# Figures for cs182/282 Final Project

jenniferyjlin@berkeley.edu

# Load libriary and data

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import os

output_dir = '../output/figures/'
os.makedirs(output_dir, exist_ok=True)

# Read the combined data file
df = pd.read_csv('../output/Results - 20251211 NT+Cad for report 3.tsv', sep='\t')

print("Data shape:", df.shape)
print("\nColumn names:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())
print("\nModel types:")
print(df['Model'].value_counts())
print("\nFine-tuning types:")
print(df['Fine-tuning'].value_counts())

plt.rcParams['font.size'] = 10
plt.rcParams['font.family'] = 'serif'

Data shape: (23, 53)

Column names:
['Model', 'Pretraining', 'Input length', 'Classifier head', 'Classifier head -  hidden layers', 'Embedding strategy', 'Embedding  Strategy ', 'Fine-tuning', 'LoRA rank', 'Learning rate', 'Training Sample size', 'Batch size', 'Num steps', 'Grad accum', 'Best Validation AUC', 'Training time (hr)', 'Training memory (GB)', 'Trainable params', 'All params', 'Training_Step1000_AUC', 'Training_Step2000_AUC', 'Training_Step3000_AUC', 'Training_Step4000_AUC', 'Training_Step5000_AUC', 'Training_Step6000_AUC', 'Training_Step7000_AUC', 'Training_Step8000_AUC', 'Training_Step9000_AUC', 'Training_Step10000_AUC', 'Training_Step11000_AUC', 'Training_Step12000_AUC', 'Training_Step13000_AUC', 'Training_Step14000_AUC', 'Val_Step1000_AUC', 'Val_Step2000_AUC', 'Val_Step3000_AUC', 'Val_Step4000_AUC', 'Val_Step5000_AUC', 'Val_Step6000_AUC', 'Val_Step7000_AUC', 'Val_Step8000_AUC', 'Val_Step9000_AUC', 'Val_Step10000_AUC', 'Val_Step11000_AUC', 'Val_Step12000_AUC', 'Val_Step13

# Helper functions

In [2]:
# Define marker styles for each model
MARKERS = {
    'NT-2': 'o',      # Circle
    'NT-1': '^',      # Triangle
    'Caduceus': 'x',  # X
    'NT-2-6k': 's'    # Square for 6k input length
}

# Standardize column names - handle both variations
def get_classifier_layers(row):
    """Get classifier layers handling both column name variations"""
    if pd.notna(row.get('Classifier head -  hidden layers')):
        return row['Classifier head -  hidden layers']
    elif pd.notna(row.get('Classifier head - hidden layers')):
        return row['Classifier head - hidden layers']
    return None

# Add standardized columns
df['Classifier_layers'] = df.apply(get_classifier_layers, axis=1)

# Function to extract training/val curves
def extract_curves(row, prefix='Val_Step', max_steps=14000):
    steps = []
    aucs = []
    for step in range(1000, max_steps+1, 1000):
        col = f'{prefix}{step}_AUC'
        if col in row.index:
            val = row[col]
            # Check if value exists and is not NaN
            if pd.notna(val):
                try:
                    auc = float(val)
                    steps.append(step)
                    aucs.append(auc)
                except (ValueError, TypeError):
                    pass
    return np.array(steps), np.array(aucs)

# Standardize embedding strategy
def standardize_embedding(embed_str):
    """Standardize embedding strategy by removing 'full-' prefix and categorizing"""
    if pd.isna(embed_str):
        return 'unknown'
    
    embed_lower = str(embed_str).lower().strip()
    
    # Remove 'full-' prefix if present
    if embed_lower.startswith('full-'):
        embed_lower = embed_lower.replace('full-', '', 1)
    
    # Check downsample FIRST (most specific patterns)
    if 'downsample' in embed_lower:
        if 'mean' in embed_lower:
            return 'mean_pool' # 'downsample_mean'
        elif 'variant' in embed_lower:
            return 'variant_position' #'downsample_variant'
        else:
            return 'downsample'
    # Then check variant position (less specific)
    elif 'variant' in embed_lower:
        return 'variant_position'
    # Then check mean pool
    elif 'mean' in embed_lower:
        return 'mean_pool'
    else:
        return embed_str

df['Embedding_std'] = df['Embedding strategy'].apply(standardize_embedding)

print("\nStandardized embedding strategies:")
print(df['Embedding_std'].value_counts())

# Filter data by model and fine-tuning status
def filter_models(df, model_name, finetuning_status):
    """Filter models by name and fine-tuning status"""
    mask = (df['Model'] == model_name) & (df['Fine-tuning'].str.contains(finetuning_status, na=False))
    return df[mask].copy()

# Function to extract training/val curves
def extract_curves(row, prefix='Val_Step', max_steps=14000):
    steps = []
    aucs = []
    for step in range(1000, max_steps+1, 1000):
        col = f'{prefix}{step}_AUC'
        if col in row.index:
            val = row[col]
            # Check if value exists and is not NaN
            if pd.notna(val):
                try:
                    auc = float(val)
                    steps.append(step)
                    aucs.append(auc)
                except (ValueError, TypeError):
                    pass
    return np.array(steps), np.array(aucs)


Standardized embedding strategies:
variant_position    17
mean_pool            6
Name: Embedding_std, dtype: int64


# Figure 2

In [3]:
print("\n" + "="*80)
print("FIGURE 2: Classifier Head and Embedding Strategy Comparison (NT-2, NT-1, Caduceus)")
print("="*80)

# Get frozen models for all model types
frozen_nt2 = filter_models(df, 'NT-2', 'frozen')
frozen_nt1 = filter_models(df, 'NT-1', 'frozen')
frozen_cad = filter_models(df, 'Caduceus', 'frozen')

print(f"\nFrozen NT-2 models: {len(frozen_nt2)}")
print(f"Frozen NT-1 models: {len(frozen_nt1)}")
print(f"Frozen Caduceus models: {len(frozen_cad)}")

# Organize models by embedding strategy
def organize_by_embedding(df_subset, model_type):
    """Organize models by embedding strategy and return categorized lists"""
    models_by_embedding = {
        'variant_position': [],
        'mean_pool': [],
        'downsample_mean': [],
        'downsample_variant': []
    }
    
    for idx, row in df_subset.iterrows():
        embed_std = row['Embedding_std']
        head = row['Classifier head']
        layers = row['Classifier_layers']
        input_length = row.get('Input length', 12000)  # Get input length, default 12000
        
        if embed_std in models_by_embedding:
            models_by_embedding[embed_std].append((idx, head, layers, model_type, input_length))
    
    return models_by_embedding

# Organize all frozen models
nt2_models = organize_by_embedding(frozen_nt2, 'NT-2')
nt1_models = organize_by_embedding(frozen_nt1, 'NT-1')
cad_models = organize_by_embedding(frozen_cad, 'Caduceus')

# Assign colors based on classifier head and layers
def get_color_for_config(head, layers, embed_type):
    """Get color based on architecture configuration"""
    head_lower = str(head).lower() if pd.notna(head) else 'unknown'
    
    # Special handling for downsample embeddings
    if 'downsample' in embed_type:
        if head_lower == 'cnn':
            return 'purple'
        elif head_lower == 'transformer':
            return 'darkviolet'
        else:
            return 'mediumorchid'
    
    # Standard colors for other embeddings
    if head_lower == 'mlp':
        if str(layers) == '2':
            return 'skyblue' if embed_type == 'variant_position' else 'lightcoral'
        else:
            return 'C0' if embed_type == 'variant_position' else 'C1'
    elif head_lower == 'transformer':
        return 'C2' if embed_type == 'variant_position' else 'C3'
    elif head_lower == 'cnn':
        return 'C4' if embed_type == 'variant_position' else 'C5'
    return 'C6'

# Create Figure 1
fig = plt.figure(figsize=(18, 8))
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.35], wspace=0.3)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])

