In [3]:
""" Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. 
"""

' Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. \n'

In [1]:
""" Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. 
"""
import torch
import os
import glob
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats

def load_and_evaluate_noises(noise_dir, wandb_run_id, similarity_method='l2'):
    """
    Load and evaluate noises from a specific directory.
    First index in selected_indices should be the correct one.
    Args:
       n noise_dir: Directory containing noise files
        wandb_run_id: WandB run ID for saving plots
        similarity_method: Either 'l2' or 'cosine' for distance calculation
    """
    # First try to load multiple target gaussian noises
    target_noise_paths = sorted(glob.glob(os.path.join(noise_dir, 'target_gaussian_noises_batch*.pt')))
    if len(target_noise_paths) > 1:
        target_noise_paths = sorted(target_noise_paths, key=lambda x: int(x.split('batch')[-1].split('.')[0]))
    device='cuda' if torch.cuda.is_available() else 'cpu'
    
    saved_all_timesteps_for_target = False
    
    if target_noise_paths:
        # Multiple target files found - concatenate them
        target_gaussian_noises_list = []
        for target_path in target_noise_paths:
            target_noise_data = torch.load(target_path)
            target_gaussian_noises_list.append(target_noise_data['target_gaussian_noises'])
        # Concatenate along batch dimension (dim=1)
        target_gaussian_noises = torch.cat(target_gaussian_noises_list, dim=1).to(device)  # shape: [timesteps, total_batches, ...]
    else:
        # Try legacy single file
        target_noise_path = os.path.join(noise_dir, 'target_gaussian_noises.pt')
        if not os.path.exists(target_noise_path):
            print(f"No target gaussian noise files found at {noise_dir}")
            return None
        target_noise_data = torch.load(target_noise_path)
        target_gaussian_noises = target_noise_data['target_gaussian_noises'].to(device) # shape: [timesteps, 1, ...]
    
    # Get all noise files except target_gaussian_noises*.pt
    noise_files = sorted([f for f in glob.glob(os.path.join(noise_dir, "*.pt")) 
                  if not any(x in f for x in ["target_gaussian_noises.pt", "target_gaussian_noises_batch"])],
                  key=lambda x: int(x.split('batch')[-1].split('_')[0]))
    
    # Load and concatenate all noise files
    all_cond_noises = []
    all_uncond_noises = []
    for noise_file in tqdm(noise_files, desc=f"Loading noise files from {os.path.basename(noise_dir)}"):
        data = torch.load(noise_file)
        all_cond_noises.append(data['conditional_noises'])  # shape: [num_selected, timesteps, batch_size, ...]
        all_uncond_noises.append(data['unconditional_noises'])  # [1, timesteps, batch_size, ...]
    
    # Concatenate along batch dimension
    
    cond_noises = torch.cat(all_cond_noises, dim=2).to(device)  # shape: [num_selected, timesteps, total_batch_size, ...]
    uncond_noises = torch.cat(all_uncond_noises, dim=2).to(device)  # [1, timesteps, total_batch_size, ...]
    
    if len(cond_noises.shape) == 6:
        cond_noises = cond_noises.reshape(cond_noises.shape[0], cond_noises.shape[1], cond_noises.shape[2], -1)
        uncond_noises = uncond_noises.reshape(uncond_noises.shape[0], uncond_noises.shape[1], uncond_noises.shape[2], -1)
        
        target_gaussian_noises = target_gaussian_noises.reshape(target_gaussian_noises.shape[0], target_gaussian_noises.shape[1], -1)
        saved_all_timesteps_for_target = True
        
    # Print shapes for debugging
    print(f"Initial shapes:")
    print(f"cond_noises: {cond_noises.shape}")
    print(f"uncond_noises: {uncond_noises.shape}")
    print(f"target_gaussian_noises: {target_gaussian_noises.shape}")
    
    # For each timestep, calculate error variances across all samples
    num_timesteps = cond_noises.shape[1]
    num_samples = cond_noises.shape[2]
    num_classes = cond_noises.shape[0]
    
    # Store errors for each timestep and sample
    true_class_errors = []  # Mean errors for class 0 (true class)
    true_class_stds = []    # Standard deviations for class 0
    incorrect_class_errors = []  # Mean errors for other classes
    incorrect_class_stds = []    # Standard deviations for other classes
    accuracies_per_timestep = []  # Track accuracy at each timestep
    all_errors = []  # Store all errors for aggregation [timestep, num_classes, batch_size]
    per_sample_error_gap = []  # Store error gaps between true and incorrect classes
    
    # For each timestep
    for t in range(num_timesteps):
        # Get target noise for this timestep
        if saved_all_timesteps_for_target:
            target_t = target_gaussian_noises[t]  # [total_batches, ...]
        else:
            target_t = target_gaussian_noises[t]  # [1, features]
            
        # Get conditional noise predictions for this timestep
        cond_t = cond_noises[:, t]  # [num_selected, total_batch_size, ...]
        
        print(f"\nShapes at timestep {t}:")
        print(f"cond_t: {cond_t.shape}")
        print(f"target_t: {target_t.shape}")
        
        # Calculate errors for each class prediction
        if similarity_method == 'l2':
            # L2 distance (lower is better)
            if not saved_all_timesteps_for_target:  # If spatial dimensions present
                cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)  # [num_classes, total_batch_size, flattened_features]
                target_t = target_t.reshape(target_t.shape[0], -1)  # [total_batch_size, flattened_features]
                
                print(f"After reshape:")
                print(f"cond_t: {cond_t.shape}")
                print(f"target_t: {target_t.shape}")
                
                # Expand target_t to match cond_t's shape for broadcasting
                target_t = target_t.unsqueeze(0).expand(num_classes, -1, -1)  # [num_classes, total_batch_size, flattened_features]
                
                print(f"After expansion:")
                print(f"target_t: {target_t.shape}")
            else:
                target_t = target_t.unsqueeze(0)
            dists = torch.norm(cond_t.to(torch.float32) - target_t.to(torch.float32), p=2, dim=-1)  # [num_classes, total_batch_size]
            errors = dists
            
        else:  # cosine similarity
            if len(cond_t.shape) == 3:  # If spatial dimensions present
                cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)  # [num_classes, total_batch_size, flattened_features]
                target_t = target_t.reshape(target_t.shape[0], -1)  # [total_batch_size, flattened_features]
                
                # Expand target_t to match cond_t's shape for broadcasting
                target_t = target_t.unsqueeze(0).expand(num_classes, -1, -1)  # [num_classes, total_batch_size, flattened_features]
            
            cond_t_norm = torch.nn.functional.normalize(cond_t, p=2, dim=-1)
            target_t_norm = torch.nn.functional.normalize(target_t, p=2, dim=-1)
            similarity = (cond_t_norm * target_t_norm).sum(dim=-1)  # [num_classes, total_batch_size]
            errors = -similarity  # Convert to error (lower is better)
        
        # Store raw errors for aggregation
        all_errors.append(errors)
        
        # Calculate per-timestep accuracy (for visualization)
        correct_predictions = (errors[0] < errors[1:].min(dim=0)[0])
        accuracies_per_timestep.append(correct_predictions.float().mean().item())
        
        # Calculate mean and std for true class
        true_class_mean = errors[0].mean().item()
        true_class_std = errors[0].std().item()
        true_class_errors.append(true_class_mean)
        true_class_stds.append(true_class_std)
        
        # Calculate mean and std for incorrect classes
        incorrect_mean = errors[1:].mean().item()
        incorrect_std = errors[1:].std().item()
        incorrect_class_errors.append(incorrect_mean)
        incorrect_class_stds.append(incorrect_std)
        
        # Calculate per-sample error gap between true class and incorrect classes
        true_class_errors_per_sample = errors[0]  # [batch_size]
        incorrect_class_errors_per_sample = errors[1:].min(dim=0)[0]  # [batch_size]
        error_gap = incorrect_class_errors_per_sample - true_class_errors_per_sample  # [batch_size]
        per_sample_error_gap.append(error_gap.cpu().numpy())
    
    # Stack all errors and calculate total accuracy
    all_errors = torch.stack(all_errors)
    summed_errors = all_errors.sum(dim=0)
    correct_predictions_total = (summed_errors[0] < summed_errors[1:].min(dim=0)[0])
    mean_accuracy = correct_predictions_total.float().mean().item()
    
    # Calculate influence of each timestep
    influences = []  # Store minimum delta needed for each timestep
    influence_stds = []  # Store standard deviation of deltas for each timestep
    
    # For each timestep
    for t in range(num_timesteps):
        # Get current timestep and other timesteps contributions
        current_timestep = all_errors[t]  # [num_classes, batch_size]
        other_timesteps_sum = summed_errors - current_timestep  # [num_classes, batch_size]
        
        # For each sample, find minimum delta needed
        sample_deltas = []
        for sample_idx in range(num_samples):
            # Get current errors from this timestep and others
            other_errors = other_timesteps_sum[:, sample_idx].cpu()  # [num_classes]
            current_errors = current_timestep[:, sample_idx].cpu()  # [num_classes]
            
            # Current prediction based on all timesteps
            total_errors = other_errors + current_errors
            current_pred = total_errors.argmin().item()
            
            # Find minimum delta needed to change prediction
            min_delta = float('inf')
            for target_class in range(num_classes):
                if target_class == current_pred:
                    continue
                
                # For target_class to win over current_pred, we need:
                # (other_errors[target_class] + (current_errors[target_class] + delta)) < (other_errors[current_pred] + current_errors[current_pred])
                # So: delta < (other_errors[current_pred] + current_errors[current_pred]) - (other_errors[target_class] + current_errors[target_class])
                margin = (other_errors[current_pred] + current_errors[current_pred]) - (other_errors[target_class] + current_errors[target_class])
                
                # How much do we need to change current timestep's contribution?
                # We need to change it by more than the margin
                delta_needed = abs(margin)
                min_delta = min(min_delta, delta_needed)
            
            sample_deltas.append(min_delta)
        
        # Store average influence and std for this timestep
        influences.append(np.mean(sample_deltas))
        influence_stds.append(np.std(sample_deltas))
    
    influences = np.array(influences)
    influence_stds = np.array(influence_stds)
    
    # Convert to numpy arrays for plotting
    true_errors = np.array(true_class_errors)
    true_stds = np.array(true_class_stds)
    incorrect_errors = np.array(incorrect_class_errors)
    incorrect_stds = np.array(incorrect_class_stds)
    accuracies = np.array(accuracies_per_timestep)
    timesteps = np.arange(num_timesteps)
    
    # Find the matching run info for this wandb_run_id
    run_info = next((run for run in run_ids if run['id'] == wandb_run_id), None)
    if run_info:
        title_info = f"{run_info['gen']} → {run_info['eval']} ({run_info['type']})"
    else:
        title_info = wandb_run_id
    
    # Create figures directory if it doesn't exist
    os.makedirs('figures', exist_ok=True)
    
    # Plot 1: Mean errors with std shading
    plt.figure(figsize=(10, 6))
    
    # Create primary axis for errors
    ax1 = plt.gca()
    # Plot true class errors with std shading
    ln1 = ax1.plot(timesteps, true_errors, label='True Class Error', color='blue')
    ax1.fill_between(timesteps, true_errors - true_stds, true_errors + true_stds,
                     alpha=0.3, color='blue', label='True Class Std')
    
    # Plot incorrect classes errors with std shading
    ln2 = ax1.plot(timesteps, incorrect_errors, label='Incorrect Classes Error', color='red')
    ax1.fill_between(timesteps, incorrect_errors - incorrect_stds, incorrect_errors + incorrect_stds,
                     alpha=0.3, color='red', label='Incorrect Classes Std')
    
    ax1.set_xlabel('Timestep')
    ax1.set_ylabel(f'Error ({similarity_method})')
    ax1.tick_params(axis='y')
    
    # Create secondary axis for accuracy
    ax2 = ax1.twinx()
    ln3 = ax2.plot(timesteps, accuracies, label='Per-timestep Accuracy', color='green', linestyle='--')
    ax2.set_ylabel('Accuracy', color='green')
    ax2.tick_params(axis='y', labelcolor='green')
    
    # Combine legends
    lns = ln1 + ln2 + ln3
    labs = [l.get_label() for l in lns]
    ax1.legend(lns, labs, loc='center right')
    
    plt.title(f'Error and Accuracy per Timestep\n{title_info}\nTotal Accuracy (all timesteps): {mean_accuracy:.3f}')
    plt.grid(True, alpha=0.3)
    plt.savefig(f'figures/{wandb_run_id}_errors_and_accuracy.pdf')
    plt.close()
    
    # Plot 2: Standard deviations only
    plt.figure(figsize=(10, 6))
    plt.plot(timesteps, true_stds, label='True Class Std', color='blue')
    plt.plot(timesteps, incorrect_stds, label='Incorrect Classes Std', color='red')
    plt.xlabel('Timestep')
    plt.ylabel(f'Standard Deviation of {similarity_method} Error')
    plt.title(f'Error Standard Deviations per Timestep\n{title_info}')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.savefig(f'figures/{wandb_run_id}_error_stds.pdf')
    plt.close()
    
    # Create correlation plots
    # Collect all errors per sample for each timestep
    errors_per_sample_timestep = []  # Will be [num_timesteps, num_samples]
    for t in range(num_timesteps):
        if saved_all_timesteps_for_target:
            target_t = target_gaussian_noises[t]
        else:
            target_t = target_gaussian_noises[t]
            
        cond_t = cond_noises[:, t]
        
        if similarity_method == 'l2':
            if not saved_all_timesteps_for_target:
                cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)
                target_t = target_t.reshape(target_t.shape[0], -1)
                target_t = target_t.unsqueeze(0).expand(num_classes, -1, -1)
            else:
                target_t = target_t.unsqueeze(0)
            dists = torch.norm(cond_t.to(torch.float32) - target_t.to(torch.float32), p=2, dim=-1)
            errors = dists
        else:
            if len(cond_t.shape) == 3:
                cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)
                target_t = target_t.reshape(target_t.shape[0], -1)
                target_t = target_t.unsqueeze(0).expand(num_classes, -1, -1)
            
            cond_t_norm = torch.nn.functional.normalize(cond_t, p=2, dim=-1)
            target_t_norm = torch.nn.functional.normalize(target_t, p=2, dim=-1)
            similarity = (cond_t_norm * target_t_norm).sum(dim=-1)
            errors = -similarity
            
        # Store raw errors for correlation analysis
        errors_per_sample_timestep.append(errors.cpu().numpy())
    
    # Convert to numpy array [num_timesteps, num_classes, num_samples]
    errors_array = np.stack(errors_per_sample_timestep)
    
    # Calculate correlations for true class (class 0)
    true_class_errors_per_sample = errors_array[:, 0, :]  # [num_timesteps, num_samples]
    true_class_pearson_corr = np.corrcoef(true_class_errors_per_sample)
    
    # Calculate Spearman rank correlation for true class
    true_class_spearman_corr = np.zeros((num_timesteps, num_timesteps))
    for i in range(num_timesteps):
        for j in range(num_timesteps):
            rho, _ = scipy.stats.spearmanr(true_class_errors_per_sample[i], true_class_errors_per_sample[j])
            true_class_spearman_corr[i, j] = rho
    
    # Calculate correlations for incorrect classes (average across classes 1+)
    incorrect_class_errors_per_sample = errors_array[:, 1:, :].mean(axis=1)  # [num_timesteps, num_samples]
    incorrect_class_pearson_corr = np.corrcoef(incorrect_class_errors_per_sample)
    
    # Calculate Spearman rank correlation for incorrect classes
    incorrect_class_spearman_corr = np.zeros((num_timesteps, num_timesteps))
    for i in range(num_timesteps):
        for j in range(num_timesteps):
            rho, _ = scipy.stats.spearmanr(incorrect_class_errors_per_sample[i], incorrect_class_errors_per_sample[j])
            incorrect_class_spearman_corr[i, j] = rho
    
    # Plot correlation heatmaps (2x2 grid: Pearson and Spearman for both true and incorrect)
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    
    # True class correlations - Pearson
    sns.heatmap(true_class_pearson_corr, ax=ax1, cmap='RdBu_r', vmin=-1, vmax=1,
                xticklabels=timesteps[::5], yticklabels=timesteps[::5])
    ax1.set_title('True Class Pearson Correlations')
    ax1.set_xlabel('Timestep')
    ax1.set_ylabel('Timestep')
    
    # True class correlations - Spearman
    sns.heatmap(true_class_spearman_corr, ax=ax2, cmap='RdBu_r', vmin=-1, vmax=1,
                xticklabels=timesteps[::5], yticklabels=timesteps[::5])
    ax2.set_title('True Class Spearman Rank Correlations')
    ax2.set_xlabel('Timestep')
    ax2.set_ylabel('Timestep')
    
    # Incorrect class correlations - Pearson
    sns.heatmap(incorrect_class_pearson_corr, ax=ax3, cmap='RdBu_r', vmin=-1, vmax=1,
                xticklabels=timesteps[::5], yticklabels=timesteps[::5])
    ax3.set_title('Incorrect Classes Pearson Correlations')
    ax3.set_xlabel('Timestep')
    ax3.set_ylabel('Timestep')
    
    # Incorrect class correlations - Spearman
    sns.heatmap(incorrect_class_spearman_corr, ax=ax4, cmap='RdBu_r', vmin=-1, vmax=1,
                xticklabels=timesteps[::5], yticklabels=timesteps[::5])
    ax4.set_title('Incorrect Classes Spearman Rank Correlations')
    ax4.set_xlabel('Timestep')
    ax4.set_ylabel('Timestep')
    
    plt.suptitle(f'Error Correlations Between Timesteps\n{title_info}')
    plt.tight_layout()
    plt.savefig(f'figures/{wandb_run_id}_error_correlations.pdf')
    plt.close()
    
    # Plot influences with std shading
    plt.figure(figsize=(10, 6))
    plt.plot(timesteps, influences, label='Mean Timestep Influence', color='purple')
    plt.fill_between(timesteps, influences - influence_stds, influences + influence_stds,
                     alpha=0.3, color='purple', label='Influence Std')
    plt.xlabel('Timestep')
    plt.ylabel('Average Influence (min delta needed)')
    plt.title(f'Timestep Influence on Final Prediction\n{title_info}')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.savefig(f'figures/{wandb_run_id}_timestep_influence.pdf')
    plt.close()
    
    # Return statistics
    return {
        'true_errors': true_errors,
        'true_stds': true_stds,
        'incorrect_errors': incorrect_errors,
        'incorrect_stds': incorrect_stds,
        'accuracies': accuracies,
        'timesteps': timesteps,
        'num_samples': num_samples,
        'mean_accuracy': mean_accuracy,
        'summed_errors': summed_errors.cpu().numpy(),
        'influences': influences,
        'influence_stds': influence_stds,
        'true_class_pearson_correlations': true_class_pearson_corr,
        'true_class_spearman_correlations': true_class_spearman_corr,
        'incorrect_class_pearson_correlations': incorrect_class_pearson_corr,
        'incorrect_class_spearman_correlations': incorrect_class_spearman_corr,
        'per_sample_error_gap': np.array(per_sample_error_gap) , # [num_timesteps, batch_size]
        'errors_array': errors_array,
    }

