# Figure 3: Single Sweep Perturbation Experiments

This notebook analyzes the effects of applying perturbations at each epoch during training. For each epoch from 1-98, a single perturbation was applied and the deviation from baseline was measured.

## Import Libraries

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import numpy as np

## Load Baseline Data

In [None]:
# Load baseline CLIP results
baseline_csv = Path('../../Data/clip_results/baseline_clip_results_seed1.csv')
baseline_seed1_df = pd.read_csv(baseline_csv)

# Trim baseline at minimum test loss (early stopping point)
baseline_min_idx = baseline_seed1_df['test_loss'].idxmin()
baseline_seed1_df = baseline_seed1_df.loc[:baseline_min_idx].copy()

print(f"Loaded baseline data: {len(baseline_seed1_df)} epochs")

## Load Single Sweep Perturbation Data

In [None]:
# Load perturbation data from single sweep experiments
sweep_root = Path('../../Data/clip_results/single_sweep_experiments')
all_data_seed42 = {}

# Load all training_run*/training_res_run*.csv files
for run_dir in sorted(sweep_root.glob("training_run*")):
    if run_dir.is_dir():
        run_num = run_dir.name.split("run")[1]
        csv_file = run_dir / f"training_res_run{run_num}.csv"
        
        if csv_file.exists():
            all_data_seed42[f"training_run{run_num}"] = pd.read_csv(csv_file)

print(f"Loaded {len(all_data_seed42)} perturbation runs")
print(f"Run numbers: {sorted([int(k.split('run')[1]) for k in all_data_seed42.keys()])}")

## Plot Test Loss Deviation

In [None]:
# Calculate test loss deviation at perturbation epoch compared to baseline
perturbation_deviations = []
run_numbers_deviation = []

for run_name, df in all_data_seed42.items():
    run_num = int(run_name.split('run')[1])
    perturb_epoch = int(run_num)
    
    # Get test loss at perturbation epoch for this run
    run_at_perturb = df[df['epoch'] == perturb_epoch]
    baseline_at_perturb = baseline_seed1_df[baseline_seed1_df['epoch'] == perturb_epoch]
    
    if len(run_at_perturb) > 0 and len(baseline_at_perturb) > 0:
        run_loss = run_at_perturb.iloc[0]['test_loss']
        baseline_loss = baseline_at_perturb.iloc[0]['test_loss']
        deviation = run_loss - baseline_loss
        
        perturbation_deviations.append(deviation)
        run_numbers_deviation.append(run_num)

# Sort by run number for proper ordering
sorted_data_deviation = sorted(zip(run_numbers_deviation, perturbation_deviations))
run_numbers_deviation_sorted, perturbation_deviations_sorted = zip(*sorted_data_deviation)

# Create the bar plot
fig, ax = plt.subplots(figsize=(6, 4))
colors = ['red' if dev > 0 else 'green' for dev in perturbation_deviations_sorted]
ax.bar(run_numbers_deviation_sorted, perturbation_deviations_sorted, 
       alpha=0.7, color=colors, edgecolor='black', linewidth=0.5)

ax.set_xlabel('Perturbation Epoch', fontsize=15, fontweight='bold')
ax.set_ylabel(r'$\Delta$ Test Loss', fontsize=15, fontweight='bold')
ax.set_title('Baseline Seed 1', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--', axis='y')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.axhline(y=0, color='black', linestyle='-', alpha=0.8, linewidth=1)

plt.tight_layout()
plt.show()

# Print statistics
print(f"\nTest loss deviation at perturbation epoch statistics:")
print(f"  Overall minimum: {min(perturbation_deviations_sorted):.2f}")
print(f"  Overall maximum: {max(perturbation_deviations_sorted):.2f}")
print(f"  Mean: {np.mean(perturbation_deviations_sorted):.2f}")
print(f"  Standard deviation: {np.std(perturbation_deviations_sorted):.2f}")

positive_deviations = sum(1 for dev in perturbation_deviations_sorted if dev > 0)
negative_deviations = sum(1 for dev in perturbation_deviations_sorted if dev < 0)
print(f"\nDeviation breakdown:")
print(f"  Positive deviations (higher loss): {positive_deviations}")
print(f"  Negative deviations (lower loss): {negative_deviations}")

## Plot Behavioral Alignment Deviation

In [None]:
# Calculate behavioral alignment deviation at perturbation epoch compared to baseline
perturbation_deviations_ba = []
run_numbers_deviation_ba = []

for run_name, df in all_data_seed42.items():
    run_num = int(run_name.split('run')[1])
    perturb_epoch = int(run_num)
    
    # Get behavioral alignment at perturbation epoch for this run
    run_at_perturb = df[df['epoch'] == perturb_epoch]
    baseline_at_perturb = baseline_seed1_df[baseline_seed1_df['epoch'] == perturb_epoch]
    
    if len(run_at_perturb) > 0 and len(baseline_at_perturb) > 0:
        run_ba = run_at_perturb.iloc[0]['behavioral_rsa_rho']
        baseline_ba = baseline_at_perturb.iloc[0]['behavioral_rsa_rho']
        deviation = run_ba - baseline_ba
        
        perturbation_deviations_ba.append(deviation)
        run_numbers_deviation_ba.append(run_num)

# Sort by run number for proper ordering
sorted_data_deviation_ba = sorted(zip(run_numbers_deviation_ba, perturbation_deviations_ba))
run_numbers_deviation_ba_sorted, perturbation_deviations_ba_sorted = zip(*sorted_data_deviation_ba)

# Create the bar plot
fig, ax = plt.subplots(figsize=(6, 4))
colors = ['green' if dev > 0 else 'blue' for dev in perturbation_deviations_ba_sorted]
ax.bar(run_numbers_deviation_ba_sorted, perturbation_deviations_ba_sorted, 
       alpha=0.6, color=colors, edgecolor='black', linewidth=0.5)

ax.set_xlabel('Perturbation Epoch', fontweight='bold', fontsize=15)
ax.set_ylabel(r'$\Delta$ Behavioral Alignment', fontweight='bold', fontsize=15)
ax.grid(True, alpha=0.3, linestyle='--', axis='y')
ax.tick_params(axis='x', labelsize=15)
ax.tick_params(axis='y', labelsize=15)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.axhline(y=0, color='black', linestyle='-', alpha=0.8, linewidth=1)

plt.tight_layout()
plt.show()

# Print statistics
print(f"\nBehavioral alignment deviation at perturbation epoch statistics:")
print(f"  Overall minimum: {min(perturbation_deviations_ba_sorted):.4f}")
print(f"  Overall maximum: {max(perturbation_deviations_ba_sorted):.4f}")
print(f"  Mean: {np.mean(perturbation_deviations_ba_sorted):.4f}")
print(f"  Standard deviation: {np.std(perturbation_deviations_ba_sorted):.4f}")

positive_deviations_ba = sum(1 for dev in perturbation_deviations_ba_sorted if dev > 0)
negative_deviations_ba = sum(1 for dev in perturbation_deviations_ba_sorted if dev < 0)
print(f"\nDeviation breakdown:")
print(f"  Positive deviations (higher alignment): {positive_deviations_ba}")
print(f"  Negative deviations (lower alignment): {negative_deviations_ba}")