# Store all plotted models for legend
plotted_models = []

# Plot function for a set of models
def plot_models(models_list, ax1, ax2, embed_type, linestyle='-', alpha=1.0):
    """Plot training and validation curves for a list of models"""
    for idx, head, layers, model, input_length in models_list:
        row = df.loc[idx]
        steps_train, aucs_train = extract_curves(row, 'Training_Step', 5000)
        steps_val, aucs_val = extract_curves(row, 'Val_Step', 5000)
        
        color = get_color_for_config(head, layers, embed_type)
        
        # Determine marker based on model type and input length
        if model == 'NT-1':
            # NT-1 always uses triangle marker regardless of input length
            marker = MARKERS.get('NT-1', '^')
            input_label = ''
        elif model == 'NT-2' and input_length == 6000:
            # NT-2 with 6k input uses square marker
            marker = MARKERS.get('NT-2-6k', 's')
            input_label = ' (6k)'
        else:
            # Default marker for the model
            marker = MARKERS.get(model, 'o')
            input_label = ''
        
        # Create label
        head_str = str(head).upper() if pd.notna(head) else 'UNK'
        layers_str = str(int(layers)) if pd.notna(layers) else '?'
        
        if embed_type == 'variant_position':
            embed_label = 'Var'
        elif embed_type == 'mean_pool':
            embed_label = 'Mean'
        elif embed_type == 'downsample_mean':
            embed_label = 'Down+Mean'
        elif embed_type == 'downsample_variant':
            embed_label = 'Down+Var'
        else:
            embed_label = embed_type[:4].capitalize()
        
        if model == 'Caduceus':
            model_label = 'Cad'
        else:
            model_label = model
            
        label = f'{model_label}: {head_str}-{layers_str}L + {embed_label}{input_label}'
        
        # Get validation AUC at step 5000 for sorting
        val_auc_5000 = aucs_val[-1] if len(aucs_val) > 0 else 0
        
        markersize = 7 if model != 'Caduceus' else 6
        
        if len(steps_train) > 0:
            ax1.plot(steps_train, aucs_train, marker=marker, color=color, 
                    linewidth=2.5, linestyle=linestyle, markersize=markersize, alpha=alpha)
        if len(steps_val) > 0:
            line, = ax2.plot(steps_val, aucs_val, marker=marker, color=color, 
                           linewidth=2.5, linestyle=linestyle, markersize=markersize, alpha=alpha)
            plotted_models.append((val_auc_5000, label, line))

# Plot variant position models (solid lines)
for models_dict, alpha_val in [(nt2_models, 1.0), (nt1_models, 1.0), (cad_models, 0.8)]:
    if models_dict['variant_position']:
        plot_models(models_dict['variant_position'], ax1, ax2, 'variant_position', 
                   linestyle='-', alpha=alpha_val)

