# Test Loss vs Learning Rate Visualization

This notebook visualizes the test_loss data from MLflow experiments, plotting test_loss against learning rate on log-log scales.


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob
from typing import Dict, List, Tuple

# Set up plotting style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3


In [None]:
def parse_mlflow_metrics(metrics_file: str) -> pd.DataFrame:
    """Parse MLflow metrics file and return DataFrame with timestamp, value, step."""
    data = []
    try:
        with open(metrics_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    timestamp = int(parts[0])
                    value = float(parts[1])
                    step = int(parts[2])
                    data.append({
                        'timestamp': timestamp,
                        'value': value,
                        'step': step
                    })
    except Exception as e:
        print(f"Error parsing {metrics_file}: {e}")
        return pd.DataFrame()
    
    return pd.DataFrame(data)

def get_parameter_value(params_file: str) -> float:
    """Get parameter value from MLflow params file."""
    try:
        with open(params_file, 'r') as f:
            return float(f.read().strip())
    except Exception as e:
        print(f"Error reading {params_file}: {e}")
        return None

def load_mlflow_data(mlruns_path: str) -> pd.DataFrame:
    """Load all MLflow runs data into a single DataFrame."""
    all_data = []
    
    # Find all run directories
    run_dirs = glob.glob(os.path.join(mlruns_path, "0", "*"))
    
    for run_dir in run_dirs:
        if not os.path.isdir(run_dir):
            continue
            
        run_id = os.path.basename(run_dir)
        
        # Get learning rate parameters
        lrs_file = os.path.join(run_dir, "params", "lrs")
        lre_file = os.path.join(run_dir, "params", "lre")
        
        lrs = get_parameter_value(lrs_file)
        lre = get_parameter_value(lre_file)
        
        if lrs is None or lre is None:
            continue
            
        # Parse test_loss metrics
        test_loss_file = os.path.join(run_dir, "metrics", "test_loss")
        if os.path.exists(test_loss_file):
            test_loss_df = parse_mlflow_metrics(test_loss_file)
            if not test_loss_df.empty:
                test_loss_df['run_id'] = run_id
                test_loss_df['lrs'] = lrs
                test_loss_df['lre'] = lre
                test_loss_df['learning_rate'] = lrs  # Using lrs as the learning rate
                all_data.append(test_loss_df)
    
    if not all_data:
        return pd.DataFrame()
    
    return pd.concat(all_data, ignore_index=True)


In [None]:
# Load the MLflow data
mlruns_path = "../results_train_garch/learning_rate/mlruns"
df = load_mlflow_data(mlruns_path)

print(f"Loaded data from {df['run_id'].nunique()} runs")
print(f"Total data points: {len(df)}")
print(f"Learning rate range: {df['learning_rate'].min():.2e} to {df['learning_rate'].max():.2e}")
print(f"Test loss range: {df['value'].min():.2e} to {df['value'].max():.2e}")

df.head()


In [None]:
# Get final test loss for each run (last recorded value)
final_losses = df.groupby('run_id').agg({
    'value': 'last',
    'learning_rate': 'first',
    'lrs': 'first',
    'lre': 'first'
}).reset_index()

print(f"Final test losses for {len(final_losses)} runs")
final_losses.head()


In [None]:
# Create log-log plot of final test loss vs learning rate
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Final test loss vs learning rate
ax1.loglog(final_losses['learning_rate'], final_losses['value'], 'o', alpha=0.7, markersize=8)
ax1.set_xlabel('Learning Rate (log scale)')
ax1.set_ylabel('Final Test Loss (log scale)')
ax1.set_title('Final Test Loss vs Learning Rate')
ax1.grid(True, alpha=0.3)

# Add trend line
z = np.polyfit(np.log10(final_losses['learning_rate']), np.log10(final_losses['value']), 1)
p = np.poly1d(z)
ax1.loglog(final_losses['learning_rate'], 10**p(np.log10(final_losses['learning_rate'])), 
           "r--", alpha=0.8, label=f'Trend (slope={z[0]:.2f})')
ax1.legend()

# Plot 2: All test loss values vs learning rate (colored by run)
unique_runs = df['run_id'].unique()
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_runs)))

for i, run_id in enumerate(unique_runs):
    run_data = df[df['run_id'] == run_id]
    ax2.loglog(run_data['learning_rate'], run_data['value'], 
              color=colors[i], alpha=0.6, linewidth=1, label=f'Run {i+1}' if i < 10 else "")

