diff --git a/public/machine_learning/visualize_training_data.py b/public/machine_learning/visualize_training_data.py index 4709d73..455b351 100644 --- a/public/machine_learning/visualize_training_data.py +++ b/public/machine_learning/visualize_training_data.py @@ -7,8 +7,11 @@ from simstack.models.charts_artifact import ( ChartArtifactModel, AGBarSeriesConfig, + AGHeatmapSeriesConfig, + AGRangeBarSeriesConfig, AGChartAxisConfig, - AGChartTitleConfig + AGChartTitleConfig, + create_simple_heatmap_chart, ) from simstack.models.table_artifact import TableArtifactModel, AGGridColumnDef from simstack.models.pandas_model import PandasModel @@ -16,6 +19,71 @@ from simstack.models.charts_artifact import create_simple_scatter_chart +def _configure_heatmap_series( + heatmap_chart, + *, + x_name: str, + y_name: str, + color_name: str, + color_domain=None, + color_range=None, +): + if heatmap_chart.series and isinstance(heatmap_chart.series[0], AGHeatmapSeriesConfig): + heatmap_chart.series[0].xName = x_name + heatmap_chart.series[0].yName = y_name + heatmap_chart.series[0].colorName = color_name + if color_domain is not None: + heatmap_chart.series[0].colorDomain = color_domain + if color_range is not None: + heatmap_chart.series[0].colorRange = color_range + + +async def _save_correlation_heatmap( + corr_frame, + title: str, + task_id, + node_runner, + charts: list, + x_name: str = "Column", + y_name: str = "Row", + color_name: str = "Correlation", +): + heatmap_data = [] + for row_name in corr_frame.index: + for col_name in corr_frame.columns: + value = corr_frame.loc[row_name, col_name] + if value != value: + continue + heatmap_data.append( + { + "x_feature": str(col_name), + "y_feature": str(row_name), + "correlation": float(value), + } + ) + + heatmap_chart = create_simple_heatmap_chart( + data=heatmap_data, + x_key="x_feature", + y_key="y_feature", + color_key="correlation", + title=title, + parent_id=ObjectId(task_id) if task_id else None, + ) + _configure_heatmap_series( + heatmap_chart, + x_name=x_name, + y_name=y_name, + color_name=color_name, + color_domain=[-1.0, 1.0], + color_range=["#2166ac", "#f7f7f7", "#b2182b"], + ) + + await context.db.save(heatmap_chart) + charts.append(heatmap_chart) + node_runner.info(f"Saved heatmap: {title}") + + async def _visualize_strain_vs_concentration_internal(dataset: PandasModel, **kwargs): node_runner = kwargs.get("node_runner") task_id = kwargs.get("task_id") @@ -82,6 +150,36 @@ async def _visualize_strain_vs_concentration_internal(dataset: PandasModel, **kw charts.append(corr_table) node_runner.info("Saved correlation matrix table") + model_feature_cols = [ + col + for col in [ + "youngs_modulus_MPa", + "yield_strength_MPa", + "ultimate_strength_MPa", + "fracture_stress_MPa", + "fracture_strain", + "uniform_strain", + ] + if col in df.columns + ] + + if model_feature_cols: + spearman_feature_target_corr = ( + df[model_feature_cols + impurity_cols] + .corr(method="spearman") + .loc[model_feature_cols, impurity_cols] + ) + await _save_correlation_heatmap( + spearman_feature_target_corr, + title="Spearman Correlation: Concentrations vs Stress-Strain Features", + task_id=task_id, + node_runner=node_runner, + charts=charts, + x_name="Concentration", + y_name="Stress-Strain Feature", + color_name="Spearman correlation", + ) + if hasattr(node_runner, 'result'): node_runner.result = {"charts_count": len(charts), "correlation_matrix": corr.to_dict()} return charts @@ -92,6 +190,55 @@ async def visualize_strain_vs_concentration(dataset: PandasModel, **kwargs): return await _visualize_strain_vs_concentration_internal(dataset, **kwargs) +async def _visualize_impurity_ranges_internal(dataset: PandasModel, **kwargs): + node_runner = kwargs.get("node_runner") + task_id = kwargs.get("task_id") + + df = dataset.table + impurity_cols = ["C_wt_percent", "Mn_wt_percent", "P_wt_percent", "S_wt_percent"] + + min_values = df[impurity_cols].min() + max_values = df[impurity_cols].max() + + chart_data = [] + for col in impurity_cols: + chart_data.append({ + "impurity": col.split("_")[0], + "min_value": float(min_values[col]), + "max_value": float(max_values[col]) + }) + + range_series = [ + AGRangeBarSeriesConfig( + type="range-bar", + xKey="impurity", + yLowKey="min_value", + yHighKey="max_value", + title="Impurity Concentration Range", + data=chart_data + ) + ] + + axes = [ + AGChartAxisConfig(type="category", position="bottom", title="Impurity"), + AGChartAxisConfig(type="number", position="left", title="Concentration (wt%)") + ] + + range_chart = ChartArtifactModel( + title=AGChartTitleConfig(text="Impurity Concentration Min/Max Ranges"), + series=range_series, + axes=axes, + data=chart_data + ) + + if task_id: + range_chart.parent_id = ObjectId(task_id) + await context.db.save(range_chart) + node_runner.info("Saved impurity min/max range chart") + + return range_chart + + async def _visualize_impurity_maxima_internal(dataset: PandasModel, **kwargs): node_runner = kwargs.get("node_runner") task_id = kwargs.get("task_id") @@ -154,9 +301,14 @@ async def _visualize_impurity_maxima_internal(dataset: PandasModel, **kwargs): @node(parameters=Parameters(force_rerun=True)) async def visualize_impurity_maxima(dataset: PandasModel, **kwargs): - chart = await _visualize_impurity_maxima_internal(dataset, **kwargs) - kwargs.get("node_runner").chart = chart - return kwargs.get("node_runner").succeed() + max_chart = await _visualize_impurity_maxima_internal(dataset, **kwargs) + + # We can remove old max_chart and set this new range_chart as node_runner.chart if needed + range_chart = await _visualize_impurity_ranges_internal(dataset, **kwargs) + + node_runner = kwargs.get("node_runner") + node_runner.chart = max_chart + return node_runner.succeed() async def main(): @@ -188,6 +340,7 @@ def fail(self, msg): print(f"FAIL: {msg}") print("\n--- Visualizing Impurity Maxima ---") await _visualize_impurity_maxima_internal(dataset, **kwargs) + await _visualize_impurity_ranges_internal(dataset, **kwargs) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())