Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 158 additions & 5 deletions public/machine_learning/visualize_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,83 @@
from simstack.models.charts_artifact import (
ChartArtifactModel,
AGBarSeriesConfig,
AGHeatmapSeriesConfig,
AGRangeBarSeriesConfig,
AGChartAxisConfig,
AGChartTitleConfig
AGChartTitleConfig,
create_simple_heatmap_chart,
)
from simstack.models.table_artifact import TableArtifactModel, AGGridColumnDef
from simstack.models.pandas_model import PandasModel

from simstack.models.charts_artifact import create_simple_scatter_chart


def _configure_heatmap_series(
heatmap_chart,
*,
x_name: str,
y_name: str,
color_name: str,
color_domain=None,
color_range=None,
):
if heatmap_chart.series and isinstance(heatmap_chart.series[0], AGHeatmapSeriesConfig):
heatmap_chart.series[0].xName = x_name
heatmap_chart.series[0].yName = y_name
heatmap_chart.series[0].colorName = color_name
if color_domain is not None:
heatmap_chart.series[0].colorDomain = color_domain
if color_range is not None:
heatmap_chart.series[0].colorRange = color_range


async def _save_correlation_heatmap(
corr_frame,
title: str,
task_id,
node_runner,
charts: list,
x_name: str = "Column",
y_name: str = "Row",
color_name: str = "Correlation",
):
heatmap_data = []
for row_name in corr_frame.index:
for col_name in corr_frame.columns:
value = corr_frame.loc[row_name, col_name]
if value != value:
continue
heatmap_data.append(
{
"x_feature": str(col_name),
"y_feature": str(row_name),
"correlation": float(value),
}
)

heatmap_chart = create_simple_heatmap_chart(
data=heatmap_data,
x_key="x_feature",
y_key="y_feature",
color_key="correlation",
title=title,
parent_id=ObjectId(task_id) if task_id else None,
)
_configure_heatmap_series(
heatmap_chart,
x_name=x_name,
y_name=y_name,
color_name=color_name,
color_domain=[-1.0, 1.0],
color_range=["#2166ac", "#f7f7f7", "#b2182b"],
)

await context.db.save(heatmap_chart)
charts.append(heatmap_chart)
node_runner.info(f"Saved heatmap: {title}")


async def _visualize_strain_vs_concentration_internal(dataset: PandasModel, **kwargs):
node_runner = kwargs.get("node_runner")
task_id = kwargs.get("task_id")
Expand Down Expand Up @@ -82,6 +150,36 @@ async def _visualize_strain_vs_concentration_internal(dataset: PandasModel, **kw
charts.append(corr_table)
node_runner.info("Saved correlation matrix table")

model_feature_cols = [
col
for col in [
"youngs_modulus_MPa",
"yield_strength_MPa",
"ultimate_strength_MPa",
"fracture_stress_MPa",
"fracture_strain",
"uniform_strain",
]
if col in df.columns
]

if model_feature_cols:
spearman_feature_target_corr = (
df[model_feature_cols + impurity_cols]
.corr(method="spearman")
.loc[model_feature_cols, impurity_cols]
)
await _save_correlation_heatmap(
spearman_feature_target_corr,
title="Spearman Correlation: Concentrations vs Stress-Strain Features",
task_id=task_id,
node_runner=node_runner,
charts=charts,
x_name="Concentration",
y_name="Stress-Strain Feature",
color_name="Spearman correlation",
)

if hasattr(node_runner, 'result'):
node_runner.result = {"charts_count": len(charts), "correlation_matrix": corr.to_dict()}
return charts
Expand All @@ -92,6 +190,55 @@ async def visualize_strain_vs_concentration(dataset: PandasModel, **kwargs):
return await _visualize_strain_vs_concentration_internal(dataset, **kwargs)


async def _visualize_impurity_ranges_internal(dataset: PandasModel, **kwargs):
node_runner = kwargs.get("node_runner")
task_id = kwargs.get("task_id")

df = dataset.table
impurity_cols = ["C_wt_percent", "Mn_wt_percent", "P_wt_percent", "S_wt_percent"]

min_values = df[impurity_cols].min()
max_values = df[impurity_cols].max()

chart_data = []
for col in impurity_cols:
chart_data.append({
"impurity": col.split("_")[0],
"min_value": float(min_values[col]),
"max_value": float(max_values[col])
})

range_series = [
AGRangeBarSeriesConfig(
type="range-bar",
xKey="impurity",
yLowKey="min_value",
yHighKey="max_value",
title="Impurity Concentration Range",
data=chart_data
)
]

axes = [
AGChartAxisConfig(type="category", position="bottom", title="Impurity"),
AGChartAxisConfig(type="number", position="left", title="Concentration (wt%)")
]

range_chart = ChartArtifactModel(
title=AGChartTitleConfig(text="Impurity Concentration Min/Max Ranges"),
series=range_series,
axes=axes,
data=chart_data
)

if task_id:
range_chart.parent_id = ObjectId(task_id)
await context.db.save(range_chart)
node_runner.info("Saved impurity min/max range chart")

return range_chart


async def _visualize_impurity_maxima_internal(dataset: PandasModel, **kwargs):
node_runner = kwargs.get("node_runner")
task_id = kwargs.get("task_id")
Expand Down Expand Up @@ -154,9 +301,14 @@ async def _visualize_impurity_maxima_internal(dataset: PandasModel, **kwargs):

@node(parameters=Parameters(force_rerun=True))
async def visualize_impurity_maxima(dataset: PandasModel, **kwargs):
chart = await _visualize_impurity_maxima_internal(dataset, **kwargs)
kwargs.get("node_runner").chart = chart
return kwargs.get("node_runner").succeed()
max_chart = await _visualize_impurity_maxima_internal(dataset, **kwargs)

# We can remove old max_chart and set this new range_chart as node_runner.chart if needed
range_chart = await _visualize_impurity_ranges_internal(dataset, **kwargs)

node_runner = kwargs.get("node_runner")
node_runner.chart = max_chart
return node_runner.succeed()


async def main():
Expand Down Expand Up @@ -188,6 +340,7 @@ def fail(self, msg): print(f"FAIL: {msg}")

print("\n--- Visualizing Impurity Maxima ---")
await _visualize_impurity_maxima_internal(dataset, **kwargs)
await _visualize_impurity_ranges_internal(dataset, **kwargs)

if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())