ax2.set_xlabel('Learning Rate (log scale)')
ax2.set_ylabel('Test Loss (log scale)')
ax2.set_title('Test Loss Evolution vs Learning Rate')
ax2.grid(True, alpha=0.3)

# Only show legend for first 10 runs to avoid clutter
if len(unique_runs) <= 10:
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


In [None]:
# Create a more detailed analysis with learning rate ranges
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Scatter plot with size based on number of steps
step_counts = df.groupby('run_id').size()
final_losses['step_count'] = final_losses['run_id'].map(step_counts)

scatter = axes[0, 0].scatter(final_losses['learning_rate'], final_losses['value'], 
                            s=final_losses['step_count']*2, alpha=0.7, 
                            c=final_losses['step_count'], cmap='viridis')
axes[0, 0].set_xscale('log')
axes[0, 0].set_yscale('log')
axes[0, 0].set_xlabel('Learning Rate')
axes[0, 0].set_ylabel('Final Test Loss')
axes[0, 0].set_title('Final Test Loss vs Learning Rate\n(Size = Number of Steps)')
axes[0, 0].grid(True, alpha=0.3)
plt.colorbar(scatter, ax=axes[0, 0], label='Number of Steps')

# Plot 2: Learning rate distribution
axes[0, 1].hist(np.log10(final_losses['learning_rate']), bins=20, alpha=0.7, edgecolor='black')
axes[0, 1].set_xlabel('Log10(Learning Rate)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Learning Rate Distribution')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Test loss distribution
axes[1, 0].hist(np.log10(final_losses['value']), bins=20, alpha=0.7, edgecolor='black')
axes[1, 0].set_xlabel('Log10(Final Test Loss)')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Final Test Loss Distribution')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Learning rate vs test loss with trend analysis
axes[1, 1].loglog(final_losses['learning_rate'], final_losses['value'], 'o', alpha=0.7)

# Fit power law: loss = a * lr^b
log_lr = np.log10(final_losses['learning_rate'])
log_loss = np.log10(final_losses['value'])
z = np.polyfit(log_lr, log_loss, 1)
a, b = 10**z[1], z[0]

# Plot fitted line
lr_range = np.logspace(np.log10(final_losses['learning_rate'].min()), 
                       np.log10(final_losses['learning_rate'].max()), 100)
fitted_loss = a * (lr_range ** b)
axes[1, 1].loglog(lr_range, fitted_loss, 'r--', linewidth=2, 
                  label=f'Power law: loss = {a:.2e} × lr^{b:.2f}')

axes[1, 1].set_xlabel('Learning Rate')
axes[1, 1].set_ylabel('Final Test Loss')
axes[1, 1].set_title('Power Law Fit')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Power law fit: loss = {a:.2e} × lr^{b:.2f}")
print(f"R² = {np.corrcoef(log_lr, log_loss)[0,1]**2:.3f}")


In [None]:
# Create a detailed trajectory plot showing how test loss evolves for different learning rates
fig, ax = plt.subplots(figsize=(14, 8))

# Sort runs by learning rate for better visualization
sorted_runs = final_losses.sort_values('learning_rate')['run_id'].values

# Plot trajectories for each run
for i, run_id in enumerate(sorted_runs):
    run_data = df[df['run_id'] == run_id].sort_values('step')
    lr = run_data['learning_rate'].iloc[0]
    
    # Color by learning rate (log scale)
    color = plt.cm.plasma(np.log10(lr) / (np.log10(final_losses['learning_rate'].max()) - 
                                         np.log10(final_losses['learning_rate'].min())))
    
    ax.loglog(run_data['learning_rate'], run_data['value'], 
              color=color, alpha=0.7, linewidth=1.5, 
              label=f'LR={lr:.2e}' if i % 3 == 0 else "")

ax.set_xlabel('Learning Rate (log scale)', fontsize=14)
ax.set_ylabel('Test Loss (log scale)', fontsize=14)
ax.set_title('Test Loss Trajectories for Different Learning Rates', fontsize=16)
ax.grid(True, alpha=0.3)

# Add colorbar
sm = plt.cm.ScalarMappable(cmap=plt.cm.plasma, 
                          norm=plt.Normalize(vmin=np.log10(final_losses['learning_rate'].min()), 
                                            vmax=np.log10(final_losses['learning_rate'].max())))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label('Log10(Learning Rate)', fontsize=12)

