# Comparing the automated-ML workflow with baseline models

In [8]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ensure output folder exists
os.makedirs('ArticleFigures', exist_ok=True)

# Traits order
traits = ['PALK','OLA','ECO','LINO','LINK','MUFA','PUFA']

# Table 5 numbers 
data = {
    'TPOT': {
        'MAE': [0.020,0.152,0.134,0.457,0.253,0.198,0.540],
        'RMSE':[0.012,0.114,0.116,0.327,0.185,0.203,0.467],
        'R2':  [0.992,0.986,0.971,0.971,0.977,0.860,0.849]
    },
    'Linear Regression': {
        'MAE': [0.062,0.311,0.191,0.586,0.535,0.709,0.691],
        'RMSE':[0.082,0.393,0.235,0.753,0.680,0.316,0.878],
        'R2':  [0.635,0.836,0.741,0.805,0.677,0.621,0.460]
    },
    'Ridge': {
        'MAE': [0.062,0.310,0.192,0.585,0.534,0.258,0.691],
        'RMSE':[0.083,0.392,0.235,0.753,0.679,0.316,0.878],
        'R2':  [0.627,0.837,0.741,0.805,0.679,0.623,0.460]
    },
    'Decision Tree': {
        'MAE': [0.030,0.241,0.163,0.495,0.375,0.260,0.648],
        'RMSE':[0.042,0.411,0.232,0.707,0.536,0.331,0.850],
        'R2':  [0.904,0.820,0.748,0.828,0.800,0.584,0.494]
    },
    'Random Forest': {
        'MAE': [0.020,0.159,0.140,0.473,0.265,0.193,0.540],
        'RMSE':[0.028,0.214,0.199,0.669,0.340,0.249,0.722],
        'R2':  [0.958,0.951,0.814,0.846,0.919,0.766,0.635]
    },
    'XGBoost': {
        'MAE': [0.019,0.169,0.142,0.464,0.237,0.193,0.535],
        'RMSE':[0.026,0.234,0.201,0.667,0.307,0.250,0.710],
        'R2':  [0.963,0.942,0.811,0.847,0.934,0.764,0.647]
    }
}

# Convert to long DataFrame
rows = []
for model_name, metrics in data.items():
    for i, tr in enumerate(traits):
        rows.append({
            'Model': model_name,
            'Trait': tr,
            'MAE': metrics['MAE'][i],
            'RMSE': metrics['RMSE'][i],
            'R2': metrics['R2'][i]
        })
df = pd.DataFrame(rows)

# Styling choices 

palette_name = ["#d81159","#ffbc42","#0496ff","#c86bfa","#ffc2e2","#ff6200"]  
custom_hex = None            

# Fonts and sizes 
font_family = 'sans-serif'  
font_name = 'DejaVu Sans'  
title_size = 16
axis_label_size = 13
tick_label_size = 11
legend_size = 11
caption_size = 10

# Apply rcParams
sns.set_style("whitegrid")
sns.set_context("talk", rc={
    "axes.facecolor": "#FFFFFF",
    "font.family": font_family,
    "font.size": tick_label_size,
    "axes.titlesize": title_size,
    "axes.labelsize": axis_label_size
})
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams.get('font.sans-serif', [])
plt.rcParams['legend.frameon'] = True
plt.rcParams['legend.framealpha'] = 0.9
plt.rcParams['axes.edgecolor'] = '#333333'
plt.rcParams['axes.linewidth'] = 0.8

# Choose palette
if custom_hex:
    palette = custom_hex
else:
    # small set of colorblind-friendly/seaborn palettes
    if palette_name == "colorblind":
        palette = sns.color_palette("colorblind", n_colors=6)
    elif palette_name == "muted":
        palette = sns.color_palette("muted", n_colors=6)
    elif palette_name == "vibrant":
        palette = sns.color_palette("tab10", n_colors=6)
    else:
        palette = sns.color_palette(palette_name, n_colors=6)

# Plotting order and layout
models_order = ['TPOT','Linear Regression','Ridge','Decision Tree','Random Forest','XGBoost']
metric_names = ['MAE','RMSE','R2']
ylabels = ['Mean Absolute Error', 'Root Mean Squared Error', 'RÂ²']

fig, axes = plt.subplots(1, 3, figsize=(20,6), sharey=False)
fig.subplots_adjust(wspace=0.28, top=0.88)  

for ax, metric, ylabel in zip(axes, metric_names, ylabels):
    plot_df = df.pivot(index='Trait', columns='Model', values=metric).loc[traits, models_order]
    x = np.arange(len(traits))
    total_width = 0.82
    n_models = len(models_order)
    bar_width = total_width / n_models
    offsets = np.linspace(-total_width/2 + bar_width/2, total_width/2 - bar_width/2, n_models)

    # Use lighter edgecolor and small linewidth 
    for i, model in enumerate(models_order):
        ax.bar(x + offsets[i], plot_df[model].values,
               width=bar_width,
               label=model,
               color=palette[i % len(palette)],
               edgecolor='#444444',
               linewidth=0.4,
               alpha=0.90)

    ax.set_xticks(x)
    ax.set_xticklabels(traits, fontsize=tick_label_size)
    ax.set_xlabel('Trait', fontsize=axis_label_size)
    ax.set_ylabel(ylabel, fontsize=axis_label_size)
    ax.set_title(metric, fontsize=title_size, fontweight='semibold')
    ax.grid(axis='y', linestyle='--', linewidth=0.7, alpha=0.6)


    if metric == 'R2':
        ax.set_ylim(0, 1.02)
        ax.yaxis.set_major_locator(plt.MultipleLocator(0.2))

# Legend placed to the right, inside a semi-transparent box
legend = axes[-1].legend(title="Model", fontsize=legend_size, title_fontsize=legend_size, loc='upper left',
                         bbox_to_anchor=(1.02, 1.0))
legend.get_frame().set_edgecolor('#CCCCCC')
legend.get_frame().set_linewidth(0.6)


# Export high-resolution PNG at 600 dpi
out_path = os.path.join('ArticleFigures', 'Model_comparison_metrics.png')
fig.savefig(out_path, dpi=600, bbox_inches='tight', facecolor='white')
plt.close(fig)

print(f"Saved improved figure to: {out_path}")


Saved improved figure to: ArticleFigures\Model_comparison_metrics.png
