# Foundation Model Training Time Comparison

Visualize training time differences across foundation models using gradient color intensity.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

## Load and Process Data

In [None]:
def time_to_minutes(time_str: str) -> float:
    """Convert time string (H:MM:SS) to minutes."""
    hours, minutes, seconds = map(int, time_str.split(':'))
    return hours * 60 + minutes + seconds / 60

# Load data
csv_path = 'data/experiments/foundational_train_time_comp.csv'
df = pd.read_csv(csv_path, skiprows=1)
df['Training Time (minutes)'] = df['Training Time'].apply(time_to_minutes)

# Display data
df[['Model Name', 'Training Time', 'Training Time (minutes)']]

## Create Visualization

In [None]:
# Setup plot style
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(figsize=(12, 7))

# Normalize times for color mapping
min_time = df['Training Time (minutes)'].min()
max_time = df['Training Time (minutes)'].max()
norm_times = (df['Training Time (minutes)'] - min_time) / (max_time - min_time)

# Map to rocket colormap
rocket_cmap = sns.color_palette("rocket", as_cmap=True)
bar_colors = [rocket_cmap(val) for val in norm_times]

# Create bars
bars = ax.bar(
    df['Model Name'],
    df['Training Time (minutes)'],
    color=bar_colors,
    edgecolor='black',
    linewidth=1.2
)

# Styling
ax.set_xlabel('Model Name', fontsize=13, fontweight='bold')
ax.set_ylabel('Training Time (minutes)', fontsize=13, fontweight='bold')
ax.set_title(
    'Foundation Model Training Time Comparison\n(10,000 sequences on 2 A100 GPUs)',
    fontsize=15,
    fontweight='bold',
    pad=20
)
plt.xticks(rotation=45, ha='right')

# Add time labels
for bar, time_str in zip(bars, df['Training Time']):
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        height,
        f'{time_str}',
        ha='center',
        va='bottom',
        fontsize=10,
        fontweight='bold'
    )

plt.tight_layout()
plt.show()

## Save Figure

In [None]:
# Save high-resolution figure
output_path = 'figures/time/training_time_comparison.png'
fig.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"âœ“ Plot saved as '{output_path}'")