# Plot mean pooling models (dashed lines)
for models_dict, alpha_val in [(nt2_models, 1.0), (cad_models, 0.8)]:
    if models_dict['mean_pool']:
        plot_models(models_dict['mean_pool'], ax1, ax2, 'mean_pool', 
                   linestyle='--', alpha=alpha_val)

# Plot downsample models (dotted lines)
for models_dict, alpha_val in [(cad_models, 0.8)]:
    if models_dict['downsample_mean']:
        plot_models(models_dict['downsample_mean'], ax1, ax2, 'downsample_mean', 
                   linestyle=':', alpha=alpha_val)
    if models_dict['downsample_variant']:
        plot_models(models_dict['downsample_variant'], ax1, ax2, 'downsample_variant', 
                   linestyle=':', alpha=alpha_val)

# Sort legend by validation AUC
plotted_models.sort(key=lambda x: x[0], reverse=True)

# Configure axes
ax1.set_xlabel('Training Step', fontsize=20, fontweight='bold')
ax1.set_ylabel('Training AUC', fontsize=20, fontweight='bold')
ax1.set_title('Training Performance', fontsize=20, fontweight='bold')
ax1.set_ylim([0.4, 1.0])
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Validation Step', fontsize=20, fontweight='bold')
ax2.set_ylabel('Validation AUC', fontsize=20, fontweight='bold')
ax2.set_title('Validation Performance', fontsize=20, fontweight='bold')
ax2.set_ylim([0.4, 1.0])
ax2.grid(True, alpha=0.3)

# Add performance tier annotations on validation plot
ax2.text(5300, 0.875, 'NT-2: Var', 
         fontsize=16, bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', 
         alpha=0.7, edgecolor='navy'), verticalalignment='center')

ax2.text(5300, 0.78, 'NT-2: Mean', 
         fontsize=16, bbox=dict(boxstyle='round,pad=0.4', facecolor='lightcoral', 
         alpha=0.7, edgecolor='darkred'), verticalalignment='center')

ax2.text(5300, 0.735, 'NT-1: Var', 
         fontsize=16, bbox=dict(boxstyle='round,pad=0.4', facecolor='gray', 
         alpha=0.7, edgecolor='black'), verticalalignment='center')

ax2.text(5300, 0.6, 'Cad: Var', 
         fontsize=16, bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', 
         alpha=0.7, edgecolor='navy'), verticalalignment='center')

ax2.text(5300, 0.45, 'Cad: Mean', 
         fontsize=16, bbox=dict(boxstyle='round,pad=0.4', facecolor='lightcoral', 
         alpha=0.7, edgecolor='darkred'), verticalalignment='center')

# Create custom legend elements for markers (models)
marker_legend_elements = [
    Line2D([0], [0], marker='o', color='black', linestyle='None', markersize=8, label='NT-2 (12k)'),
    Line2D([0], [0], marker='s', color='black', linestyle='None', markersize=8, label='NT-2 (6k)'),
    Line2D([0], [0], marker='^', color='black', linestyle='None', markersize=8, label='NT-1 (6k)'),
    Line2D([0], [0], marker='x', color='black', linestyle='None', markersize=6, label='Caduceus (30k)'),
]

# Create custom legend elements for line styles (embedding strategies)
linestyle_legend_elements = [
    Line2D([0], [0], color='black', linestyle='-', linewidth=2, label='Var pos'),
    Line2D([0], [0], color='black', linestyle='--', linewidth=2, label='Mean pool')
]

# Add marker legend (aligned to same x position)
marker_legend = fig.legend(handles=marker_legend_elements, loc='upper left', 
                          bbox_to_anchor=(0.82, 0.9), fontsize=16, 
                          title='Model', title_fontsize=16, frameon=True, 
                          fancybox=True, shadow=True)

# Add linestyle legend (aligned to same x position)
linestyle_legend = fig.legend(handles=linestyle_legend_elements, loc='upper left', 
                             bbox_to_anchor=(0.97, 0.9), fontsize=16, # 0.82, 0.775
                             title='Embedding Strategy', title_fontsize=16, 
                             frameon=True, fancybox=True, shadow=True)

fig.legend([x[2] for x in plotted_models], [f'{x[1]} ({x[0]:.3f})' for x in plotted_models], 
          loc='center left', bbox_to_anchor=(0.82, 0.32), fontsize=16, framealpha=0.9)


plt.suptitle('Frozen Model Comparison: NT-2, NT-1, and Caduceus', 
            fontsize=24, fontweight='bold', y=0.98)

plt.savefig(output_dir + 'fig2_classifier_embedding_comparison.pdf', dpi=300, bbox_inches='tight')
plt.savefig(output_dir + 'fig2_classifier_embedding_comparison.png', dpi=300, bbox_inches='tight')
print("\nSaved fig2_classifier_embedding_comparison.pdf/png")
print(f"  ✓ Plotted {len(plotted_models)} models total")
print(f"  ✓ NT-2: {len(nt2_models['variant_position']) + len(nt2_models['mean_pool'])} models")
print(f"  ✓ NT-1: {len(nt1_models['variant_position'])} models")
print(f"  ✓ Caduceus: {sum(len(v) for v in cad_models.values())} models")
plt.close()


FIGURE 2: Classifier Head and Embedding Strategy Comparison (NT-2, NT-1, Caduceus)

Frozen NT-2 models: 8
Frozen NT-1 models: 1
Frozen Caduceus models: 7

Saved fig2_classifier_embedding_comparison.pdf/png
  ✓ Plotted 16 models total
  ✓ NT-2: 8 models
  ✓ NT-1: 1 models
  ✓ Caduceus: 7 models


