In [None]:
import os, sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import random
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
# from tqdm.auto import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path
import pandas as pd
pd.options.plotting.backend = "plotly"
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import copy
import shap
import plotly.express as px

chapter = "chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part4_superposition_and_saes").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, hist
from part31_superposition_and_saes.utils import (
    plot_features_in_2d,
    plot_features_in_Nd,
    plot_features_in_Nd_discrete,
    plot_correlated_features,
    plot_feature_geometry,
    frac_active_line_plot,
    animate_features_in_2d
)

from feature_geometry_utils import *
device = t.device("cuda" if t.cuda.is_available() else "cpu")

if not t.backends.mps.is_available():
    if not t.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    device = t.device("mps")

MAIN = __name__ == "__main__"

# Lower Sparsity Runs

In [None]:
def feature_probability(n_instances):
    feature_prob = (1000 ** -t.linspace(0, 1, n_instances))
    feature_prob = einops.rearrange(feature_prob, "instances -> instances ()")
    return feature_prob

def feature_importance(n_features):
    importance = (0.9 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance
    

In [None]:
feature_probability(8)

In [None]:
feature_importance(6)

In [None]:
def save_experiment_metadata(n_features, n_instances = 8,optim_fn=t.optim.Adam):
    print(f"n_features:{n_features}")

    cfg = Config(
            n_instances = n_instances,
            n_features = n_features,
            n_hidden = 2,
            optim_fn = optim_fn
        )

    
    # importance varies within features for each instance
    importance = feature_importance(cfg.n_features)
    
    # sparsity is the same for all features in a given instance, but varies over instances
    feature_prob = feature_probability(cfg.n_instances)

    model = Model(
        cfg = cfg,
        device = device,
        importance = importance,
        feature_probability = feature_prob,
    )
    summed_losses, losses, batches,all_W, all_b, per_feature_losses = model.optimize(steps=10_000)

    results_dict = {}
    results_dict["n_features"] = n_features
    results_dict["model"] = model
    results_dict["importance"] = model.importance
    results_dict["feature_prob"] = feature_prob
    results_dict["W"] = model.W.detach()
    results_dict["b"] = model.b_final.detach()
    results_dict["summed_loss"] = summed_losses
    results_dict["per_feature_losses"] = per_feature_losses
    results_dict["losses"] = losses
    results_dict["batches"] = batches
    results_dict["all_W"] = all_W
    results_dict["all_b"] = all_b
    
    return results_dict

In [None]:
n_features_list = list(range(4,11))
results_list = []

for n_feature in n_features_list:
    results_list.append(save_experiment_metadata(n_feature))

## Plotting per feature norms with total norm and total loss

In [None]:
moving_avg_window = 20
for result in results_list:
    print("n_feature:", result['n_features'])
    feature_wise_norm_lowest_sparsity = [t.norm(t.tensor(W[-1]),dim=0).cpu().tolist() for W in result['all_W']]
    lowest_sparsity_norm = [t.norm(t.tensor(W[-1])).item() for W in result['all_W']]
    lowest_sparsity_loss = pd.DataFrame(result['losses'])[7].values
    n_features_learned = [sum(np.round(np.abs(W[-1]).sum(axis = 0),0)>0).item() for W in result['all_W']]
    df = pd.DataFrame(feature_wise_norm_lowest_sparsity)
    df['norm'] = lowest_sparsity_norm
    df['loss'] = lowest_sparsity_loss
    df['n_features_learned'] = n_features_learned
    moving_avg_loss = df['loss'].rolling(window=moving_avg_window).mean()
   
    
    fig = df.plot()
    
    # Add the second trace for y2
    fig.add_trace(
        go.Scatter(x=df.index[moving_avg_window-1:], y=df['n_features_learned'].iloc[moving_avg_window-1:], name='n_features_learned', yaxis='y2')
    )
    
    # Update layout for dual axes
    fig.update_layout(
        yaxis=dict(title='norm/n_features_learned', side='left'),
        yaxis2=dict(title='loss', overlaying='y', side='right'),
        xaxis=dict(title='x'),
        title='Dual Axis Plot with Plotly'
    )
    
    # Show the plot
    fig.show()

# Animation for n_features

In [None]:
t.tensor(results_list[0]['all_W'][0])

In [None]:
all_w = [t.tensor(W) for W in results_list[3]['all_W']]

In [None]:
all_w[0].shape

In [None]:
animate_features_in_2d(
    {
        "weights": t.stack(all_w),
    },
    steps=df.index[moving_avg_window-1:].tolist(),
    filename="animation-n_features_7.html",
    title="Visualizing 7 features across epochs",
)

In [None]:
optim_df = pd.DataFrame()
for optim_fn, results in zip(optim_fn_list,optim_results_list):
    print("n_features:",results['n_features'])
    print("optim_fn:",optim_fn)
    optim_df[f"{optim_fn.__name__}"] = pd.DataFrame(results['losses'])[7].values

In [None]:
[result['n_features'] for result in results_list]

In [None]:
t.norm(results_list[3]['W'][-2]),t.norm(results_list[3]['W'][-2],dim=0),results_list[3]['W'][-2]

In [None]:
t.norm(results_list[3]['W'][-2],dim=0).cpu().tolist()

In [None]:
for results in results_list:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
        annotations= True
    )
    corr_plots(results)

In [None]:
results['importance'][0]

In [None]:
results['W'][0].T