# Take only SD3 runs of few categories
# run_ids = [
#     # {"gen": "sd3", "eval":"sd3", "type": "whatsup_A", "id": "2zfmbdgb"},
#     # {"gen": "sd3", "eval":"sd3", "type": "whatsup_B", "id": "72srd0uj"},
    
#     {"gen": "sd3", "eval":"sd2", "type": "position", "id": "aut3pbwh"},
#     {"gen": "sd3", "eval":"sd2", "type": "counting", "id": "lxu0ohji"},
#     {"gen": "sd3", "eval":"sd2", "type": "color_attr", "id": "8axobt1l"},
#     # # {"gen": "sd3", "eval":"sd2", "type": "single", "id": "wwieih3x"},
    
#     {"gen": "sd3", "eval":"sd1.5", "type": "position", "id": "curxv9va"},
#     {"gen": "sd3", "eval":"sd1.5", "type": "counting", "id": "9j8saxms"},
#     {"gen": "sd3", "eval":"sd1.5", "type": "color_attr", "id": "ok41tk28"},
#     # # {"gen": "sd3", "eval":"sd1.5", "type": "single", "id": "4xuy2o92"},
    
#     {"gen": "sd3", "eval":"sd3", "type": "counting", "id": "1tcxc2rz"},
#     {"gen": "sd3", "eval":"sd3", "type": "color_attr", "id": "znbz2uhj"},
#     {"gen": "sd3", "eval":"sd3", "type": "position", "id": "4xuy2o92"},
    