# Figure 3

In [4]:
print("\n" + "="*80)
print("FIGURE 3: LoRA Rank Comparison")
print("="*80)

# Figure 3: LoRA rank comparison (same as before but with same y-range)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Filter for rows with Num steps == 4000
df_filtered = df[df['Num steps'] == 4000]

lora_models = [
    (8, 'LoRA rank=8', 'C0'),
    (16, 'LoRA rank=16', 'C1'),
    (32, 'LoRA rank=32', 'C2'),
]

for rank, label, color in lora_models:
    # Select row with specific LoRA rank
    row_data = df_filtered[df_filtered['LoRA rank'] == rank]
    
    if len(row_data) > 0:
        row = row_data.iloc[0]  # Get the first matching row
        
        # Training curve
        steps_train, aucs_train = extract_curves(row, 'Training_Step', 4000)
        # Validation curve
        steps_val, aucs_val = extract_curves(row, 'Val_Step', 4000)
        
        if len(steps_train) > 0:
            ax1.plot(steps_train, aucs_train, marker='o', label=label, color=color, linewidth=2.5)
        if len(steps_val) > 0:
            ax2.plot(steps_val, aucs_val, marker='o', label=label, color=color, linewidth=2.5)

# Set same y-range
y_min = 0.81
y_max = 0.99
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])

ax1.set_xlabel('Training Steps', fontsize=11)
ax1.set_ylabel('Training AUC', fontsize=11)
ax1.set_title('Training Performance', fontsize=12, fontweight='bold')
ax1.legend(loc='upper left', fontsize=10)
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Training Steps', fontsize=11)
ax2.set_ylabel('Validation AUC', fontsize=11)
ax2.set_title('Validation Performance', fontsize=12, fontweight='bold')
ax2.legend(loc='upper left', fontsize=10)
ax2.grid(True, alpha=0.3)

plt.suptitle('LoRA Rank Comparison (5k samples, CNN + Variant pos.)', fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(output_dir + 'fig3_lora_rank_comparison.pdf', dpi=300, bbox_inches='tight')
plt.savefig(output_dir + 'fig3_lora_rank_comparison.png', dpi=300, bbox_inches='tight')
print("Saved fig3_lora_rank_comparison.pdf/png")
plt.close()


FIGURE 3: LoRA Rank Comparison
Saved fig3_lora_rank_comparison.pdf/png


# Figure 4

In [5]:
print("\n" + "="*80)
print("FIGURE 4: Learning Rate Comparison")
print("="*80)

# Figure 4: Learning rate comparison with training and validation
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Filter for rows with LoRA rank == 32 and Num steps == 40000
df_filtered = df[(df['LoRA rank'] == 32) & (df['Num steps'] == 40000)]

lr_models = [
    (0.00003, '3e-5', 'C0'),
    (0.00005, '5e-5', 'C1'),
    (0.00001, '1e-5', 'C2'),
]

for lr_value, lr_label, color in lr_models:
    # Select row with specific learning rate
    row_data = df_filtered[df_filtered['Learning rate'] == lr_value]
    
    if len(row_data) > 0:
        row = row_data.iloc[0]  # Get the first matching row
        
        # Training curves
        steps_train, aucs_train = extract_curves(row, 'Training_Step', 14000)
        # Validation curves
        steps_val, aucs_val = extract_curves(row, 'Val_Step', 14000)
        
        if len(steps_train) > 0:
            ax1.plot(steps_train, aucs_train, marker='o', label=f'LR={lr_label}', color=color, linewidth=2.5)
        if len(steps_val) > 0:
            ax2.plot(steps_val, aucs_val, marker='o', label=f'LR={lr_label}', color=color, linewidth=2.5)

# Set same y-range
y_min = 0.7
y_max = 0.97
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])

ax1.set_xlabel('Training Steps', fontsize=11)
ax1.set_ylabel('Training AUC', fontsize=11)
ax1.set_title('Training Performance', fontsize=12, fontweight='bold')
ax1.legend(loc='lower right', fontsize=10)
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Training Steps', fontsize=11)
ax2.set_ylabel('Validation AUC', fontsize=11)
ax2.set_title('Validation Performance', fontsize=12, fontweight='bold')
ax2.legend(loc='lower right', fontsize=10)
ax2.grid(True, alpha=0.3)