In [None]:
def corr_plots(results, instances=8):
    fig, ax = plt.subplots(1, instances, figsize=(25, 4))

    ax[0].set_ylabel("feature correlation")
    for i in range(instances):
        W = results['W'][i].T.cpu()
        corr = t.corrcoef(W).cpu().numpy()
        ax[i].set_title(f"Sparsity: {np.round(results['feature_prob'][i].item(),3)}")
        ax[i].imshow(corr,cmap='viridis',aspect='equal')
        
        ax[i].set_xticks(np.arange(0, corr.shape[0], 1))
        ax[i].set_yticks(np.arange(0, corr.shape[0], 1))
        ax[i].set_xticklabels(np.arange(0, corr.shape[0], 1))
        ax[i].set_yticklabels(np.arange(0, corr.shape[0], 1))
        ax[i].set_xticks(np.arange(-.5, corr.shape[0], 1), minor=True)
        ax[i].set_yticks(np.arange(-.5, corr.shape[0], 1), minor=True)
        ax[i].grid(which='minor', color='w', linestyle='-', linewidth=2)

    im = ax[i].imshow(corr,cmap='viridis',aspect='equal')    
    fig.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, pad=0.04)
    plt.show();

In [None]:
corr_plots(results)

In [None]:
# Visualising hidden layers for the last batch of inputs
for results in results_list:
    print("n_features:",results['n_features'])

    hidden = einops.einsum(results['batches'][-1].to(device), results['W'], "... instances features, instances n_hidden features -> ... instances n_hidden")

    hidden = einops.rearrange(hidden, "batch instances hidden -> instances hidden batch")

    plot_features_in_2d(
        hidden,
        colors = "red",
        title = f"Hidden Layer: Input of {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

In [None]:
results_list_6 = []

for n in range(5):
    results_list_6.append(save_experiment_metadata(6))

In [None]:
for results in results_list[7:]:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

In [None]:
def feature_probability(n_instances):
    feature_prob = (1500 ** -t.linspace(0, 1, n_instances))
    feature_prob = einops.rearrange(feature_prob, "instances -> instances ()")
    return feature_prob

def feature_importance(n_features):
    importance = (0.7 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance

In [None]:
feature_probability(10)

In [None]:
feature_importance(6)

In [None]:
results_list_6 = []

for n in range(5):
    results_list_6.append(save_experiment_metadata(6,10))

In [None]:
for results in results_list_6:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

# Lower Sparsity runs for all n_feature with RMSProp

In [None]:
def feature_probability(n_instances):
    feature_prob = (1500 ** -t.linspace(0, 1, n_instances))
    feature_prob = einops.rearrange(feature_prob, "instances -> instances ()")
    return feature_prob

def feature_importance(n_features):
    importance = (0.7 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance

In [None]:
feature_probability(8)

In [None]:
feature_importance(6)

In [None]:
n_features_list = list(range(4,11))
rmsprop_low_sparsity_results_list = []

for n_feature in n_features_list:
    rmsprop_low_sparsity_results_list.append(save_experiment_metadata(n_feature,8,t.optim.RMSprop))

In [None]:
for results in rmsprop_low_sparsity_results_list:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )
    corr_plots(results)

# Lower Sparsity runs for n_feature=6 with RMSProp

In [None]:
rmsprop_low_sparsity_results_list_6 = []

for n in range(5):
    rmsprop_low_sparsity_results_list_6.append(save_experiment_metadata(6,10,t.optim.RMSprop))

In [None]:
for results in rmsprop_low_sparsity_results_list_6:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )
    corr_plots(results,10)

# Hexagon Replication

In [None]:
def feature_probability(n_instances):
    feature_prob = (25 ** -t.linspace(0, 1, n_instances))
    feature_prob = einops.rearrange(feature_prob, "instances -> instances ()")
    return feature_prob

def feature_importance(n_features):
    importance = (1 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance

In [None]:
feature_probability(10)

In [None]:
feature_importance(6)

In [None]:
results_list_hexagon_replication = []

for n in range(10):
    results_list_hexagon_replication.append(save_experiment_metadata(6,10))

In [None]:
for results in results_list_hexagon_replication:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
    )

# Initialization Effects

In [None]:
t.manual_seed(123)
np.random.seed(123)
random.seed(123)

In [None]:
random_batch_6_fixed_init = []

n_features=6
n_instances=8
optim_fn=t.optim.Adam

cfg = Config(
        n_instances = n_instances,
        n_features = n_features,
        n_hidden = 2,
        optim_fn = optim_fn
    )


# importance varies within features for each instance
importance = feature_importance(cfg.n_features)

# sparsity is the same for all features in a given instance, but varies over instances
feature_prob = feature_probability(cfg.n_instances)

model = Model(
    cfg = cfg,
    device = device,
    importance = importance,
    feature_probability = feature_prob,
)

for n in range(5):

    new_model = copy.deepcopy(model)
    summed_losses, losses, batches = new_model.optimize(steps=10_000)

    results_dict = {}
    results_dict["n_features"] = n_features
    results_dict["importance"] = new_model.importance
    results_dict["feature_prob"] = feature_prob
    results_dict["W"] = new_model.W.detach()
    results_dict["b"] = new_model.b_final.detach()
    results_dict["summed_loss"] = summed_losses
    results_dict["losses"] = losses
    results_dict["batches"] = batches
    random_batch_6_fixed_init.append(results_dict)

In [None]:
for results in random_batch_6_fixed_init:
    print("n_features:",results['n_features'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob'].squeeze()],
        annotations = True
    )
    corr_plots(results)