## Description

Code for plotting layer spectra for MLP.

## Imports

In [None]:

%load_ext autoreload
%autoreload 2

import sys
import os
import pickle
import torch
import scipy
import matplotlib.pyplot as plt
from argparse import Namespace
from pathlib import Path
import plotly.graph_objects as go
import pathlib
import numpy as np
from collections import defaultdict

# Update path to include mypkg
sys.path.insert(0, str(Path(os.path.abspath('')).parent.parent.absolute()))

from src import helpers, plot_data, global_config, datasets, models
config = global_config.config

%matplotlib inline


## Plot all spectrums

In [None]:

means = np.random.randn(5,20)
stds = np.random.randn(5,20)

def plot_multiple_spectral_norms_pl(means, stds, title=""):
    # means, stds: shape (n_layers, train len)
    iter_nums = np.arange(means.shape[1])
    
    fig = go.Figure()
    
    for layer_num, (mean_curve, std_curve) in enumerate(zip(means, stds)):
        color = f'hsl({layer_num * 360 / len(means)}, 50%, 50%)'  # Generate distinct colors
        
        # Add the mean line
        fig.add_trace(go.Scatter(
            x=iter_nums,
            y=mean_curve,
            mode='lines',
            name=f'Layer {layer_num + 1}',
            line=dict(color=color)
        ))
        
        # Add the confidence interval
        fig.add_trace(go.Scatter(
            x=np.concatenate([iter_nums, iter_nums[::-1]]),  # x then reversed x
            y=np.concatenate([mean_curve + std_curve, (mean_curve - std_curve)[::-1]]),  # upper then lower reversed
            fill='toself',
            fillcolor=color,
            opacity=0.15,
            line=dict(width=0),
            showlegend=False,
            hoverinfo='skip'
        ))
    
    fig.update_layout(
        xaxis_title='Training Iteration',
        yaxis_title='Spectral Norm of Layer Weights',
        hovermode='x unified',
        title=title
    )
    
    fig.show()

def plot_multiple_spectral_norms(means, stds):
    # means, stds: shape (n_layers, train len)
    iter_nums = np.arange(means.shape[1])
    plt.xlabel("Training Iteration")
    plt.ylabel("Spectral Norm of Layer Weights")
    for layer_num, (mean_curve, std_curve) in enumerate(zip(means, stds)): 
        p = plt.plot(iter_nums, mean_curve, label=f'Layer {layer_num + 1}')
        plt.fill_between(iter_nums, mean_curve + std_curve, mean_curve - std_curve, color=p[0].get_color(), alpha=0.15)
    plt.legend()
    plt.show()

plot_multiple_spectral_norms(means, stds)
plot_multiple_spectral_norms_pl(means, stds)


## Collect Spectrums Frames

In [None]:

total_data = dict()

for dataset in ['MNIST']:
    # Input
    in_dir = str(Path(config.top_dir) / 'fft_images')
    spectrum_fname = 'model_spectrum.pkl'
    hp_fname = 'hyperparams.pkl'

    # Output
    model_spectrums = defaultdict(lambda : defaultdict(list))
    model_bottlenecks = defaultdict(lambda : defaultdict(list))

    # Make sure top directory exists
    if not os.path.isdir(in_dir):
        raise ValueError(f"The provided directory does not exist: {in_dir}")

    in_dirs = [d for d in os.listdir(in_dir) if os.path.isdir(os.path.join(in_dir, d))]

    for tmp_dir in in_dirs:
        if tmp_dir[0:len(dataset)] != dataset:
            continue

        # ensemble number
        # ensemble = int(str(tmp_dir).split('_')[-1])

        # Set up pickle paths before loading
        spectrum_path = Path(in_dir) / tmp_dir / 'pkl' / spectrum_fname
        hp_path = Path(in_dir) / tmp_dir / 'pkl' / hp_fname

        # Load data
        try:
            with open(spectrum_path, 'rb') as f:
                spectrum = pickle.load(f)
        except Exception as e:
            print(f"Failed to load pickle file {spectrum_path}: {str(e)}")
            continue

        try:
            with open(hp_path, 'rb') as f:
                hp = pickle.load(f)
        except Exception as e:
            print(f"Failed to load pickle file {hp_path}: {str(e)}")
            continue

        # Collect data
        layers = hp['mlp_depth']
        bottleneck = hp['mlp_bottleneck_dim']

        # Save
        model_spectrums[layers][bottleneck].append(spectrum)
        model_bottlenecks[layers][bottleneck].append(bottleneck)

        print(tmp_dir)

    total_data[dataset] = {"model_spectrums": model_spectrums, "model_bottlenecks": model_bottlenecks}

# ensembles, n iters, layers
mat = np.asarray(total_data["MNIST"]["model_spectrums"][6][784])
print(mat.shape)
print(np.mean(mat, axis=0).shape)

# total_data: dataset, parameter, layers, bottleneck, spectrum
print(total_data["MNIST"].keys())


## Make Plot

In [None]:
for dataset in ['MNIST']:
    # Generate data
    plot_dataset = dataset
    plot_parameter = 'model_spectrums'
    plot_data_dict = total_data[plot_dataset][plot_parameter]

    # Separate out plots
    for layer in sorted(plot_data_dict.keys()):
        for bottleneck in sorted(plot_data_dict[layer].keys()):

            data = np.asarray(total_data[dataset][plot_parameter][layer][bottleneck])

            plot_means = np.mean(data, axis=0)
            plot_stdevs = np.mean(data, axis=1)

            # Plot
            title = f"Layer spectra for {dataset} with {layer} layers and {bottleneck} bottleneck"

            plot_multiple_spectral_norms_pl(plot_means, plot_stdevs, title)