plt.suptitle('Learning Rate Optimization (44k samples, LoRA rank=32)', fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(output_dir + 'fig4_learning_rate_comparison.pdf', dpi=300, bbox_inches='tight')
plt.savefig(output_dir + 'fig4_learning_rate_comparison.png', dpi=300, bbox_inches='tight')
print("Saved fig4_learning_rate_comparison.pdf/png")
plt.close()


FIGURE 4: Learning Rate Comparison
Saved fig4_learning_rate_comparison.pdf/png


# Figure 5

In [6]:
print("\n" + "="*80)
print("FIGURE 5: Best LoRA vs Full Fine-tuning (Overfitting Analysis)")
print("="*80)

# Figure 5: Best LoRA vs unfrozen (overfitting comparison)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Best frozen model (baseline) - highest validation AUC among Num steps == 5000
df_frozen = df[df['Num steps'] == 5000]
frozen_idx = df_frozen['Best Validation AUC'].idxmax()
frozen_row = df.loc[frozen_idx]
print(f"Best frozen model: index {frozen_idx}, Val AUC = {frozen_row['Best Validation AUC']:.4f}")

steps_train, aucs_train = extract_curves(frozen_row, 'Training_Step', 5000)
steps_val, aucs_val = extract_curves(frozen_row, 'Val_Step', 5000)
if len(steps_train) > 0:
    ax1.plot(steps_train, aucs_train, marker='o', label='Frozen baseline', color='C0', linewidth=2.5, linestyle='-')
if len(steps_val) > 0:
    ax2.plot(steps_val, aucs_val, marker='o', label='Frozen baseline', color='C0', linewidth=2.5, linestyle='-')

# Best LoRA model - highest validation AUC among LoRA rank=32, Num steps=40000
df_lora = df[(df['LoRA rank'] == 32) & (df['Num steps'] == 40000)]
lora_idx = df_lora['Best Validation AUC'].idxmax()
lora_row = df.loc[lora_idx]
print(f"Best LoRA model: index {lora_idx}, Val AUC = {lora_row['Best Validation AUC']:.4f}")

steps_train, aucs_train = extract_curves(lora_row, 'Training_Step', 14000)
steps_val, aucs_val = extract_curves(lora_row, 'Val_Step', 14000)
if len(steps_train) > 0:
    ax1.plot(steps_train, aucs_train, marker='s', label='LoRA (rank=32, LR=3e-5)', color='C2', linewidth=2.5, linestyle='-')
if len(steps_val) > 0:
    ax2.plot(steps_val, aucs_val, marker='s', label='LoRA (rank=32, LR=3e-5)', color='C2', linewidth=2.5, linestyle='-')

# Unfrozen model (full fine-tuning)
df_unfrozen = df[df['Fine-tuning'] == 'unfreeze all']
unfreeze_idx = df_unfrozen.index[0]  # Get the first (and likely only) unfrozen model
unfreeze_row = df.loc[unfreeze_idx]
print(f"Unfrozen model: index {unfreeze_idx}, Val AUC = {unfreeze_row['Best Validation AUC']:.4f}")

steps_train, aucs_train = extract_curves(unfreeze_row, 'Training_Step', 10000)
steps_val, aucs_val = extract_curves(unfreeze_row, 'Val_Step', 10000)
if len(steps_train) > 0:
    ax1.plot(steps_train, aucs_train, marker='^', label='Full fine-tuning (unfrozen)', color='C3', linewidth=2.5, linestyle='--', alpha=0.8)
if len(steps_val) > 0:
    ax2.plot(steps_val, aucs_val, marker='^', label='Full fine-tuning (unfrozen)', color='C3', linewidth=2.5, linestyle='--', alpha=0.8)

# Set same y-range
y_min = 0.5
y_max = 1
ax1.set_ylim([y_min, y_max])
ax2.set_ylim([y_min, y_max])

ax1.set_xlabel('Training Steps', fontsize=11)
ax1.set_ylabel('Training AUC', fontsize=11)
ax1.set_title('Training Performance', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10, loc='lower right')
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Training Steps', fontsize=11)
ax2.set_ylabel('Validation AUC', fontsize=11)
ax2.set_title('Validation Performance', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10, loc='lower right')
ax2.grid(True, alpha=0.3)

plt.suptitle('Generalization: LoRA vs Full Fine-tuning', fontsize=13, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(output_dir + 'fig5_lora_vs_full_finetuning.pdf', dpi=300, bbox_inches='tight')
plt.savefig(output_dir + 'fig5_lora_vs_full_finetuning.png', dpi=300, bbox_inches='tight')
print("Saved fig5_lora_vs_full_finetuning.pdf/png")
plt.close()


FIGURE 5: Best LoRA vs Full Fine-tuning (Overfitting Analysis)
Best frozen model: index 6, Val AUC = 0.8710
Best LoRA model: index 12, Val AUC = 0.8861
Unfrozen model: index 7, Val AUC = 0.7880
Saved fig5_lora_vs_full_finetuning.pdf/png


# Figure 6

In [7]:
print("\n" + "="*80)
print("FIGURE 6: Model Performance Summary")
print("="*80)

# Build model data systematically from the dataframe
model_data = []

def add_model_to_summary(row, short_name):
    """Extract model information and add to summary"""
    best_auc = row['Best Validation AUC']
    model_type = row['Model']
    classifier = str(row['Classifier head']).upper() if pd.notna(row['Classifier head']) else 'UNK'
    layers = str(int(row['Classifier_layers'])) if pd.notna(row['Classifier_layers']) else '?'
    
    # Determine embedding type
    embed_std = row['Embedding_std']
    if embed_std == 'variant_position':
        embedding = 'Var'
    elif embed_std == 'mean_pool':
        embedding = 'Mean'
    elif 'downsample' in embed_std:
        embedding = 'Down'
    else:
        embedding = '?'
    
    # Fine-tuning status
    finetuning_raw = row['Fine-tuning']
    if 'lora' in str(finetuning_raw).lower():
        finetuning = 'LoRA'
    elif 'frozen' in str(finetuning_raw).lower():
        finetuning = 'Frozen'
    elif 'unfreeze' in str(finetuning_raw).lower():
        finetuning = 'Full'
    else:
        finetuning = str(finetuning_raw)
    
    # LoRA rank
    rank = str(int(row['LoRA rank'])) if pd.notna(row['LoRA rank']) else '-'
    
    # Learning rate
    lr_val = row['Learning rate']
    if pd.notna(lr_val):
        lr = f"{lr_val:.0e}".replace('e-0', 'e-')
    else:
        lr = '-'
    
    # Sample size
    samples_raw = row['Training Sample size']
    if pd.notna(samples_raw):
        if 'k' in str(samples_raw).lower():
            samples = str(samples_raw)
        else:
            samples = f"{int(samples_raw)/1000:.0f}k"
    else:
        samples = '-'
    
    # Input length
    input_length = row.get('Input length', '-')
    if pd.notna(input_length):
        input_len = f"{int(input_length)/1000:.0f}k"
    else:
        input_len = '-'
    
    # Shorten model name for Caduceus
    if model_type == 'Caduceus':
        model_display = 'Cad'
    else:
        model_display = model_type
    
    model_data.append((
        short_name,
        best_auc,
        model_display,
        classifier,
        layers,
        embedding,
        finetuning,
        rank,
        lr,
        samples,
        input_len
    ))

# Track stage boundaries
stage_boundaries = []

# Stage 1: Mean pool - Caduceus (sorted by AUC ascending - lowest first)
mean_pool_cad = frozen_cad[frozen_cad['Embedding_std'] == 'mean_pool'].copy()
mean_pool_cad = mean_pool_cad.sort_values('Best Validation AUC')
for idx, row in mean_pool_cad.iterrows():
    head = str(row['Classifier head']).upper()
    layers = int(row['Classifier_layers']) if pd.notna(row['Classifier_layers']) else 0
    add_model_to_summary(row, f'Cad {head}-{layers}L Mean')

# Stage 1: Mean pool - NT-2 (12k only, sorted by AUC ascending)
mean_pool_nt2 = frozen_nt2[frozen_nt2['Embedding_std'] == 'mean_pool'].copy()
mean_pool_nt2_12k = mean_pool_nt2[mean_pool_nt2['Input length'] == 12000]
mean_pool_nt2_12k = mean_pool_nt2_12k.sort_values('Best Validation AUC')
for idx, row in mean_pool_nt2_12k.iterrows():
    head = str(row['Classifier head']).upper()
    layers = int(row['Classifier_layers']) if pd.notna(row['Classifier_layers']) else 0
    add_model_to_summary(row, f'{head}-{layers}L Mean')

# Record boundary after mean pool
stage_boundaries.append(('mean_pool', len(model_data) - 0.5))

# Stage 1: Variant position - Caduceus (sorted by AUC ascending)
var_pos_cad = frozen_cad[frozen_cad['Embedding_std'] == 'variant_position'].copy()
var_pos_cad = var_pos_cad.sort_values('Best Validation AUC')
for idx, row in var_pos_cad.iterrows():
    head = str(row['Classifier head']).upper()
    layers = int(row['Classifier_layers']) if pd.notna(row['Classifier_layers']) else 0
    add_model_to_summary(row, f'Cad {head}-{layers}L Var')

# Stage 1: Variant position - NT-2 (12k only, sorted by AUC ascending)
var_pos_nt2 = frozen_nt2[frozen_nt2['Embedding_std'] == 'variant_position'].copy()
var_pos_nt2_12k = var_pos_nt2[var_pos_nt2['Input length'] == 12000]
var_pos_nt2_12k = var_pos_nt2_12k.sort_values('Best Validation AUC')
for idx, row in var_pos_nt2_12k.iterrows():
    head = str(row['Classifier head']).upper()
    layers = int(row['Classifier_layers']) if pd.notna(row['Classifier_layers']) else 0
    add_model_to_summary(row, f'{head}-{layers}L Var')

# Record boundary after variant position
stage_boundaries.append(('var_pos', len(model_data) - 0.5))

# Stage 1: 6k input models (NT-2 with 6k and NT-1)
# NT-2 with 6k input
var_pos_nt2_6k = var_pos_nt2[var_pos_nt2['Input length'] == 6000]
for idx, row in var_pos_nt2_6k.iterrows():
    head = str(row['Classifier head']).upper()
    layers = int(row['Classifier_layers']) if pd.notna(row['Classifier_layers']) else 0
    add_model_to_summary(row, f'{head}-{layers}L (6k)')

# NT-1
for idx, row in frozen_nt1.iterrows():
    add_model_to_summary(row, 'NT-1')

# Record boundary after 6k models
stage_boundaries.append(('6k', len(model_data) - 0.5))

# Stage 2: LoRA ranks - get models with different ranks
lora_models = df[df['Fine-tuning'].str.contains('lora', case=False, na=False)].copy()
# Filter for the 5k training samples
lora_5k = lora_models[lora_models['Training Sample size'].astype(str).str.contains('5k', na=False)]
lora_5k = lora_5k.sort_values('LoRA rank')
for idx, row in lora_5k.iterrows():
    rank = int(row['LoRA rank'])
    add_model_to_summary(row, f'LoRA r={rank}')

# Record boundary after LoRA
stage_boundaries.append(('lora', len(model_data) - 0.5))

# Stage 3: Learning rates - get models with different LRs (44k samples)
lora_44k = lora_models[lora_models['Training Sample size'].astype(str).str.contains('44k', na=False)]
lora_44k = lora_44k.sort_values('Learning rate', ascending=True)  # Changed to ascending=True
for idx, row in lora_44k.iterrows():
    lr_val = row['Learning rate']
    lr_str = f"{lr_val:.0e}".replace('e-0', 'e-')
    add_model_to_summary(row, f'LR={lr_str}')

# Record boundary after learning rates
stage_boundaries.append(('lr', len(model_data) - 0.5))

# Stage 4: Full fine-tuning
full_ft = df[df['Fine-tuning'].str.contains('unfreeze', case=False, na=False)]
for idx, row in full_ft.iterrows():
    add_model_to_summary(row, 'Full FT')

print(f"\nTotal models in summary: {len(model_data)}")

# Create Figure 5
fig = plt.figure(figsize=(22, 11))
gs = fig.add_gridspec(2, 1, height_ratios=[1, 0.4], hspace=0.3)
ax = fig.add_subplot(gs[0])
ax_table = fig.add_subplot(gs[1])
ax_table.axis('off')

# Define colors by stage and group
colors = []
lora_count = 0
lr_count = 0

# Define colors by stage and group
colors = []
lora_count = 0
lr_count = 0

# Define colors by stage and group
colors = []
lora_count = 0
lr_count = 0

for i, (name, _, model_type, _, _, embedding, finetuning, _, _, _, _) in enumerate(model_data):
    if 'LR=' in name:
        # Learning rate models in warm colors (lighter to darker)
        lr_colors = ['lightsalmon', 'salmon', 'tomato']
        colors.append(lr_colors[lr_count % len(lr_colors)])
        lr_count += 1
    elif finetuning == 'LoRA':
        # LoRA models in green gradient (lighter to darker)
        lora_colors = ['lightgreen', 'mediumseagreen', 'seagreen']
        colors.append(lora_colors[lora_count % len(lora_colors)])
        lora_count += 1
    elif finetuning == 'Full':
        # Full fine-tuning in red
        colors.append('indianred')
    elif embedding == 'Mean':
        # All mean pool in light gray
        colors.append('lightgray')
    elif embedding == 'Var':
        # All variant position in sky blue
        colors.append('skyblue')
    elif '6k' in name or model_type == 'NT-1':
        # 6k input models in purple/silver
        if model_type == 'NT-1':
            colors.append('silver')
        else:
            colors.append('mediumpurple')
    else:
        # Fallback
        colors.append('gray')
        

# Add hatching patterns
hatches = []
for i, (name, _, model_type, _, _, _, _, _, _, _, _) in enumerate(model_data):
    if model_type == 'Cad':
        hatches.append('xxx')
    elif model_type == 'NT-1':
        hatches.append('//')
    else:
        hatches.append('')

positions = list(range(len(model_data)))
names = [m[0] for m in model_data]
aucs = [m[1] for m in model_data]

bars = ax.bar(positions, aucs, color=colors, alpha=0.85, 
             edgecolor='black', linewidth=1.5)

# Apply hatching
for bar, hatch in zip(bars, hatches):
    bar.set_hatch(hatch)

# Add value labels
for i, (bar, auc) in enumerate(zip(bars, aucs)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.003,
            f'{auc:.3f}', ha='center', va='bottom', fontsize=14, fontweight='bold')

# Add vertical dividers
for stage_name, boundary in stage_boundaries:
    ax.axvline(x=boundary, color='dimgray', linestyle='--', linewidth=2.5, alpha=0.6)

# Calculate stage label positions
# Mean pool stage
mean_pool_start = 0
mean_pool_end = stage_boundaries[0][1]
mean_pool_center = (mean_pool_start + mean_pool_end) / 2

# Variant position stage
var_pos_start = stage_boundaries[0][1]
var_pos_end = stage_boundaries[1][1]
var_pos_center = (var_pos_start + var_pos_end) / 2

# 6k input stage
input_6k_start = stage_boundaries[1][1]
input_6k_end = stage_boundaries[2][1]
input_6k_center = (input_6k_start + input_6k_end) / 2

# LoRA stage
lora_start = stage_boundaries[2][1]
lora_end = stage_boundaries[3][1]
lora_center = (lora_start + lora_end) / 2

# Learning rate stage
lr_start = stage_boundaries[3][1]
lr_end = stage_boundaries[4][1]
lr_center = (lr_start + lr_end) / 2

# Full FT stage
full_ft_start = stage_boundaries[4][1]
full_ft_end = len(model_data) - 0.5
full_ft_center = (full_ft_start + full_ft_end) / 2

# Add stage labels
ax.text(mean_pool_center, 0.93, 'Stage 1:\nMean Pool', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgray', alpha=0.5, edgecolor='gray'))

ax.text(var_pos_center, 0.93, 'Stage 1:\nVariant Pos', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', alpha=0.5, edgecolor='navy'))

ax.text(input_6k_center, 0.93, 'Stage 1:\n6k Input', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='lavender', alpha=0.5, edgecolor='purple'))

ax.text(lora_center, 0.93, 'Stage 2:\nLoRA Rank', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.4', facecolor='lightgreen', alpha=0.5, edgecolor='darkgreen'))

ax.text(lr_center, 0.93, 'Stage 3:\nLearning Rate', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.4', facecolor='lightsalmon', alpha=0.5, edgecolor='darkred'))

ax.text(full_ft_center+0.5, 0.93, 'Stage 4:\nFull FT', ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='mistyrose', alpha=0.5, edgecolor='darkred'))

# Highlight best model
best_idx = aucs.index(max(aucs))
bars[best_idx].set_linewidth(4)
bars[best_idx].set_edgecolor('red')
bars[best_idx].set_zorder(10)

ax.set_xticks(positions)
ax.set_xticklabels(names, fontsize=10, rotation=45, ha='right', fontweight='bold')
ax.set_ylabel('Best Validation AUC', fontsize=20, fontweight='bold')
ax.set_title('Model Performance Summary: Progressive Optimization (NT-2, NT-1, Caduceus)', 
            fontsize=22, fontweight='bold', pad=20)
ax.set_ylim([0.40, 0.99])
ax.grid(True, axis='y', alpha=0.3, linewidth=0.5)

# Create annotation table
table_data = []
row_labels = ['Model Type', 'Classifier', 'Layers', 'Embedding', 'Fine-tuning', 
             'LoRA Rank', 'Learning Rate', 'Samples', 'Input Len']

for i, (name, auc, model_type, classifier, layers, embedding, finetuning, rank, lr, samples, input_len) in enumerate(model_data):
    col_data = [model_type, classifier, layers, embedding, finetuning, rank, lr, samples, input_len]
    table_data.append(col_data)

# Transpose
table_data_transposed = list(map(list, zip(*table_data)))

# Create table
table = ax_table.table(cellText=table_data_transposed,
                       rowLabels=row_labels,
                       colLabels=[f'{i+1}' for i in range(len(model_data))],
                       cellLoc='center',
                       rowLoc='center',
                       loc='center',
                       bbox=[0, 0, 1, 1])

table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 2.5)

# Color code table
for i in range(len(model_data)):
    table[(0, i)].set_facecolor(colors[i])
    table[(0, i)].set_alpha(0.7)
    table[(0, i)].set_text_props(weight='bold', fontsize=9)
    
    if i == best_idx:
        table[(0, i)].set_edgecolor('red')
        table[(0, i)].set_linewidth(3)

# Style row labels
for i in range(len(row_labels)):
    table[(i+1, -1)].set_facecolor('lightgray')
    table[(i+1, -1)].set_alpha(0.5)
    table[(i+1, -1)].set_text_props(weight='bold', fontsize=9)

# Add borders
for key, cell in table.get_celld().items():
    cell.set_edgecolor('black')
    cell.set_linewidth(0.5)

plt.savefig(output_dir + 'fig6_model_performance_summary.pdf', dpi=300, bbox_inches='tight')
plt.savefig(output_dir + 'fig6_model_performance_summary.png', dpi=300, bbox_inches='tight')
print("\nSaved fig6_model_performance_summary.pdf/png")
print(f"  ✓ {len(model_data)} models in progressive optimization pipeline")
print(f"  ✓ Best model: {names[best_idx]} (AUC = {aucs[best_idx]:.3f})")
print("  ✓ Stage 1: Mean pool (gray), Var pos (blue), 6k input (purple/silver)")
print("  ✓ Stage 2: LoRA ranks (green gradient)")
print("  ✓ Stage 3: Learning rates (warm gradient)")
print("  ✓ Stage 4: Full FT (red)")
plt.close()

print("\n" + "="*80)
print("Processing complete!")
print("="*80)


FIGURE 6: Model Performance Summary

Total models in summary: 23

Saved fig6_model_performance_summary.pdf/png
  ✓ 23 models in progressive optimization pipeline
  ✓ Best model: LR=3e-5 (AUC = 0.886)
  ✓ Stage 1: Mean pool (gray), Var pos (blue), 6k input (purple/silver)
  ✓ Stage 2: LoRA ranks (green gradient)
  ✓ Stage 3: Learning rates (warm gradient)
  ✓ Stage 4: Full FT (red)

Processing complete!


In [8]:
# Create DataFrame from model_data
table_df = pd.DataFrame(model_data, columns=[
    'Model Name', 
    'Best Val AUC', 
    'Model Type', 
    'Classifier', 
    'Layers', 
    'Embedding', 
    'Fine-tuning', 
    'LoRA Rank', 
    'Learning Rate', 
    'Samples',
    'Input Length'
])

# Display the DataFrame
print("\nModel Summary Table:")
print(table_df)

# Optionally save to CSV
table_df.to_csv(output_dir + 'fig6_model_summary_table.csv', index=False)
print(f"\nSaved table to {output_dir}fig6_model_summary_table.csv")

# Or transpose it to match the plot table format
table_df_transposed = table_df.set_index('Model Name').T
print("\nTransposed Table (as shown in plot):")
print(table_df_transposed)

# Save transposed version
table_df_transposed.to_csv(output_dir + 'fig6_model_summary_table_transposed.csv')


Model Summary Table:
                 Model Name  Best Val AUC Model Type   Classifier Layers  \
0           Cad MLP-2L Mean      0.459180        Cad          MLP      2   
1   Cad TRANSFORMER-2L Mean      0.560211        Cad  TRANSFORMER      2   
2           Cad CNN-2L Mean      0.670923        Cad          CNN      2   
3               MLP-2L Mean      0.751000       NT-2          MLP      2   
4       TRANSFORMER-2L Mean      0.775000       NT-2  TRANSFORMER      2   
5               CNN-2L Mean      0.778000       NT-2          CNN      2   
6            Cad MLP-2L Var      0.596876        Cad          MLP      2   
7            Cad MLP-3L Var      0.599565        Cad          MLP      3   
8    Cad TRANSFORMER-2L Var      0.700477        Cad  TRANSFORMER      2   
9            Cad CNN-2L Var      0.753605        Cad          CNN      2   
10               MLP-2L Var      0.869000       NT-2          MLP      2   
11               MLP-3L Var      0.869000       NT-2          MLP 