In [3]:
""" Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. 
"""

' Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. \n'

In [165]:
""" Simple idea: check variances of errors at each timestep. If the variance is high, that step infleunces the error more. Therefore timestep weighting makes sense. 
"""
import torch
import os
import glob
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
def load_and_evaluate_noises(noise_dir, wandb_run_id, similarity_method='l2', use_scores=False):
    """
    Load and evaluate noises/scores from a specific directory.
    If use_scores=True, evaluates precomputed scores while also tracking influences.
    Otherwise, computes errors from noises using L2 or cosine distance.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Get all relevant files
    noise_files = sorted(
        [f for f in glob.glob(os.path.join(noise_dir, "*.pt")) 
         if not any(x in f for x in ["target_gaussian_noises.pt", "target_gaussian_noises_batch"])],
        key=lambda x: int(x.split('batch')[-1].split('_')[0])
    )

    if use_scores:
        # Load scores and concatenate
        all_scores_with_timestep = []
        all_correct_indices = []
        all_timesteps = None

        for noise_file in tqdm(noise_files, desc=f"Loading score files from {os.path.basename(noise_dir)}"):
            data = torch.load(noise_file)
            scores_with_timestep = data['scores_with_timestep'].numpy()
            correct_indices = data['correct_indices']
            
            all_scores_with_timestep.append(scores_with_timestep)
            all_correct_indices.extend(correct_indices)
            
            if all_timesteps is None:
                all_timesteps = data['all_timesteps_used']
            else:
                assert np.array_equal(all_timesteps, data['all_timesteps_used'])

        # Concatenate along batch dimension
        scores = np.concatenate(all_scores_with_timestep, axis=2)  # [T, P, total_samples]
        correct_indices = np.array(all_correct_indices)
        num_timesteps, num_prompts, num_samples = scores.shape

        # Track statistics per timestep
        true_class_errors = []
        true_class_stds = []
        incorrect_class_errors = []
        incorrect_class_stds = []
        accuracies_per_timestep = []
        all_errors = []
        influences = []
        influence_stds = []
        per_sample_error_gap = []

        # Analyze each timestep
        for t in range(num_timesteps):
            timestep_scores = scores[t]  # [num_prompts, total_samples]
            
            # Get errors for correct and incorrect classes
            correct_scores = np.array([timestep_scores[correct_indices[i], i] for i in range(num_samples)])
            incorrect_scores = []
            for i in range(num_samples):
                mask = np.ones(num_prompts, dtype=bool)
                mask[correct_indices[i]] = False
                incorrect_scores.append(timestep_scores[mask, i])
            incorrect_scores = np.array(incorrect_scores)

            # Calculate statistics
            true_class_errors.append(np.mean(correct_scores))
            true_class_stds.append(np.std(correct_scores))
            incorrect_class_errors.append(np.mean(incorrect_scores))
            incorrect_class_stds.append(np.std(incorrect_scores))

            # Calculate accuracy
            min_scores = np.min(timestep_scores, axis=0)
            predictions = np.argmin(timestep_scores, axis=0)
            accuracy = np.mean(predictions == correct_indices)
            accuracies_per_timestep.append(accuracy)
            
            # Store errors for influence calculation
            all_errors.append(timestep_scores)

            # Calculate per-sample error gap
            incorrect_min_scores = np.array([np.mean(timestep_scores[mask, i]) for i, mask in 
                                          enumerate([np.ones(num_prompts, dtype=bool) & (np.arange(num_prompts) != ci) 
                                                   for ci in correct_indices])])
            error_gap = incorrect_min_scores - correct_scores
            per_sample_error_gap.append(error_gap)

        # Stack all errors and calculate influences
        all_errors = np.stack(all_errors)  # [T, P, samples]
        summed_errors = np.sum(all_errors, axis=0)  # [P, samples]
        
        # Calculate influences
        for t in range(num_timesteps):
            current_timestep = all_errors[t]  # [P, samples]
            other_timesteps_sum = summed_errors - current_timestep  # [P, samples]
            
            sample_deltas = []
            for sample_idx in range(num_samples):
                other_errors = other_timesteps_sum[:, sample_idx]
                current_errors = current_timestep[:, sample_idx]
                total_errors = other_errors + current_errors
                current_pred = np.argmin(total_errors)
                
                min_delta = float('inf')
                for target_class in range(num_prompts):
                    if target_class == current_pred:
                        continue
                    margin = (other_errors[current_pred] + current_errors[current_pred]) - (other_errors[target_class] + current_errors[target_class])
                    delta_needed = abs(margin)
                    min_delta = min(min_delta, delta_needed)
                sample_deltas.append(min_delta)
            
            influences.append(np.mean(sample_deltas))
            influence_stds.append(np.std(sample_deltas))

        # Calculate overall accuracy
        avg_scores = np.mean(scores, axis=0)  # [P, samples]
        predictions = np.argmin(avg_scores, axis=0)
        overall_acc = np.mean(predictions == correct_indices)

        return {
            'true_errors': np.array(true_class_errors),
            'true_stds': np.array(true_class_stds),
            'incorrect_errors': np.array(incorrect_class_errors),
            'incorrect_stds': np.array(incorrect_class_stds),
            'accuracies': np.array(accuracies_per_timestep),
            'timesteps': all_timesteps,
            'num_samples': num_samples,
            'mean_accuracy': overall_acc,
            'summed_errors': summed_errors,
            'influences': np.array(influences),
            'influence_stds': np.array(influence_stds),
            'scores_shape': scores.shape,
            'per_sample_error_gap': np.array(per_sample_error_gap)  # [num_timesteps, batch_size]
        }

    else:
        # Original noise-based logic
        all_cond_noises = []
        all_uncond_noises = []
        for noise_file in tqdm(noise_files, desc=f"Loading noise files from {os.path.basename(noise_dir)}"):
            data = torch.load(noise_file)
            all_cond_noises.append(data['conditional_noises'])
            all_uncond_noises.append(data['unconditional_noises'])

        cond_noises = torch.cat(all_cond_noises, dim=2).to(device)
        uncond_noises = torch.cat(all_uncond_noises, dim=2).to(device)

        if len(cond_noises.shape) == 6:
            cond_noises = cond_noises.reshape(cond_noises.shape[0], cond_noises.shape[1], cond_noises.shape[2], -1)
            uncond_noises = uncond_noises.reshape(uncond_noises.shape[0], uncond_noises.shape[1], uncond_noises.shape[2], -1)

        print(f"Initial shapes:")
        print(f"cond_noises: {cond_noises.shape}")
        print(f"uncond_noises: {uncond_noises.shape}")

        num_timesteps = cond_noises.shape[1]
        num_samples = cond_noises.shape[2]
        num_classes = cond_noises.shape[0]

        true_class_errors = []
        true_class_stds = []
        incorrect_class_errors = []
        incorrect_class_stds = []
        accuracies_per_timestep = []
        all_errors = []
        per_sample_error_gap = []

        for t in range(num_timesteps):
            cond_t = cond_noises[:, t]
            print(f"\nShapes at timestep {t}:")
            print(f"cond_t: {cond_t.shape}")

            if similarity_method == 'l2':
                if len(cond_t.shape) == 3:
                    cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)
                dists = torch.norm(cond_t.to(torch.float32), p=2, dim=-1)
                errors = dists
            else:
                if len(cond_t.shape) == 3:
                    cond_t = cond_t.reshape(cond_t.shape[0], cond_t.shape[1], -1)
                cond_t_norm = torch.nn.functional.normalize(cond_t, p=2, dim=-1)
                similarity = (cond_t_norm * cond_t_norm).sum(dim=-1)
                errors = -similarity

            all_errors.append(errors)
            correct_predictions = (errors[0] < errors[1:].min(dim=0)[0])
            accuracies_per_timestep.append(correct_predictions.float().mean().item())

            true_class_mean = errors[0].mean().item()
            true_class_std = errors[0].std().item()
            true_class_errors.append(true_class_mean)
            true_class_stds.append(true_class_std)

            incorrect_mean = errors[1:].mean().item()
            incorrect_std = errors[1:].std().item()
            incorrect_class_errors.append(incorrect_mean)
            incorrect_class_stds.append(incorrect_std)

            # Calculate per-sample error gap
            true_class_errors_per_sample = errors[0]  # [batch_size]
            incorrect_class_errors_per_sample = errors[1:].min(dim=0)[0]  # [batch_size]
            error_gap = incorrect_class_errors_per_sample - true_class_errors_per_sample  # [batch_size]
            per_sample_error_gap.append(error_gap.cpu().numpy())

        all_errors = torch.stack(all_errors)
        summed_errors = all_errors.sum(dim=0)
        correct_predictions_total = (summed_errors[0] < summed_errors[1:].min(dim=0)[0])
        mean_accuracy = correct_predictions_total.float().mean().item()

        influences = []
        influence_stds = []
        for t in range(num_timesteps):
            current_timestep = all_errors[t]
            other_timesteps_sum = summed_errors - current_timestep

            sample_deltas = []
            for sample_idx in range(num_samples):
                other_errors = other_timesteps_sum[:, sample_idx].cpu()
                current_errors = current_timestep[:, sample_idx].cpu()
                total_errors = other_errors + current_errors
                current_pred = total_errors.argmin().item()

                min_delta = float('inf')
                for target_class in range(num_classes):
                    if target_class == current_pred:
                        continue
                    margin = ((other_errors[current_pred] + current_errors[current_pred])
                            - (other_errors[target_class] + current_errors[target_class]))
                    delta_needed = abs(margin)
                    min_delta = min(min_delta, delta_needed)
                sample_deltas.append(min_delta)

            influences.append(np.mean(sample_deltas))
            influence_stds.append(np.std(sample_deltas))

        return {
            'true_errors': np.array(true_class_errors),
            'true_stds': np.array(true_class_stds),
            'incorrect_errors': np.array(incorrect_class_errors),
            'incorrect_stds': np.array(incorrect_class_stds),
            'accuracies': np.array(accuracies_per_timestep),
            'timesteps': np.arange(num_timesteps),
            'num_samples': num_samples,
            'mean_accuracy': mean_accuracy,
            'summed_errors': summed_errors.cpu().numpy(),
            'influences': np.array(influences),
            'influence_stds': np.array(influence_stds),
            'per_sample_error_gap': np.array(per_sample_error_gap)  # [num_timesteps, batch_size]
        }





In [166]:
class Args:
    def __init__(self):
        # self.base_dir = '/mnt/lustre/work/oh/owl661/compositional-vaes/noise_results'
        self.base_dir = '/mnt/lustre/work/oh/owl661/compositional-vaes/score_results'
        self.run_id = None
        self.similarity = 'l2'

args = Args()

if args.run_id:
    # Find the directory containing this run_id
    run_dirs = []
    for root, dirs, files in os.walk(args.base_dir):
        if args.run_id in dirs:
            run_dirs.append(os.path.join(root, args.run_id))
    
    if not run_dirs:
        print(f"No directory found for run_id: {args.run_id}")
        exit(1)
        
    # Process the specified run
    for run_dir in run_dirs:
        print(f"\nEvaluating run: {args.run_id}")
        stats = load_and_evaluate_noises(run_dir, args.run_id, args.similarity)
        if stats is None:
            continue
        
        print(f"\nResults for {args.run_id}:")
        print(f"Total samples: {stats['num_samples']}")
        print(f"Mean error across all timesteps: {stats['true_errors'].mean():.4f}")
        print(f"Mean std across all timesteps: {stats['incorrect_errors'].mean():.4f}")
        print(f"Max std at timestep {stats['timesteps'][stats['incorrect_errors'].argmax()]}: {stats['incorrect_errors'].max():.4f}")
        print(f"Min std at timestep {stats['timesteps'][stats['incorrect_errors'].argmin()]}: {stats['incorrect_errors'].min():.4f}")
        print(f"Mean accuracy: {stats['mean_accuracy']:.3f}")
else:
    # Process all runs in run_ids list and collect stats
    all_stats = {}
    for run_info in run_ids:
        run_id = run_info['id']
        run_dir = None
        
        # Find the directory for this run_id
        for root, dirs, files in os.walk(args.base_dir):
            if run_id in dirs:
                run_dir = os.path.join(root, run_id)
                break
        
        if run_dir is None:
            print(f"No directory found for run_id: {run_id}")
            continue
            
        print(f"\nEvaluating run: {run_id} ({run_info['gen']}->{run_info['eval']}, {run_info['type']})")
        stats = load_and_evaluate_noises(run_dir, run_id, args.similarity, use_scores=True)
        if stats is None:
            continue
        
        # Store stats for this run
        all_stats[run_id] = stats
        
        # print(f"Total samples: {stats['num_samples']}")
        print(f"Mean error across all timesteps: {stats['true_errors'].mean():.4f}")
        print(f"Mean std across all timesteps: {stats['incorrect_errors'].mean():.4f}")
        print(f"Max std at timestep {stats['timesteps'][stats['incorrect_errors'].argmax()]}: {stats['incorrect_errors'].max():.4f}")
        print(f"Min std at timestep {stats['timesteps'][stats['incorrect_errors'].argmin()]}: {stats['incorrect_errors'].min():.4f}")
        print(f"Mean accuracy: {stats['mean_accuracy']:.3f}")
    
    # Create aggregate comparison plot
    create_aggregate_comparison_plot(all_stats, run_ids)




Evaluating run: an1a08rx (sd3->sd2, counting)


  data = torch.load(noise_file)
Loading score files from an1a08rx: 100%|██████████| 58/58 [00:00<00:00, 647.79it/s]


Mean error across all timesteps: 34.6992
Mean std across all timesteps: 34.7312
Max std at timestep 19.0: 88.6906
Min std at timestep 989.0: 6.8932
Mean accuracy: 0.639

Evaluating run: jyo8eg54 (sd3->sd2, single)


Loading score files from jyo8eg54: 100%|██████████| 79/79 [00:00<00:00, 481.72it/s]


Mean error across all timesteps: 35.8252
Mean std across all timesteps: 36.1726
Max std at timestep 19.0: 92.5875
Min std at timestep 989.0: 7.0972
Mean accuracy: 0.997

Evaluating run: 3me8aw5a (sd3->sd2, two)


Loading score files from 3me8aw5a: 100%|██████████| 33/33 [00:00<00:00, 1135.64it/s]


Mean error across all timesteps: 38.2739
Mean std across all timesteps: 38.4998
Max std at timestep 19.0: 97.0145
Min std at timestep 989.0: 7.9158
Mean accuracy: 1.000

Evaluating run: mst9bl9p (sd2->sd3, counting)


Loading score files from mst9bl9p: 100%|██████████| 28/28 [00:00<00:00, 1300.37it/s]


Mean error across all timesteps: 262.9860
Mean std across all timesteps: 263.5419
Max std at timestep 84.9056625366211: 449.8859
Min std at timestep 771.8446044921875: 211.7143
Mean accuracy: 0.685

Evaluating run: 01dpmapt (sd2->sd3, single)


Loading score files from 01dpmapt: 100%|██████████| 34/34 [00:00<00:00, 489.26it/s]


Mean error across all timesteps: 256.2554
Mean std across all timesteps: 259.4910
Max std at timestep 84.9056625366211: 447.0573
Min std at timestep 778.8462524414062: 206.8103
Mean accuracy: 0.941

Evaluating run: 8cv1l6b6 (sd2->sd3, two)


Loading score files from 8cv1l6b6: 100%|██████████| 26/26 [00:00<00:00, 442.36it/s]


Mean error across all timesteps: 255.1764
Mean std across all timesteps: 257.0352
Max std at timestep 84.9056625366211: 440.5613
Min std at timestep 757.4257202148438: 202.5320
Mean accuracy: 0.779


In [167]:
# plt stuff - hide top right splines
import seaborn as sns
sns.reset_defaults()
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

# global sizes of font 8
plt.rcParams['font.size'] = 8
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 8
plt.rcParams['legend.fontsize'] = 8
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8

In [168]:
def create_aggregate_comparison_plot(all_stats, run_ids):
    """Create aggregate plots comparing SD2 and SD3 evaluations for each task type."""
    task_types = ['counting', 'position', 'color_attr']
    task_types = ['single', 'two', 'counting']
    task_name_replacer = {
        'single': 'Single object',
        'two': 'Two objects',
        'counting': 'Counting',
        'position': 'Position',
        'color_attr': 'Color attribute',
    }
    w,h=3.25,0.7
    fig, axes = plt.subplots(1, 3, figsize=(w,h), sharey=True)
    
    # Get colorblind-friendly colors
    colors = sns.color_palette("colorblind")
    color1, color2 = colors[0], colors[1]
    
    # Configure splines - only show x-axis at 0
    for ax in axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_position(('data', 0))
        ax.spines['bottom'].set_color('black')
        ax.spines['bottom'].set_linewidth(0.5)
        # Hide overlapping ticks
        ax.tick_params(axis='both', length=0, width=0.5)
        # Only show 0 on y-axis
        ax.set_yticks([0])
    
    for task_idx, task_type in enumerate(task_types):
        ax = axes[task_idx]
        
        # Get runs for this task type
        sd2_run = next(run for run in run_ids if run['eval'] == 'sd2' and run['type'] == task_type)
        sd3_run = next(run for run in run_ids if run['eval'] == 'sd3' and run['type'] == task_type)
        
        # Get stats
        sd2_stats = all_stats[sd2_run['id']]
        sd3_stats = all_stats[sd3_run['id']]

        # Normalize errors by max absolute value
        max_abs_error_sd2=np.max(np.abs(sd2_stats['per_sample_error_gap']))
        max_abs_error_sd3=np.max(np.abs(sd3_stats['per_sample_error_gap'])) 
        
        # Plot SD2 normalized errors
        ax.plot(np.array(sd2_stats['timesteps']) / 1000, np.array(sd2_stats['per_sample_error_gap'].mean(axis=1)  ) / max_abs_error_sd2,
                color=color1, linestyle='-', linewidth=1, label='SD2', zorder=10)
        ax.fill_between(np.array(sd2_stats['timesteps']) / 1000 ,
                       (np.array(sd2_stats['per_sample_error_gap'].mean(axis=1)) - np.array(sd2_stats['per_sample_error_gap'].std(axis=1)   )) / max_abs_error_sd2,
                       (np.array(sd2_stats['per_sample_error_gap'].mean(axis=1)) + np.array(sd2_stats['per_sample_error_gap'].std(axis=1)   )) / max_abs_error_sd2,
                       alpha=0.1, color=color1, edgecolor=None)
        
        # Plot SD3 normalized errors
        ax.plot(np.array(sd3_stats['timesteps']) / 1000, np.array(sd3_stats['per_sample_error_gap'].mean(axis=1)  ) / max_abs_error_sd3 ,
                color=color2, linestyle='-', linewidth=1, label='SD3', zorder=10)
        ax.fill_between(np.array(sd3_stats['timesteps'])     / 1000,
                       (np.array(sd3_stats['per_sample_error_gap'].mean(axis=1)) - np.array(sd3_stats['per_sample_error_gap'].std(axis=1))   ) / max_abs_error_sd3,
                       (np.array(sd3_stats['per_sample_error_gap'].mean(axis=1)) + np.array(sd3_stats['per_sample_error_gap'].std(axis=1))   ) / max_abs_error_sd3,
                       alpha=0.1, color=color2, edgecolor=None)
        
        ax.set_title(f'{task_name_replacer[task_type]}', fontsize=7, pad=-20)
        if task_idx == 0:
            ax.set_ylabel('Error\ngap$_{\\text{pos}-\\text{neg}}$', fontsize=5, labelpad=1)
        # if task_idx == 1:
        #     ax.legend(fontsize=6, frameon=False, loc='lower center', bbox_to_anchor=(0.5, -0.5), ncol=2)
        ax.set_xlabel('Timestep', fontsize=5, labelpad=-5)

        ax.locator_params(axis='x', nbins=2)
        
        ax.tick_params(axis='x', labelsize=4)
        ax.tick_params(axis='y', labelsize=4)

    # resize to w, h, force
    fig.set_size_inches(w, h)
    # subplot adjust
    plt.subplots_adjust(left=0.08, right=0.98, top=0.87, bottom=0.0, wspace=0.1)
    plt.savefig('figures/aggregate_comparison.pdf', pad_inches=0)
    plt.close()


In [169]:
create_aggregate_comparison_plot(all_stats, run_ids)

In [170]:
def create_accuracy_comparison_plot(all_stats, run_ids):
    """Create plots comparing per-timestep and overall accuracy for SD1.5, SD2, and SD3."""
    task_types = ['counting', 'position', 'color_attr']
    task_types = ['single', 'two', 'counting']
    w, h = 3.25, 0.75 # Increased height to accommodate second row
    show_softmax = False
    
    task_name_replacer = {
        'single': 'Single object',
        'two': 'Two objects',
        'counting': 'Counting',
        'position': 'Position',
        'color_attr': 'Color attribute',
    }
    
    if show_softmax:
        fig, axes = plt.subplots(2, 3, figsize=(w, h))
    else:
        fig, axes = plt.subplots(1, 3, figsize=(w, h), sharey=True)
        axes = axes[np.newaxis, :]  # Add dimension to match 2D case
    
    # Get colorblind-friendly colors
    colors = sns.color_palette("colorblind")
    color_sd15, color_sd2, color_sd3 = colors[2], colors[0], colors[1]
    for ax in axes.flatten():
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        # ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_position(('data', 0))
        # ax.spines['bottom'].set_color('black')
        ax.spines['bottom'].set_linewidth(0.5)
        ax.spines['left'].set_linewidth(0.5)
        # Hide overlapping ticks
        # ax.tick_params(axis='both', length=0, width=0.5)
        # Only show 0 on y-axis
        # ax.set_yticks([0])
    # Make tick marks shorter
    for ax in axes.flatten():
        ax.tick_params(axis='both', length=2, width=0.5, labelsize=4)
    
    
    for task_idx, task_type in enumerate(task_types):
        ax_top = axes[0, task_idx]
        
        # Get runs for this task type
        # sd15_run = next(run for run in run_ids if run['eval'] == 'sd1.5' and run['type'] == task_type)
        sd2_run = next(run for run in run_ids if run['eval'] == 'sd2' and run['type'] == task_type)
        sd3_run = next(run for run in run_ids if run['eval'] == 'sd3' and run['type'] == task_type)
        
        # Get stats
        # sd15_stats = all_stats[sd15_run['id']]
        sd2_stats = all_stats[sd2_run['id']]
        sd3_stats = all_stats[sd3_run['id']]
        
        # Plot per-timestep accuracies on top axis
        # ax_top.plot(sd15_stats['timesteps'] / 30, sd15_stats['accuracies'], 
        #         label='SD1.5', color=color_sd15, linestyle='-', linewidth=1)
        ax_top.plot(np.array(sd2_stats['timesteps']) / 1000, np.array(sd2_stats['accuracies'])    ,
                label='SD2 model, SD2 generations', color=color_sd2, linestyle='-', linewidth=1)
        ax_top.plot(np.array(sd3_stats['timesteps']) / 1000, np.array(sd3_stats['accuracies']),
                label='SD3 model, SD3 generations', color=color_sd3, linestyle='-', linewidth=1)
        
        # Add horizontal lines for overall accuracies
        # ax_top.axhline(y=sd15_stats['mean_accuracy'], color=color_sd15, linestyle='--', alpha=0.5)
        # ax_top.axhline(y=sd2_stats['mean_accuracy'], color=color_sd2, linestyle='--', alpha=0.5)
        # ax_top.axhline(y=sd3_stats['mean_accuracy'], color=color_sd3, linestyle='--', alpha=0.5)

        if show_softmax:
            ax_bottom = axes[1, task_idx]
            
            # Calculate softmax probabilities safely using log-sum-exp trick
            def safe_softmax(x, axis=1):
                # Subtract max for numerical stability
                x_max = np.max(x, axis=axis, keepdims=True)
                exp_x = np.exp(x - x_max)
                # Add small epsilon to avoid division by zero
                return exp_x / (np.sum(exp_x, axis=axis, keepdims=True) + 1e-10)
            
            # Get probabilities for individual timesteps
            # sd15_probs = safe_softmax(-sd15_stats['errors_array'])
            sd2_probs = safe_softmax(-sd2_stats['errors_array'])
            sd3_probs = safe_softmax(-sd3_stats['errors_array'])
            
            # Get mean top-1 probabilities per timestep
            # sd15_top1 = np.mean(np.max(sd15_probs, axis=1), axis=1)
            sd2_top1 = np.mean(np.max(sd2_probs, axis=1), axis=1)
            sd3_top1 = np.mean(np.max(sd3_probs, axis=1), axis=1)
            
            # Plot individual timestep probabilities
            # ax_bottom.plot(sd15_stats['timesteps'] / 30, sd15_top1,
            #             color=color_sd15, linestyle='-', linewidth=1)
            ax_bottom.plot(sd2_stats['timesteps'] / 100, sd2_top1,
                        color=color_sd2, linestyle='-', linewidth=1)
            ax_bottom.plot(sd3_stats['timesteps'] / 100, sd3_top1,
                        color=color_sd3, linestyle='-', linewidth=1)
            
            # Add horizontal lines for mean probabilities
            # error_sd  15_mean = np.mean(sd15_stats['errors_array'], axis=0)
            error_sd2_mean = np.mean(sd2_stats['errors_array'], axis=0)
            error_sd3_mean = np.mean(sd3_stats['errors_array'], axis=0)
            
            # sd15_mean_softmax = safe_softmax(-error_sd15_mean, axis=0)
            sd2_mean_softmax = safe_softmax(-error_sd2_mean, axis=0)
            sd3_mean_softmax = safe_softmax(-error_sd3_mean, axis=0)
            
            # sd15_mean_top1 = np.mean(np.max(sd15_mean_softmax, axis=0))
            sd2_mean_top1 = np.mean(np.max(sd2_mean_softmax, axis=0))
            sd3_mean_top1 = np.mean(np.max(sd3_mean_softmax, axis=0))
            
            ax_bottom.set_xlabel('Timestep', fontsize=5)
            if task_idx == 0:
                ax_bottom.set_ylabel('Top-1 Prob', fontsize=6)
            ax_bottom.tick_params(axis='y', labelsize=4)
        # set tick params to small
        ax_top.tick_params(axis='x', labelsize=4)
        ax_top.tick_params(axis='y', labelsize=4)
        
        # Formatting
        ax_top.set_title(f'{task_name_replacer[task_type]}', fontsize=7, pad=-20)
        if task_idx == 0:
            ax_top.set_ylabel('Accuracy', fontsize=5, labelpad=0)
            
        if not show_softmax:
            ax_top.set_xlabel('Timestep', fontsize=5, labelpad=0)
        
        # Make y labels small
        ax_top.tick_params(axis='y', labelsize=4)
        if task_idx == 0:
            # Create legend below the plots
            if show_softmax:
                fig.legend(fontsize=6, frameon=False, loc='lower center', bbox_to_anchor=(0.5, 0.0), ncol=3)
            else:
                fig.legend(fontsize=6, frameon=False, loc='lower center', bbox_to_anchor=(0.5, -0.1), ncol=3)
    # y grid
    for ax in axes.flatten():
        ax.grid(which='major', axis='y', linestyle='-', alpha=0.2)
    # Adjust figure size and spacing
    fig.set_size_inches(w, h)
    if show_softmax:
        plt.subplots_adjust(left=0.15, right=0.98, top=0.9, bottom=0.25, wspace=0.7, hspace=0.4)
    else:
        # plt.subplots_adjust(left=0.14, right=0.98, top=0.85, bottom=0.4, wspace=0.1)
        plt.subplots_adjust(left=0.08, right=0.98, top=0.87, bottom=0.4, wspace=0.1)

    plt.savefig('figures/accuracy_comparison.pdf', pad_inches=0)
    plt.close()
    
create_accuracy_comparison_plot(all_stats, run_ids)


In [163]:
print(run_ids)

[{'gen': 'sd2', 'eval': 'sd2', 'type': 'counting', 'id': 'rf5u3pz9'}, {'gen': 'sd2', 'eval': 'sd2', 'type': 'single', 'id': 'pjyhjvyw'}, {'gen': 'sd2', 'eval': 'sd2', 'type': 'two', 'id': '3me8aw5a'}, {'gen': 'sd3', 'eval': 'sd3', 'type': 'counting', 'id': 'mst9bl9p'}, {'gen': 'sd3', 'eval': 'sd3', 'type': 'single', 'id': 'dvxgc18d'}, {'gen': 'sd3', 'eval': 'sd3', 'type': 'two', 'id': 'ioer1zo7'}]


In [161]:
arr=torch.tensor([-0.2638195, -0.2645999, -0.2667893, -0.2674649])

print(torch.nn.functional.softmax(arr, dim=0))



tensor([0.2505, 0.2503, 0.2497, 0.2496])


In [164]:
# Take only SD3 runs of few categories

run_ids = [
    # {"gen": "sd3", "eval":"sd2", "type": "position", "id": "aut3pbwh"},
    # {"gen": "sd3", "eval":"sd2", "type": "counting", "id": "lxu0ohji"},
    # {"gen": "sd3", "eval":"sd2", "type": "color_attr", "id": "8axobt1l"},
    
    # {"gen": "sd3", "eval":"sd1.5", "type": "position", "id": "curxv9va"},
    # {"gen": "sd3", "eval":"sd1.5", "type": "counting", "id": "9j8saxms"},
    # {"gen": "sd3", "eval":"sd1.5", "type": "color_attr", "id": "ok41tk28"},
    
    # {"gen": "sd3", "eval":"sd3", "type": "position", "id": "4xuy2o92"},
    # {"gen": "sd3", "eval":"sd3", "type": "counting", "id": "1tcxc2rz"},
    # {"gen": "sd3", "eval":"sd3", "type": "color_attr", "id": "znbz2uhj"},
    
    {"gen": "sd3", "eval":"sd2", "type": "counting", "id": "an1a08rx"},
    {"gen": "sd3", "eval":"sd2", "type": "single", "id": "jyo8eg54"},
    {"gen": "sd3", "eval":"sd2", "type": "two", "id": "3me8aw5a"},
    
    {"gen": "sd2", "eval":"sd3", "type": "counting", "id": "mst9bl9p"},
    {"gen": "sd2", "eval":"sd3", "type": "single", "id": "01dpmapt"},
    {"gen": "sd2", "eval":"sd3", "type": "two", "id": "8cv1l6b6"},
]

#self domani
# run_ids  = [
#     {"gen": "sd2", "eval":"sd2", "type": "counting", "id": "rf5u3pz9"},
#     {"gen": "sd2", "eval":"sd2", "type": "single", "id": "pjyhjvyw"},
#     {"gen": "sd2", "eval":"sd2", "type": "two", "id": "3me8aw5a"},
    
#     {"gen": "sd3", "eval":"sd3", "type": "counting", "id": "mst9bl9p"},
#     {"gen": "sd3", "eval":"sd3", "type": "single", "id": "dvxgc18d"},
#     {"gen": "sd3", "eval":"sd3", "type": "two", "id": "ioer1zo7"},
# ]