# Only show some labels to avoid clutter
handles, labels = ax.get_legend_handles_labels()
if len(handles) > 10:
    ax.legend(handles[::len(handles)//10], labels[::len(handles)//10], 
              bbox_to_anchor=(1.05, 1), loc='upper left')
else:
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


In [None]:
# Summary statistics
print("=== SUMMARY STATISTICS ===")
print(f"Number of runs: {len(final_losses)}")
print(f"Learning rate range: {final_losses['learning_rate'].min():.2e} to {final_losses['learning_rate'].max():.2e}")
print(f"Final test loss range: {final_losses['value'].min():.2e} to {final_losses['value'].max():.2e}")
print(f"Mean final test loss: {final_losses['value'].mean():.2e}")
print(f"Std final test loss: {final_losses['value'].std():.2e}")
print(f"\nBest performing run:")
best_run = final_losses.loc[final_losses['value'].idxmin()]
print(f"  Run ID: {best_run['run_id']}")
print(f"  Learning Rate: {best_run['learning_rate']:.2e}")
print(f"  Final Test Loss: {best_run['value']:.2e}")
print(f"  Steps: {best_run['step_count']}")

print(f"\nWorst performing run:")
worst_run = final_losses.loc[final_losses['value'].idxmax()]
print(f"  Run ID: {worst_run['run_id']}")
print(f"  Learning Rate: {worst_run['learning_rate']:.2e}")
print(f"  Final Test Loss: {worst_run['value']:.2e}")
print(f"  Steps: {worst_run['step_count']}")


In [None]:
# Create log-log plot of train_loss vs step, colored by lrs and lrd
def load_train_loss_data(mlruns_path: str) -> pd.DataFrame:
    """Load train_loss data from MLflow runs."""
    all_data = []
    
    # Find all run directories
    run_dirs = glob.glob(os.path.join(mlruns_path, "0", "*"))
    
    for run_dir in run_dirs:
        if not os.path.isdir(run_dir):
            continue
            
        run_id = os.path.basename(run_dir)
        
        # Get learning rate parameters
        lrs_file = os.path.join(run_dir, "params", "lrs")
        lrd_file = os.path.join(run_dir, "params", "lrd")
        
        lrs = get_parameter_value(lrs_file)
        lrd = get_parameter_value(lrd_file)
        
        if lrs is None or lrd is None:
            continue
            
        # Parse train_loss metrics
        train_loss_file = os.path.join(run_dir, "metrics", "train_loss")
        if os.path.exists(train_loss_file):
            train_loss_df = parse_mlflow_metrics(train_loss_file)
            if not train_loss_df.empty:
                train_loss_df['run_id'] = run_id
                train_loss_df['lrs'] = lrs
                train_loss_df['lrd'] = lrd
                all_data.append(train_loss_df)
    
    if not all_data:
        return pd.DataFrame()
    
    return pd.concat(all_data, ignore_index=True)

# Load train loss data
train_df = load_train_loss_data(mlruns_path)
print(f"Loaded train loss data from {train_df['run_id'].nunique()} runs")
print(f"Total train loss data points: {len(train_df)}")

# Create the log-log plot
fig, ax = plt.subplots(figsize=(12, 8))

# Get unique combinations of lrs and lrd for coloring
unique_combinations = train_df[['lrs', 'lrd']].drop_duplicates().sort_values(['lrs', 'lrd'])
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_combinations)))

# Plot each run
for i, (_, combo) in enumerate(unique_combinations.iterrows()):
    lrs_val, lrd_val = combo['lrs'], combo['lrd']
    run_data = train_df[(train_df['lrs'] == lrs_val) & (train_df['lrd'] == lrd_val)]
    
    # Group by run_id to handle multiple runs with same lrs/lrd
    for run_id, run_group in run_data.groupby('run_id'):
        ax.loglog(run_group['step'], run_group['value'], 
                 color=colors[i], alpha=0.7, linewidth=1.5,
                 label=f'lrs={lrs_val:.2e}, lrd={lrd_val:.2e}' if run_id == run_group['run_id'].iloc[0] else "")

ax.set_xlabel('Step (log scale)', fontsize=14)
ax.set_ylabel('Train Loss (log scale)', fontsize=14)
ax.set_title('Train Loss vs Step (Log-Log Plot)', fontsize=16)
ax.grid(True, alpha=0.3)

