In [1]:
"""Script to visualize learned timestep weights across different models and task types.
Loads results from pickle files generated by learn_timestep_weights.py and creates comparison plots.
"""
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Base path for processed results
RESULTS_DIR = '/mnt/lustre/work/oh/owl661/compositional-vaes/src/vqvae/_post/self_bench/analysis_cfg/processed_runs'

# Define the run configurations we want to analyze
run_ids = [
    {"gen": "sd3", "eval": "sd2", "type": "position", "id": "aut3pbwh"},
    {"gen": "sd3", "eval": "sd2", "type": "counting", "id": "lxu0ohji"},
    {"gen": "sd3", "eval": "sd2", "type": "color_attr", "id": "8axobt1l"},
    
    {"gen": "sd3", "eval": "sd1.5", "type": "position", "id": "curxv9va"},
    {"gen": "sd3", "eval": "sd1.5", "type": "counting", "id": "9j8saxms"},
    {"gen": "sd3", "eval": "sd1.5", "type": "color_attr", "id": "ok41tk28"},
    
    {"gen": "sd3", "eval": "sd3", "type": "counting", "id": "1tcxc2rz"},
    {"gen": "sd3", "eval": "sd3", "type": "color_attr", "id": "znbz2uhj"},
    {"gen": "sd3", "eval": "sd3", "type": "position", "id": "4xuy2o92"},
    
    {"gen": "whatsup-a", "eval": "sd1.5", "type": "whatsup-a", "id": "gtqrvf9c"},
    {"gen": "whatsup-a", "eval": "sd2", "type": "whatsup-a", "id": "qtwnqk5c"},
    {"gen": "whatsup-a", "eval": "sd3", "type": "whatsup-a", "id": "2zfmbdgb"},
    
    {"gen": "whatsup-b", "eval": "sd1.5", "type": "whatsup-b", "id": "vqudhuh5"},
    {"gen": "whatsup-b", "eval": "sd2", "type": "whatsup-b", "id": "784a6ywm"},
    {"gen": "whatsup-b", "eval": "sd3", "type": "whatsup-b", "id": "72srd0uj"},
    
    {"gen": "sd1.5-cnt", "eval": "sd1.5", "type": "sd1.5-cnt", "id": "h315owa1"},
    {"gen": "sd1.5-cnt", "eval": "sd2", "type": "sd1.5-cnt", "id": "pctg0k6y"},
    {"gen": "sd1.5-cnt", "eval": "sd3", "type": "sd1.5-cnt", "id": "jx3fdoii"},
    
    {"gen": "sd2-cnt", "eval": "sd1.5", "type": "sd2-cnt", "id": "xc0afsg5"},
    {"gen": "sd2-cnt", "eval": "sd2", "type": "sd2-cnt", "id": "cmlxwe5a"},
    {"gen": "sd2-cnt", "eval": "sd3", "type": "sd2-cnt", "id": "ygjlsq3o"},
    
    
]

def load_all_results():
    """Preload all results into a dictionary."""
    results_cache = {}
    for run in run_ids:
        results_file = os.path.join(RESULTS_DIR, f'{run["id"]}_results.pkl')
        if not os.path.exists(results_file):
            print(f"Warning: Results file not found for run {run['id']}")
            continue
        
        with open(results_file, 'rb') as f:
            results_cache[run['id']] = pickle.load(f)
    
    return results_cache
results_cache = load_all_results()



In [2]:
import seaborn as sns
sns.reset_defaults()
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

# global sizes of font 8
plt.rcParams['font.size'] = 6  # Reduced from 8 to make all text smaller
plt.rcParams['axes.labelsize'] = 4  # Reduced from 5
plt.rcParams['axes.titlesize'] = 6  # Reduced from 8
plt.rcParams['legend.fontsize'] = 6  # Reduced from 8
plt.rcParams['xtick.labelsize'] = 4  # Reduced from 5
plt.rcParams['ytick.labelsize'] = 2  # Reduced further to make y-axis numbers less prominent

In [3]:


def create_comparison_plots(results_cache):
    """Create a 1x3 grid of plots comparing timestep weights across task types."""
    # Set up the plot style
    plt.style.use('seaborn-v0_8-paper')  # Use a clean, publication-ready style
    w,h=5,1.5
    
    
    # Define colors and linestyles for each model
    model_styles = {
        'sd1.5': {'color': sns.color_palette("colorblind")[0], 'linestyle': '-'},
        'sd2': {'color': sns.color_palette("colorblind")[1], 'linestyle': '-'},
        'sd3': {'color': sns.color_palette("colorblind")[2], 'linestyle': '-'}
    }
    
    # Group runs by task type
    task_types = ['position', 'counting', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    models = ['sd1.5', 'sd2', 'sd3']
    
    fig, axes = plt.subplots(1, len(task_types), figsize=(w,h), sharey=True)
    
    # Create plots
    for i, task_type in enumerate(task_types):
        ax = axes[i]
        
        # Plot each model's weights
        for model in models:
            # Find runs for this task type and model
            relevant_runs = [run for run in run_ids if run['type'] == task_type and run['eval'] == model]
            
            if not relevant_runs:
                continue
            
            # Plot results for each relevant run
            for run in relevant_runs:
                if run['id'] not in results_cache:
                    continue
                
                results = results_cache[run['id']]
                weights = results['timestep_weights']
                # Normalize weights by maximum
                weights = weights / weights.max()
                timesteps = np.arange(len(weights))
                
                # Plot weights
                ax.plot(timesteps / len(timesteps), weights, 
                       color=model_styles[model]['color'],
                       linestyle=model_styles[model]['linestyle'],
                       alpha=0.8,
                       label=f"{model.upper()}")
        
        # Customize plot
        ax.set_title(f'{task_type.replace("_", " ").title()}')
        ax.set_xlabel('Timestep')
        if i == 0:  # Only leftmost plot
            ax.set_ylabel('Weight')
        # ax.grid(True, alpha=0.2)
        # ax.legend(fontsize=8, frameon=False, loc='upper right')
        
        if i == 1:
            # legend
            ax.legend(fontsize=6, frameon=False,)
        # Set axis limits
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        
        # Add minor grid
        # ax.grid(True, which='minor', alpha=0.1)
        # ax.grid(True, which='major', alpha=0.2)
    
    # plt.tight_layout()
    # Format axis ticks to show only 1 decimal digit
    for ax in axes:
        ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0f'))
        ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.0f'))
    # Create figures directory if it doesn't exist
    os.makedirs('figures', exist_ok=True)
    # set w, h
    plt.subplots_adjust(left=0.15, right=0.97, top=0.8, bottom=0.3)
    plt.gcf().set_size_inches(w, h)

    plt.savefig('figures/timestep_weights_comparison.pdf', pad_inches=0, dpi=300)
    plt.close()

def create_accuracy_summary(results_cache):
    """Create a summary table of accuracies before and after weighting."""
    print("\nAccuracy Summary:")
    print("-" * 80)
    print(f"{'Task Type':<12} {'Model':<8} {'Original':<10} {'Weighted':<10} {'Improvement':<12}")
    print("-" * 80)
    
    # Store results for averaging
    results_by_type = {}
    
    for task_type in ['position', 'counting', 'color_attr']:
        results_by_type[task_type] = []
        for model in ['sd1.5', 'sd2', 'sd3']:
            runs = [run for run in run_ids if run['type'] == task_type and run['eval'] == model]
            for run in runs:
                if run['id'] not in results_cache:
                    continue
                
                results = results_cache[run['id']]
                orig_acc = results['accuracy_original']
                best_acc = results['accuracy_best']
                improvement = (best_acc - orig_acc) * 100
                
                results_by_type[task_type].append({
                    'model': model,
                    'orig_acc': orig_acc,
                    'best_acc': best_acc,
                    'improvement': improvement
                })
                
                print(f"{task_type:<12} {model:<8} {orig_acc:.3f}     {best_acc:.3f}     {improvement:+.1f}%")
    
    print("-" * 80)
    print("Averages by task type:")
    for task_type, results in results_by_type.items():
        if results:
            avg_improvement = np.mean([r['improvement'] for r in results])
            avg_orig = np.mean([r['orig_acc'] for r in results])
            avg_best = np.mean([r['best_acc'] for r in results])
            print(f"{task_type:<12} {'AVG':<8} {avg_orig:.3f}     {avg_best:.3f}     {avg_improvement:+.1f}%")
    print("-" * 80)

create_comparison_plots(results_cache)


In [69]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def create_comparison_plots(results_cache):
    """Create plots with timestep weights and accuracy bars."""
    # Set up the plot style
    plt.style.use('seaborn-v0_8-paper')
    w, h = 6, 1.5
    
    # Create figure with GridSpec for layout control
    fig = plt.figure(figsize=(w, h))
    
    # Define colors for each model
    model_styles = {
        'sd1.5': {'color': sns.color_palette("colorblind")[0], 'linestyle': '-'},
        'sd2': {'color': sns.color_palette("colorblind")[1], 'linestyle': '-'},
        'sd3': {'color': sns.color_palette("colorblind")[2], 'linestyle': '-'}
    }
    
    # task_types = ['position', 'counting', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    task_types = ['position', 'counting', 'color_attr', 'whatsup-a', 'whatsup-b', 'sd1.5-cnt', 'sd2-cnt']
    models = ['sd1.5', 'sd2', 'sd3']
    
    gs = GridSpec(2, len(task_types), height_ratios=[2, 1], hspace=0.8)
    
    # Create plots
    for i, task_type in enumerate(task_types):
        # Top subplot for weights
        ax_top = fig.add_subplot(gs[0, i])
        # Bottom subplot for accuracies
        ax_bottom = fig.add_subplot(gs[1, i])
        
        accuracies = {'uniform': [], 'learned': []}
        
        # Plot weights and collect accuracies
        for model in models:
            relevant_runs = [run for run in run_ids if run['type'] == task_type and run['eval'] == model]
            
            if not relevant_runs:
                continue
            
            for run in relevant_runs:
                if run['id'] not in results_cache:
                    continue
                
                results = results_cache[run['id']]
                
                # Plot weights
                weights = results['timestep_weights']
                weights = weights / weights.max()
                timesteps = np.arange(len(weights))
                
                ax_top.plot(timesteps / len(timesteps), weights, 
                          color=model_styles[model]['color'],
                          linestyle=model_styles[model]['linestyle'],
                          alpha=0.8,
                          label=f"{model.upper()}")
                
                # Store accuracies
                accuracies['uniform'].append((model, results['accuracy_original']))
                accuracies['learned'].append((model, results['accuracy_best']))
        
        # Customize weight plot
        ax_top.set_title(f'{task_type.replace("_", " ").title()}')
        if i == 0:
            ax_top.set_ylabel('Weight')
        ax_top.set_xlim(0, 1)
        ax_top.set_ylim(0, 1)
        if i == 3:
            # Set legend below the whole plot
            ax_top.legend(fontsize=6, frameon=False, ncol=3, bbox_to_anchor=(0.5, -1.6), loc='center')
        # ax_top.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0f'))
        # ax_top.yaxis.set_major_formatter(plt.FormatStrFormatter('%.0f'))
        
        if i != 0:
            ax_top.set_yticklabels([])
        
        
        # write timestep as x lable
        # ax_top.set_xticks(timesteps / len(timesteps))
        # ax_top.set_xticklabels(timesteps)
        ax_top.set_xlabel('Timestep')
        # Plot relative accuracy improvements
        x = np.arange(len(models))
        width = 0.35
        
        # Calculate relative improvements
        improvements = []
        for model in models:
            uniform_acc = next((acc for m, acc in accuracies['uniform'] if m == model), np.nan)
            learned_acc = next((acc for m, acc in accuracies['learned'] if m == model), np.nan)
            if not np.isnan(uniform_acc) and not np.isnan(learned_acc):
                rel_improvement = max(0, (learned_acc - uniform_acc) * 100)  # Show 0% if negative
            else:
                rel_improvement = np.nan
            improvements.append(rel_improvement)
        
        # Plot single bars for improvements with corresponding colors
        for idx, (improvement, model) in enumerate(zip(improvements, models)):
            if not np.isnan(improvement):
                rect = ax_bottom.bar(x[idx], improvement, width, 
                                   color=model_styles[model]['color'], 
                                   alpha=1)
                # Add value label
                if improvement == 0:
                    ax_bottom.text(x[idx], improvement,
                                 '0%',
                                 ha='center', va='bottom', fontsize=3)
                else:
                    ax_bottom.text(x[idx], improvement,
                                 f'+{improvement:.1f}%',
                                 ha='center', va='bottom', fontsize=5)
        
        # Customize accuracy plot
        ax_bottom.set_xticks(x)
        ax_bottom.set_xticklabels(['1.5', '2', '3'], fontsize=6)
        if i == 0:
            ax_bottom.set_ylabel('Diff. (%)', position=(-30, 1), labelpad=18)
        ax_bottom.yaxis.set_major_formatter(plt.FormatStrFormatter('%.0f'))
        
        # Remove top subplot's bottom ticks
        # ax_top.set_xticks([])
        
        # Get current ticks from bottom plot
        ax_bottom.set_yticklabels([])
        
        
        
        # Remove left spline period
        ax_bottom.spines['left'].set_visible(False)
        # and left ticks ofc
        ax_bottom.set_yticks([])
        # remove x little ticks 
        ax_bottom.tick_params(axis='x', length=0)
        
    # Adjust layout
    plt.subplots_adjust(left=0.1, right=0.97, top=0.87, bottom=0.2, wspace=0.3, hspace=10)
    plt.gcf().set_size_inches(w, h)
    
    # Save plot
    os.makedirs('figures', exist_ok=True)
    plt.savefig('figures/timestep_weights_comparison.pdf', pad_inches=0, dpi=300)
    plt.close()
create_comparison_plots(results_cache)

In [68]:
# Load all results once

# Create visualizations using cached results
# create_accuracy_summary(results_cache)