# ]


run_ids = [
    {"gen": "sd3", "eval": "sd2", "type": "position", "id": "aut3pbwh"},
    {"gen": "sd3", "eval": "sd2", "type": "counting", "id": "lxu0ohji"},
    {"gen": "sd3", "eval": "sd2", "type": "color_attr", "id": "8axobt1l"},
    
    {"gen": "sd3", "eval": "sd1.5", "type": "position", "id": "curxv9va"},
    {"gen": "sd3", "eval": "sd1.5", "type": "counting", "id": "9j8saxms"},
    {"gen": "sd3", "eval": "sd1.5", "type": "color_attr", "id": "ok41tk28"},
    
    {"gen": "sd3", "eval": "sd3", "type": "counting", "id": "1tcxc2rz"},
    {"gen": "sd3", "eval": "sd3", "type": "color_attr", "id": "znbz2uhj"},
    {"gen": "sd3", "eval": "sd3", "type": "position", "id": "4xuy2o92"},
    
    {"gen": "whatsup-a", "eval": "sd1.5", "type": "whatsup-a", "id": "gtqrvf9c"},
    {"gen": "whatsup-a", "eval": "sd2", "type": "whatsup-a", "id": "qtwnqk5c"},
    {"gen": "whatsup-a", "eval": "sd3", "type": "whatsup-a", "id": "2zfmbdgb"},
    
    {"gen": "whatsup-b", "eval": "sd1.5", "type": "whatsup-b", "id": "vqudhuh5"},
    {"gen": "whatsup-b", "eval": "sd2", "type": "whatsup-b", "id": "784a6ywm"},
    {"gen": "whatsup-b", "eval": "sd3", "type": "whatsup-b", "id": "72srd0uj"},
    
    {"gen": "sd1.5-cnt", "eval": "sd1.5", "type": "sd1.5-cnt", "id": "h315owa1"},
    {"gen": "sd1.5-cnt", "eval": "sd2", "type": "sd1.5-cnt", "id": "pctg0k6y"},
    {"gen": "sd1.5-cnt", "eval": "sd3", "type": "sd1.5-cnt", "id": "jx3fdoii"},
    
    {"gen": "sd2-cnt", "eval": "sd1.5", "type": "sd2-cnt", "id": "xc0afsg5"},
    {"gen": "sd2-cnt", "eval": "sd2", "type": "sd2-cnt", "id": "cmlxwe5a"},
    {"gen": "sd2-cnt", "eval": "sd3", "type": "sd2-cnt", "id": "ygjlsq3o"},
]