# Add legend (limit to avoid clutter)
handles, labels = ax.get_legend_handles_labels()
if len(handles) > 15:
    # Show every nth label to avoid overcrowding
    step = len(handles) // 15
    ax.legend(handles[::step], labels[::step], 
              bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
else:
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

plt.tight_layout()
plt.show()

# Print summary of unique lrs/lrd combinations
print(f"\nUnique lrs/lrd combinations:")
for _, combo in unique_combinations.iterrows():
    count = len(train_df[(train_df['lrs'] == combo['lrs']) & (train_df['lrd'] == combo['lrd'])]['run_id'].unique())
    print(f"  lrs={combo['lrs']:.2e}, lrd={combo['lrd']:.2e}: {count} runs")


In [None]:
# Create plot of learning_rate vs step (linear x-axis, log y-axis)
def load_learning_rate_data(mlruns_path: str) -> pd.DataFrame:
    """Load learning_rate data from MLflow runs."""
    all_data = []
    
    # Find all run directories
    run_dirs = glob.glob(os.path.join(mlruns_path, "0", "*"))
    
    for run_dir in run_dirs:
        if not os.path.isdir(run_dir):
            continue
            
        run_id = os.path.basename(run_dir)
        
        # Get learning rate parameters
        lrs_file = os.path.join(run_dir, "params", "lrs")
        lrd_file = os.path.join(run_dir, "params", "lrd")
        
        lrs = get_parameter_value(lrs_file)
        lrd = get_parameter_value(lrd_file)
        
        if lrs is None or lrd is None:
            continue
            
        # Parse learning_rate metrics
        lr_file = os.path.join(run_dir, "metrics", "learning_rate")
        if os.path.exists(lr_file):
            lr_df = parse_mlflow_metrics(lr_file)
            if not lr_df.empty:
                lr_df['run_id'] = run_id
                lr_df['lrs'] = lrs
                lr_df['lrd'] = lrd
                all_data.append(lr_df)
    
    if not all_data:
        return pd.DataFrame()
    
    return pd.concat(all_data, ignore_index=True)

# Load learning rate data
lr_df = load_learning_rate_data(mlruns_path)
print(f"Loaded learning rate data from {lr_df['run_id'].nunique()} runs")
print(f"Total learning rate data points: {len(lr_df)}")

# Create the plot
fig, ax = plt.subplots(figsize=(12, 8))

# Get unique combinations of lrs and lrd for coloring
unique_combinations = lr_df[['lrs', 'lrd']].drop_duplicates().sort_values(['lrs', 'lrd'])
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_combinations)))

# Plot each run
for i, (_, combo) in enumerate(unique_combinations.iterrows()):
    lrs_val, lrd_val = combo['lrs'], combo['lrd']
    run_data = lr_df[(lr_df['lrs'] == lrs_val) & (lr_df['lrd'] == lrd_val)]
    
    # Group by run_id to handle multiple runs with same lrs/lrd
    for run_id, run_group in run_data.groupby('run_id'):
        ax.semilogy(run_group['step'], run_group['value'], 
                   color=colors[i], alpha=0.7, linewidth=1.5,
                   label=f'lrs={lrs_val:.2e}, lrd={lrd_val:.2e}' if run_id == run_group['run_id'].iloc[0] else "")

ax.set_xlabel('Step', fontsize=14)
ax.set_ylabel('Learning Rate (log scale)', fontsize=14)
ax.set_title('Learning Rate vs Step (Linear-Log Plot)', fontsize=16)
ax.grid(True, alpha=0.3)

# Add legend (limit to avoid clutter)
handles, labels = ax.get_legend_handles_labels()
if len(handles) > 15:
    # Show every nth label to avoid overcrowding
    step = len(handles) // 15
    ax.legend(handles[::step], labels[::step], 
              bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
else:
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

plt.tight_layout()
plt.show()

# Print summary of learning rate evolution
print(f"\nLearning rate evolution summary:")
for _, combo in unique_combinations.iterrows():
    combo_data = lr_df[(lr_df['lrs'] == combo['lrs']) & (lr_df['lrd'] == combo['lrd'])]
    initial_lr = combo_data.groupby('run_id')['value'].first().mean()
    final_lr = combo_data.groupby('run_id')['value'].last().mean()
    print(f"  lrs={combo['lrs']:.2e}, lrd={combo['lrd']:.2e}: {len(combo_data['run_id'].unique())} runs")
    print(f"    Initial LR: {initial_lr:.2e}, Final LR: {final_lr:.2e}")


In [None]:
# Create 2-column subplot: learning rate vs steps (left) and test_loss vs steps (right)
# Focus on lrd=1.00e+00 (no decay) curves in blue, plus best curve (lrs=2E-03, lrd=0.999) in red

# First, let's check what columns we have in our dataframes
print("Columns in df (test_loss data):", df.columns.tolist())
print("Columns in lr_df (learning_rate data):", lr_df.columns.tolist())

# We need to add lrd column to df by merging with lr_df or loading it separately
# Let's create a mapping from run_id to lrd
run_to_lrd = lr_df.groupby('run_id')['lrd'].first().to_dict()
run_to_lrs = lr_df.groupby('run_id')['lrs'].first().to_dict()

# Add lrd and lrs columns to df
df['lrd'] = df['run_id'].map(run_to_lrd)
df['lrs'] = df['run_id'].map(run_to_lrs)

print(f"Added lrd and lrs columns to df. Unique lrd values: {df['lrd'].unique()}")

# Filter data for specific curves we want to show
# 1. Highest lrs=2E-3 with lrd=1
# 2. Lowest lrs=2E-5 with lrd=1  
# 3. Current red line: lrs=2E-3, lrd=0.999
# 4. New one: lrs=2E-3, lrd=0.998

# First, let's check what learning rates are actually available for lrd=1.0
available_lrs_lrd1 = sorted(lr_df[lr_df['lrd'] == 1.0]['lrs'].unique())
print(f"Available lrs values for lrd=1.0: {available_lrs_lrd1}")

# Define the specific curves we want - updated with correct values
target_curves = [
    {'lrs': 1e-3, 'lrd': 1.0, 'label': 'lrs=1E-3, lrd=1.0', 'color': '#1f77b4'},  # blue - highest lrs with no decay
    {'lrs': 2e-5, 'lrd': 1.0, 'label': 'lrs=2E-5, lrd=1.0', 'color': '#ff7f0e'},  # orange - lowest lrs with no decay
    {'lrs': 2e-3, 'lrd': 0.999, 'label': 'lrs=2E-3, lrd=0.999', 'color': 'red'},  # red - current best
    {'lrs': 2e-3, 'lrd': 0.998, 'label': 'lrs=2E-3, lrd=0.998', 'color': '#9467bd'}  # purple - new curve
]

print("Target curves:")
for curve in target_curves:
    print(f"  {curve['label']} - {curve['color']}")

# Filter data for these specific curves
filtered_lr_data = []
filtered_test_data = []

for curve in target_curves:
    lrs_val = curve['lrs']
    lrd_val = curve['lrd']
    
    # Get learning rate data
    lr_data = lr_df[(lr_df['lrs'] == lrs_val) & (lr_df['lrd'] == lrd_val)]
    if not lr_data.empty:
        lr_data['curve_label'] = curve['label']
        lr_data['curve_color'] = curve['color']
        filtered_lr_data.append(lr_data)
    
    # Get test loss data
    test_data = df[(df['lrs'] == lrs_val) & (df['lrd'] == lrd_val)]
    if not test_data.empty:
        test_data['curve_label'] = curve['label']
        test_data['curve_color'] = curve['color']
        filtered_test_data.append(test_data)

print(f"\nFound data for {len(filtered_lr_data)} learning rate curves and {len(filtered_test_data)} test loss curves")

# Create subplot
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))

# Get unique lrs values for no-decay curves, sorted
unique_lrs_no_decay = sorted(no_decay_lr_data['lrs'].unique())
print(f"Unique lrs values (no decay): {unique_lrs_no_decay}")

# Create better color scheme with blue, purple, green, orange
n_colors = len(unique_lrs_no_decay)
# Create a custom colormap that transitions through blue -> purple -> green -> orange
colors_list = []
if n_colors == 1:
    colors_list = ['#1f77b4']  # blue
elif n_colors == 2:
    colors_list = ['#1f77b4', '#ff7f0e']  # blue, orange
elif n_colors == 3:
    colors_list = ['#1f77b4', '#9467bd', '#ff7f0e']  # blue, purple, orange
elif n_colors == 4:
    colors_list = ['#1f77b4', '#9467bd', '#2ca02c', '#ff7f0e']  # blue, purple, green, orange
else:
    # For more colors, create a gradient through the desired colors
    import matplotlib.colors as mcolors
    # Define custom colors: blue -> purple -> green -> orange
    custom_colors = ['#1f77b4', '#9467bd', '#2ca02c', '#ff7f0e']
    cmap = mcolors.LinearSegmentedColormap.from_list('custom', custom_colors)
    colors_list = [cmap(i / (n_colors - 1)) for i in range(n_colors)]