In [2]:
class Args:
    def __init__(self):
        self.base_dir = '/mnt/lustre/work/oh/owl661/compositional-vaes/noise_results'
        self.run_id = None
        self.similarity = 'l2'

args = Args()

if args.run_id:
    # Find the directory containing this run_id
    run_dirs = []
    for root, dirs, files in os.walk(args.base_dir):
        if args.run_id in dirs:
            run_dirs.append(os.path.join(root, args.run_id))
    
    if not run_dirs:
        print(f"No directory found for run_id: {args.run_id}")
        exit(1)
        
    # Process the specified run
    for run_dir in run_dirs:
        print(f"\nEvaluating run: {args.run_id}")
        stats = load_and_evaluate_noises(run_dir, args.run_id, args.similarity)
        if stats is None:
            continue
        
        print(f"\nResults for {args.run_id}:")
        print(f"Total samples: {stats['num_samples']}")
        print(f"Mean error across all timesteps: {stats['true_errors'].mean():.4f}")
        print(f"Mean std across all timesteps: {stats['incorrect_errors'].mean():.4f}")
        print(f"Max std at timestep {stats['timesteps'][stats['incorrect_errors'].argmax()]}: {stats['incorrect_errors'].max():.4f}")
        print(f"Min std at timestep {stats['timesteps'][stats['incorrect_errors'].argmin()]}: {stats['incorrect_errors'].min():.4f}")
        print(f"Mean accuracy: {stats['mean_accuracy']:.3f}")
else:
    # Process all runs in run_ids list and collect stats
    all_stats = {}
    for run_info in run_ids:
        run_id = run_info['id']
        run_dir = None
        
        # Find the directory for this run_id
        for root, dirs, files in os.walk(args.base_dir):
            if run_id in dirs:
                run_dir = os.path.join(root, run_id)
                break
        
        if run_dir is None:
            print(f"No directory found for run_id: {run_id}")
            continue
            
        print(f"\nEvaluating run: {run_id} ({run_info['gen']}->{run_info['eval']}, {run_info['type']})")
        stats = load_and_evaluate_noises(run_dir, run_id, args.similarity)
        if stats is None:
            continue
        
        # Store stats for this run
        all_stats[run_id] = stats
        
        print(f"Total samples: {stats['num_samples']}")
        print(f"Mean error across all timesteps: {stats['true_errors'].mean():.4f}")
        print(f"Mean std across all timesteps: {stats['incorrect_errors'].mean():.4f}")
        print(f"Max std at timestep {stats['timesteps'][stats['incorrect_errors'].argmax()]}: {stats['incorrect_errors'].max():.4f}")
        print(f"Min std at timestep {stats['timesteps'][stats['incorrect_errors'].argmin()]}: {stats['incorrect_errors'].min():.4f}")
        print(f"Mean accuracy: {stats['mean_accuracy']:.3f}")
    
    # Create aggregate comparison plot
    create_aggregate_comparison_plot(all_stats, run_ids)




Evaluating run: aut3pbwh (sd3->sd2, position)


  target_noise_data = torch.load(target_noise_path)
  data = torch.load(noise_file)