print(f"Using {len(colors_list)} colors: {colors_list}")

# Plot no-decay curves in blue gradients
# Plot the specific curves
for i, curve_data in enumerate(filtered_lr_data):
    curve_label = curve_data['curve_label'].iloc[0]
    curve_color = curve_data['curve_color'].iloc[0]
    
    # Learning rate plot (left)
    for run_id, run_group in curve_data.groupby('run_id'):
        ax2.semilogy(run_group['step'], run_group['value'], 
                    color=curve_color, alpha=0.8, linewidth=1,
                    label=curve_label if run_id == run_group['run_id'].iloc[0] else "")
    
    # Bottom right: Learning rate plot (loglog)
    for run_id, run_group in curve_data.groupby('run_id'):
        ax4.loglog(run_group['step'], run_group['value'], 
                  color=curve_color, alpha=0.8, linewidth=1,
                  label=curve_label if run_id == run_group['run_id'].iloc[0] else "")

# Plot test loss curves
for i, curve_data in enumerate(filtered_test_data):
    curve_label = curve_data['curve_label'].iloc[0]
    curve_color = curve_data['curve_color'].iloc[0]
    
    # Top left: Test loss plot (semilogy)
    for run_id, run_group in curve_data.groupby('run_id'):
        ax1.semilogy(run_group['step'], run_group['value'], 
                    color=curve_color, alpha=0.8, linewidth=1,
                    label=curve_label if run_id == run_group['run_id'].iloc[0] else "")
    
    # Bottom left: Test loss plot (loglog)
    for run_id, run_group in curve_data.groupby('run_id'):
        ax3.loglog(run_group['step'], run_group['value'], 
                  color=curve_color, alpha=0.8, linewidth=1,
                  label=curve_label if run_id == run_group['run_id'].iloc[0] else "")

        ax3.loglog(run_group['step'], run_group['value'], 
                  color=curve_color, alpha=0.8,
                  label=curve_label if run_id == run_group['run_id'].iloc[0] else "")


# Configure plots
# Top left: Test loss vs steps (semilogy)
ax1.set_xlabel('Steps')
#ax1.set_ylabel('Test Loss (log scale)', fontsize=10)
ax1.set_title('Test Loss (Semi-Log)')
ax1.grid(True, alpha=0.3)

# Top right: Learning rate vs steps (semilogy)
ax2.set_xlabel('Steps')
#ax2.set_ylabel('Learning Rate (log scale)', fontsize=10)
ax2.set_title('Learning Rate (Semi-Log)')
ax2.grid(True, alpha=0.3)

# Bottom left: Test loss vs steps (loglog)
ax3.set_xlabel('Steps')
#ax3.set_ylabel('Test Loss (log scale)', fontsize=10)
ax3.set_title('Test Loss (Log-Log)',)
ax3.grid(True, alpha=0.3)

# Bottom right: Learning rate vs steps (loglog)
ax4.set_xlabel('Steps')
#ax4.set_ylabel('Learning Rate', fontsize=10)
ax4.set_title('Learning Rate (Log-Log)')
ax4.grid(True, alpha=0.3)

# Remove top and right spines from all plots
for ax in [ax1, ax2, ax3, ax4]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Add legend only to top right plot (learning rate semilogy)
handles, labels = ax2.get_legend_handles_labels()
#ax2.legend(handles, labels, loc='upper right', 
#          frameon=True, fancybox=True, shadow=True, 
#          facecolor='white', edgecolor='none', framealpha=0.8, fontsize=8)


# Configure right plot (test loss)
ax3.set_xlabel('Steps')
#ax3.set_ylabel('Test Loss')
ax3.set_title('Test Loss (Log-Log)')
ax3.grid(True, alpha=0.5)


# Add legend only to left plot, no frame, with semi-transparent background
handles, labels = ax1.get_legend_handles_labels()

# Reverse the order of blue entries (highest learning rate at top)
# Separate blue and red entries
blue_handles = []
blue_labels = []
red_handles = []
red_labels = []

for handle, label in zip(handles, labels):
    if 'Best:' in label:
        red_handles.append(handle)
        red_labels.append(label)
    else:
        blue_handles.append(handle)
        blue_labels.append(label)

# Reverse blue entries and combine with red at bottom
reversed_handles = blue_handles[::-1] + red_handles
reversed_labels = blue_labels[::-1] + red_labels


plt.tight_layout()
plt.savefig('leanring_rate.png', dpi=300)
plt.show()