Loading noise files from aut3pbwh: 100%|██████████| 2/2 [00:00<00:00,  2.19it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 113, 16384])
uncond_noises: torch.Size([1, 30, 113, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 113
Mean error across all timesteps: 37.1801
Mean std across all timesteps: 37.1849
Max std at timestep 0: 87.8117
Min std at timestep 29: 7.2873
Mean accuracy: 0.372

Evaluating run: lxu0ohji (sd3->sd2, counting)


  data = torch.load(noise_file)
Loading noise files from lxu0ohji: 100%|██████████| 4/4 [00:01<00:00,  3.51it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 230, 16384])
uncond_noises: torch.Size([1, 30, 230, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 230
Mean error across all timesteps: 34.2766
Mean std across all timesteps: 34.3095
Max std at timestep 0: 81.6847
Min std at timestep 29: 7.0640
Mean accuracy: 0.609

Evaluating run: 8axobt1l (sd3->sd2, color_attr)


  data = torch.load(noise_file)
Loading noise files from 8axobt1l: 100%|██████████| 8/8 [00:01<00:00,  6.51it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 252, 16384])
uncond_noises: torch.Size([1, 30, 252, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 252
Mean error across all timesteps: 33.2698
Mean std across all timesteps: 33.3401
Max std at timestep 0: 78.8245
Min std at timestep 29: 7.5274
Mean accuracy: 0.952

Evaluating run: curxv9va (sd3->sd1.5, position)


  data = torch.load(noise_file)
Loading noise files from curxv9va: 100%|██████████| 2/2 [00:00<00:00,  3.90it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 113, 16384])
uncond_noises: torch.Size([1, 30, 113, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 113, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 113, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 113
Mean error across all timesteps: 37.2327
Mean std across all timesteps: 37.2335
Max std at timestep 0: 88.0518
Min std at timestep 29: 7.2980
Mean accuracy: 0.274

Evaluating run: 9j8saxms (sd3->sd1.5, counting)


  data = torch.load(noise_file)
Loading noise files from 9j8saxms: 100%|██████████| 4/4 [00:01<00:00,  3.43it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 230, 16384])
uncond_noises: torch.Size([1, 30, 230, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 230, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 230, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 230
Mean error across all timesteps: 34.3401
Mean std across all timesteps: 34.3637
Max std at timestep 0: 81.9747
Min std at timestep 29: 7.0496
Mean accuracy: 0.609

Evaluating run: ok41tk28 (sd3->sd1.5, color_attr)


  data = torch.load(noise_file)
Loading noise files from ok41tk28: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 252, 16384])
uncond_noises: torch.Size([1, 30, 252, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 252, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 252, 16384])
target_t: t

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from 1tcxc2rz: 100%|██████████| 58/58 [00:15<00:00,  3.71it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 230, 262144])
uncond_noises: torch.Size([1, 30, 230, 262144])
target_gaussian_noises: torch.Size([30, 230, 262144])

Shapes at timestep 0:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 1:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 2:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 3:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 4:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 5:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 6:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 7:
cond_t: torch.Size([4, 230, 262144])
target_t: torch.Size([230, 262144])

Shapes at timestep 8:
cond_t: torch.Size([4, 230, 262144])
target_t

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from znbz2uhj: 100%|██████████| 63/63 [00:04<00:00, 13.36it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 252, 65536])
uncond_noises: torch.Size([1, 30, 252, 65536])
target_gaussian_noises: torch.Size([30, 252, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 252, 65536])
target_t: torch.Size([252, 6

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from 4xuy2o92: 100%|██████████| 4/4 [00:01<00:00,  2.15it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 113, 65536])
uncond_noises: torch.Size([1, 30, 113, 65536])
target_gaussian_noises: torch.Size([30, 113, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 113, 65536])
target_t: torch.Size([113, 6

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from gtqrvf9c: 100%|██████████| 13/13 [00:01<00:00,  6.83it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 412, 16384])
uncond_noises: torch.Size([1, 30, 412, 16384])
target_gaussian_noises: torch.Size([30, 412, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 41

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from qtwnqk5c: 100%|██████████| 13/13 [00:01<00:00,  7.41it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 412, 16384])
uncond_noises: torch.Size([1, 30, 412, 16384])
target_gaussian_noises: torch.Size([30, 412, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After expansion:
target_t: torch.Size([4, 412, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 412, 16384])
target_t: torch.Size([412, 16384])
After reshape:
cond_t: torch.Size([4, 41

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from 2zfmbdgb: 100%|██████████| 26/26 [00:06<00:00,  3.73it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 412, 65536])
uncond_noises: torch.Size([1, 30, 412, 65536])
target_gaussian_noises: torch.Size([30, 412, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 412, 65536])
target_t: torch.Size([412, 6

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from vqudhuh5: 100%|██████████| 13/13 [00:01<00:00,  7.60it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 408, 16384])
uncond_noises: torch.Size([1, 30, 408, 16384])
target_gaussian_noises: torch.Size([30, 408, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 40

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from 784a6ywm: 100%|██████████| 13/13 [00:01<00:00,  6.93it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 408, 16384])
uncond_noises: torch.Size([1, 30, 408, 16384])
target_gaussian_noises: torch.Size([30, 408, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After expansion:
target_t: torch.Size([4, 408, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 408, 16384])
target_t: torch.Size([408, 16384])
After reshape:
cond_t: torch.Size([4, 40

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from 72srd0uj: 100%|██████████| 26/26 [00:06<00:00,  3.96it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 408, 65536])
uncond_noises: torch.Size([1, 30, 408, 65536])
target_gaussian_noises: torch.Size([30, 408, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 408, 65536])
target_t: torch.Size([408, 6

  target_noise_data = torch.load(target_noise_path)


Total samples: 408
Mean error across all timesteps: 103.4205
Mean std across all timesteps: 103.4922
Max std at timestep 0: 219.2736
Min std at timestep 13: 70.5974
Mean accuracy: 0.304

Evaluating run: h315owa1 (sd1.5-cnt->sd1.5, sd1.5-cnt)


  data = torch.load(noise_file)
Loading noise files from h315owa1: 100%|██████████| 2/2 [00:00<00:00,  4.71it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 98, 16384])
uncond_noises: torch.Size([1, 30, 98, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size(

  target_noise_data = torch.load(target_noise_path)


Total samples: 98
Mean error across all timesteps: 38.9019
Mean std across all timesteps: 38.9356
Max std at timestep 0: 92.7239
Min std at timestep 29: 7.6618
Mean accuracy: 0.755

Evaluating run: pctg0k6y (sd1.5-cnt->sd2, sd1.5-cnt)


  data = torch.load(noise_file)
Loading noise files from pctg0k6y: 100%|██████████| 2/2 [00:00<00:00,  4.57it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 98, 16384])
uncond_noises: torch.Size([1, 30, 98, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 98, 16384])
target_t: torch.Size(

  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from jx3fdoii: 100%|██████████| 4/4 [00:01<00:00,  2.70it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 98, 65536])
uncond_noises: torch.Size([1, 30, 98, 65536])
target_gaussian_noises: torch.Size([30, 98, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 98, 65536])
target_t: torch.Size([98, 65536])

Shapes at tim

  target_noise_data = torch.load(target_noise_path)


Total samples: 98
Mean error across all timesteps: 139.1871
Mean std across all timesteps: 139.5909
Max std at timestep 0: 233.5324
Min std at timestep 13: 109.6694
Mean accuracy: 0.327

Evaluating run: xc0afsg5 (sd2-cnt->sd1.5, sd2-cnt)


  data = torch.load(noise_file)
Loading noise files from xc0afsg5: 100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 111, 16384])
uncond_noises: torch.Size([1, 30, 111, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 111, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 111, 16384])
target_t: t

  target_noise_data = torch.load(target_noise_path)


Total samples: 111
Mean error across all timesteps: 38.3425
Mean std across all timesteps: 38.3605
Max std at timestep 0: 91.7874
Min std at timestep 29: 7.7348
Mean accuracy: 0.613

Evaluating run: cmlxwe5a (sd2-cnt->sd2, sd2-cnt)


  data = torch.load(noise_file)
Loading noise files from cmlxwe5a: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]

Initial shapes:
cond_noises: torch.Size([4, 30, 19, 16384])
uncond_noises: torch.Size([1, 30, 19, 16384])
target_gaussian_noises: torch.Size([30, 1, 16384])

Shapes at timestep 0:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 1:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 2:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After expansion:
target_t: torch.Size([4, 1, 16384])

Shapes at timestep 3:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size([1, 16384])
After reshape:
cond_t: torch.Size([4, 19, 16384])
target_t: torch.Size(




Total samples: 19
Mean error across all timesteps: 39.2746
Mean std across all timesteps: 39.2955
Max std at timestep 0: 93.7110
Min std at timestep 29: 7.7130
Mean accuracy: 0.842

Evaluating run: ygjlsq3o (sd2-cnt->sd3, sd2-cnt)


  target_noise_data = torch.load(target_path)
  data = torch.load(noise_file)
Loading noise files from ygjlsq3o: 100%|██████████| 4/4 [00:01<00:00,  2.41it/s]


Initial shapes:
cond_noises: torch.Size([4, 30, 111, 65536])
uncond_noises: torch.Size([1, 30, 111, 65536])
target_gaussian_noises: torch.Size([30, 111, 65536])

Shapes at timestep 0:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 1:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 2:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 3:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 4:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 5:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 6:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 7:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 65536])

Shapes at timestep 8:
cond_t: torch.Size([4, 111, 65536])
target_t: torch.Size([111, 6

NameError: name 'create_aggregate_comparison_plot' is not defined

In [3]:
# plt stuff - hide top right splines
import seaborn as sns
sns.reset_defaults()
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

# global sizes of font 8
plt.rcParams['font.size'] = 8
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 8
plt.rcParams['legend.fontsize'] = 8
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8

In [73]:
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
def create_aggregate_comparison_plot(all_stats, run_ids):
    """Create aggregate plots comparing SD2 and SD3 evaluations for each task type."""
    task_types = ['counting', 'position', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    w,h=8,1.5
    fig, axes = plt.subplots(1, 7, figsize=(w,h), sharey=True)
    
    # Get colorblind-friendly colors
    colors = sns.color_palette("colorblind")
    color1, color2 = colors[0], colors[1]
    # Enable right spine for the second y-axis that will show differences
    for ax in axes:
        ax2 = ax.get_shared_y_axes().get_siblings(ax)[0]  # Get the twin axis
        ax2.spines['right'].set_visible(True)
    
    for task_idx, task_type in enumerate(task_types):
        ax = axes[task_idx]
        ax2 = ax.twinx()  # Create second y-axis
        ax2.spines['right'].set_visible(True)
        
        # Get runs for this task type
        sd2_run = next(run for run in run_ids if run['eval'] == 'sd2' and run['type'] == task_type)
        sd3_run = next(run for run in run_ids if run['eval'] == 'sd3' and run['type'] == task_type)
        
        # Get stats and normalize each model's errors to [0,1] range
        sd2_stats = all_stats[sd2_run['id']]
        sd3_stats = all_stats[sd3_run['id']]
        
        # Normalize SD2 errors
        sd2_errors = sd2_stats['true_errors']
        sd2_stds = sd2_stats['true_stds']
        sd2_errors_norm = (sd2_errors - sd2_errors.min()) / (sd2_errors.max() - sd2_errors.min())
        sd2_stds_norm = sd2_stds / (sd2_errors.max() - sd2_errors.min())
        
        # Normalize SD3 errors
        sd3_errors = sd3_stats['true_errors']
        sd3_stds = sd3_stats['true_stds']
        sd3_errors_norm = (sd3_errors - sd3_errors.min()) / (sd3_errors.max() - sd3_errors.min())
        sd3_stds_norm = sd3_stds / (sd3_errors.max() - sd3_errors.min())
        
        # # Plot SD2 evaluation
        # ax.plot(sd2_stats['timesteps'] / 30, sd2_errors_norm, 
        #         label='SD2', color=color1, linestyle='-', linewidth=0.5)
        # ax.fill_between(sd2_stats['timesteps'] / 30, 
        #                sd2_errors_norm - sd2_stds_norm,
        #                sd2_errors_norm + sd2_stds_norm,
        #                alpha=0.2, color=color1, edgecolor=None)
        
        # Plot SD3 evaluation
        # ax.plot(sd3_stats['timesteps'] / 30, sd3_errors_norm,
        #         label='SD3', color=color2, linestyle='-', linewidth=0.5)
        # ax.fill_between(sd3_stats['timesteps'] / 30,
        #                sd3_errors_norm - sd3_stds_norm,
        #                sd3_errors_norm + sd3_stds_norm,
        #                alpha=0.2, color=color2, edgecolor=None)
        
        # Plot mean per-sample error gaps on second y-axis
        # Calculate mean and std per timestep across samples
        sd3_errors = sd3_stats['per_sample_error_gap']
        sd2_errors = sd2_stats['per_sample_error_gap']
        
        
        
        # Normalize the error gaps by dividing by max value
        sd2_norm = sd2_errors / np.max(np.abs(sd2_errors), axis=1, keepdims=True)
        sd3_norm = sd3_errors / np.max(np.abs(sd3_errors), axis=1, keepdims=True)
        # Create violin plots at each timestep
        positions = sd2_stats['timesteps'] / 30
        # for t_idx, t in enumerate(positions):
        #     # Create violin plots for SD2
        #     df_sd2 = pd.DataFrame({'x': [t]*len(sd2_norm[t_idx]), 'y': sd2_norm[t_idx]})
        #     # sns.violinplot(data=df_sd2, x='x', y='y', ax=ax2, color=color1, alpha=0.4)
        #     sns.stripplot(data=df_sd2, x='x', y='y', ax=ax2, color=color1, alpha=0.4, size=1, jitter=0.02)
            
        #     # Create violin plots for SD3
        #     df_sd3 = pd.DataFrame({'x': [t]*len(sd3_norm[t_idx]), 'y': sd3_norm[t_idx]})
        #     # sns.violinplot(data=df_sd3, x='x', y='y', ax=ax2, color=color2, alpha=0.4)
        #     sns.stripplot(data=df_sd3, x='x', y='y', ax=ax2, color=color2, alpha=0.4, size=1, jitter=0.02)
        
        # Calculate rate of positive errors for each timestep
        sd2_err = np.mean(sd2_errors, axis=1)
        sd3_err = np.mean(sd3_errors, axis=1)
        
        sd2_err = sd2_err / (np.abs(sd2_err).max())
        sd3_err = sd3_err / (np.abs(sd3_err).max())
        
        sd2_std=np.std(sd2_errors, axis=1)
        sd3_std=np.std(sd3_errors, axis=1)
        
        # plot with dashed
        ax2.plot(positions, sd2_err, color=color1, linestyle='--', linewidth=1, label='SD2')
        ax2.plot(positions, sd3_err, color=color2, linestyle='--', linewidth=1, label='SD3')
        ax2.set_ylabel('Error')
        
        # plot stds shading
        # ax2.fill_between(positions, sd2_err - sd2_std, sd2_err + sd2_std, color=color1, alpha=0.2)
        # ax2.fill_between(positions, sd3_err - sd3_std, sd3_err + sd3_std, color=color2, alpha=0.2)
        
        
        
        
        
        # Plot positive error rates
        # ax2.plot(positions, sd2_pos_rate, color=color1, linestyle='-', linewidth=1, label='SD2')
        # ax2.plot(positions, sd3_pos_rate, color=color2, linestyle='-', linewidth=1, label='SD3')
        # ax2.set_ylabel('Rate of Positive Errors')
        
        ax.set_title(f'{task_type.replace("_", " ").title()}')
        ax.set_xlabel('Timestep')
        if task_idx == 0:
            ax.set_ylabel('Normalized Error')
        if task_idx == 1:
            ax.legend(fontsize=7, frameon=False, loc='lower center', bbox_to_anchor=(3.5, -0.74), ncol=2)
    
        ax.locator_params(axis='x', nbins=4)
        ax2.locator_params(axis='x', nbins=4)
        ax.locator_params(axis='y', nbins=4)
        ax2.locator_params(axis='y', nbins=4)
        ax2.yaxis.set_major_locator(plt.MaxNLocator(4))  # Force exactly 4 ticks on ax2 y-axis
        plt.rc('xtick', labelsize=5)
        plt.rc('ytick', labelsize=5)
    # resize to w, h, force
    fig.set_size_inches(w, h)
    # subplot adjust
    plt.subplots_adjust(left=0.1, right=0.9, top=0.85, bottom=0.35, wspace=0.7)
    plt.savefig('figures/aggregate_comparison.pdf', pad_inches=0)
    plt.close()


In [74]:
create_aggregate_comparison_plot(all_stats, run_ids)

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


In [23]:
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import gaussian_kde

def create_aggregate_comparison_plot_violin(all_stats, run_ids):
    """Create aggregate plots comparing SD2 and SD3 evaluations for each task type."""
    task_types = ['counting', 'position', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    w,h=18,3 # Increased height to accommodate two rows
    fig, axes = plt.subplots(2, 7, figsize=(w,h), sharex=True)
    
    # Get colorblind-friendly colors
    colors = sns.color_palette("colorblind")
    color1, color2 = colors[0], colors[1]

    for task_idx, task_type in enumerate(task_types):
        ax_sd2 = axes[0, task_idx]  # Top row for SD2
        ax_sd3 = axes[1, task_idx]  # Bottom row for SD3
        
        # Get runs for this task type
        sd2_run = next(run for run in run_ids if run['eval'] == 'sd2' and run['type'] == task_type)
        sd3_run = next(run for run in run_ids if run['eval'] == 'sd3' and run['type'] == task_type)
        
        # Get stats and normalize
        sd2_stats = all_stats[sd2_run['id']]
        sd3_stats = all_stats[sd3_run['id']]
        
        # Get error gaps and normalize
        sd2_errors = sd2_stats['per_sample_error_gap']
        sd3_errors = sd3_stats['per_sample_error_gap']
        
        # Normalize to [0,1] range per sample
        sd2_norm = sd2_errors / np.max(sd2_errors, axis=1, keepdims=True)
        sd3_norm = sd3_errors / np.max(sd3_errors, axis=1, keepdims=True)
        
        # Create density plots at each timestep
        timesteps = sd2_stats['timesteps'] / 30
        
        # Create meshgrid for density plot
        X, Y = np.meshgrid(timesteps, np.linspace(0, 1, 100))
        Z_sd2 = np.zeros_like(X)
        Z_sd3 = np.zeros_like(X)
        
        # Calculate density for each timestep
        for t_idx, _ in enumerate(timesteps):
            kernel_sd2 = gaussian_kde(sd2_norm[t_idx])
            kernel_sd3 = gaussian_kde(sd3_norm[t_idx])
            Z_sd2[:, t_idx] = kernel_sd2(Y[:, t_idx])
            Z_sd3[:, t_idx] = kernel_sd3(Y[:, t_idx])
        
        # Normalize densities
        Z_sd2 = Z_sd2 / Z_sd2.max()
        Z_sd3 = Z_sd3 / Z_sd3.max()
        
        # Plot densities
        ax_sd2.pcolormesh(X, Y, Z_sd2, cmap='viridis', shading='auto')
        ax_sd3.pcolormesh(X, Y, Z_sd3, cmap='viridis', shading='auto')
        
        # Customize plots
        if task_idx == 0:
            ax_sd2.set_ylabel('SD2\nNormalized Range')
            ax_sd3.set_ylabel('SD3\nNormalized Range')
        
        ax_sd3.set_xlabel('Timestep')
        ax_sd2.set_title(f'{task_type.replace("_", " ").title()}')
        
        # Set ticks
        ax_sd2.locator_params(axis='x', nbins=4)
        ax_sd3.locator_params(axis='x', nbins=4)
        ax_sd2.locator_params(axis='y', nbins=4)
        ax_sd3.locator_params(axis='y', nbins=4)
        
        plt.rc('xtick', labelsize=5)
        plt.rc('ytick', labelsize=5)

    # Adjust layout
    plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.3, wspace=0.7)
    plt.savefig('figures/aggregate_comparison_density.pdf', pad_inches=0)
    plt.close()

create_aggregate_comparison_plot_violin(all_stats, run_ids)


In [2]:
def create_accuracy_comparison_plot(all_stats, run_ids):
    """Create plots comparing per-timestep and overall accuracy for SD1.5, SD2, and SD3."""
    task_types = ['counting', 'position', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    w, h = 7, 3  # Increased height to accommodate second row
    fig, axes = plt.subplots(2, 7, figsize=(w, h))
    
    # Get colorblind-friendly colors
    colors = sns.color_palette("colorblind")
    color1, color2, color3 = colors[0], colors[1], colors[2]
    
    for task_idx, task_type in enumerate(task_types):
        ax_top = axes[0, task_idx]
        ax_bottom = axes[1, task_idx]
        
        # Get runs for this task type
        sd15_run = next(run for run in run_ids if run['eval'] == 'sd1.5' and run['type'] == task_type)
        sd2_run = next(run for run in run_ids if run['eval'] == 'sd2' and run['type'] == task_type)
        sd3_run = next(run for run in run_ids if run['eval'] == 'sd3' and run['type'] == task_type)
        
        # Get stats
        sd15_stats = all_stats[sd15_run['id']]
        sd2_stats = all_stats[sd2_run['id']]
        sd3_stats = all_stats[sd3_run['id']]
        
        # Plot per-timestep accuracies on top axis
        ax_top.plot(sd15_stats['timesteps'] / 30, sd15_stats['accuracies'], 
                label='SD1.5', color=color1, linestyle='-')
        ax_top.plot(sd2_stats['timesteps'] / 30, sd2_stats['accuracies'],
                label='SD2', color=color2, linestyle='-')
        ax_top.plot(sd3_stats['timesteps'] / 30, sd3_stats['accuracies'],
                label='SD3', color=color3, linestyle='-')
        
        # Add horizontal lines for overall accuracies
        ax_top.axhline(y=sd15_stats['mean_accuracy'], color=color1, linestyle='--', alpha=0.5)
        ax_top.axhline(y=sd2_stats['mean_accuracy'], color=color2, linestyle='--', alpha=0.5)
        ax_top.axhline(y=sd3_stats['mean_accuracy'], color=color3, linestyle='--', alpha=0.5)
        
        # Calculate softmax probabilities safely using log-sum-exp trick
        def safe_softmax(x, axis=1):
            # Subtract max for numerical stability
            x_max = np.max(x, axis=axis, keepdims=True)
            exp_x = np.exp(x - x_max)
            # Add small epsilon to avoid division by zero
            return exp_x / (np.sum(exp_x, axis=axis, keepdims=True) + 1e-10)
        
        # Get probabilities for individual timesteps
        sd15_probs = safe_softmax(-sd15_stats['errors_array'])
        sd2_probs = safe_softmax(-sd2_stats['errors_array'])
        sd3_probs = safe_softmax(-sd3_stats['errors_array'])
        
        # Get mean top-1 probabilities per timestep
        sd15_top1 = np.mean(np.max(sd15_probs, axis=1), axis=1)
        sd2_top1 = np.mean(np.max(sd2_probs, axis=1), axis=1)
        sd3_top1 = np.mean(np.max(sd3_probs, axis=1), axis=1)
        
        # Plot individual timestep probabilities
        ax_bottom.plot(sd15_stats['timesteps'] / 30, sd15_top1,
                    color=color1, linestyle='-')
        ax_bottom.plot(sd2_stats['timesteps'] / 30, sd2_top1,
                    color=color2, linestyle='-')
        ax_bottom.plot(sd3_stats['timesteps'] / 30, sd3_top1,
                    color=color3, linestyle='-')
        
        # Add horizontal lines for mean probabilities
        # for this, we need to get the actual mean first;)
        error_sd15_mean=np.mean(sd15_stats['errors_array'], axis=0)
        error_sd2_mean=np.mean(sd2_stats['errors_array'], axis=0)
        error_sd3_mean=np.mean(sd3_stats['errors_array'], axis=0)
        
        sd15_mean_softmax=safe_softmax(-error_sd15_mean, axis=0)
        sd2_mean_softmax=safe_softmax(-error_sd2_mean, axis=0)
        sd3_mean_softmax=safe_softmax(-error_sd3_mean, axis=0)
        
        sd15_mean_top1=np.mean(np.max(sd15_mean_softmax, axis=0))
        sd2_mean_top1=np.mean(np.max(sd2_mean_softmax, axis=0))
        sd3_mean_top1=np.mean(np.max(sd3_mean_softmax, axis=0))
        
        ax_bottom.axhline(y=sd15_mean_top1, color=color1, linestyle='--', alpha=0.5)
        ax_bottom.axhline(y=sd2_mean_top1, color=color2, linestyle='--', alpha=0.5)
        ax_bottom.axhline(y=sd3_mean_top1, color=color3, linestyle='--', alpha=0.5)
        
        
        
        
        
        # Formatting
        ax_top.set_title(f'{task_type.replace("_", " ").title()}')
        if task_idx == 0:
            ax_top.set_ylabel('Accuracy')
            ax_bottom.set_ylabel('Top-1 Prob')
        if task_idx == 1:
            ax_top.legend(fontsize=7, frameon=False)
            
        ax_bottom.set_xlabel('Timestep')
        
        # Make y labels small
        ax_top.tick_params(axis='y', labelsize=6)
        ax_bottom.tick_params(axis='y', labelsize=6)
    
    # Adjust figure size and spacing
    fig.set_size_inches(w, h)
    plt.subplots_adjust(left=0.1, right=0.98, top=0.9, bottom=0.15, wspace=0.7, hspace=0.4)
    plt.savefig('figures/accuracy_comparison.pdf', pad_inches=0)
    plt.close()
    
create_accuracy_comparison_plot(all_stats, run_ids)


NameError: name 'all_stats' is not defined

In [134]:
arr=torch.tensor([-0.2638195, -0.2645999, -0.2667893, -0.2674649])

print(torch.nn.functional.softmax(arr, dim=0))



tensor([0.2505, 0.2503, 0.2497, 0.2496])


In [None]:
# Take only SD3 runs of few categories

run_ids = [
    {"gen": "sd3", "eval":"sd2", "type": "position", "id": "aut3pbwh"},
    {"gen": "sd3", "eval":"sd2", "type": "counting", "id": "lxu0ohji"},
    {"gen": "sd3", "eval":"sd2", "type": "color_attr", "id": "8axobt1l"},
    
    {"gen": "sd3", "eval":"sd1.5", "type": "position", "id": "curxv9va"},
    {"gen": "sd3", "eval":"sd1.5", "type": "counting", "id": "9j8saxms"},
    {"gen": "sd3", "eval":"sd1.5", "type": "color_attr", "id": "ok41tk28"},
    
    {"gen": "sd3", "eval":"sd3", "type": "position", "id": "4xuy2o92"},
    {"gen": "sd3", "eval":"sd3", "type": "counting", "id": "1tcxc2rz"},
    {"gen": "sd3", "eval":"sd3", "type": "color_attr", "id": "znbz2uhj"},
]

