In [None]:
import os, sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import random
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
# from tqdm.auto import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path
import pandas as pd
pd.options.plotting.backend = "plotly"
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import copy
import shap
import plotly.express as px

chapter = "chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part4_superposition_and_saes").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, hist
from part31_superposition_and_saes.utils import (
    plot_features_in_2d,
    plot_features_in_Nd,
    plot_features_in_Nd_discrete,
    plot_correlated_features,
    plot_feature_geometry,
    frac_active_line_plot,
    animate_features_in_2d
)

from feature_geometry_utils import *
device = t.device("cuda" if t.cuda.is_available() else "cpu")

if not t.backends.mps.is_available():
    if not t.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    device = t.device("mps")

MAIN = __name__ == "__main__"

In [None]:
def feature_importance(n_features):
    importance = (1 ** t.arange(n_features))
    importance = einops.rearrange(importance, "features -> () features")
    return importance

In [None]:
feature_importance(6)

In [None]:
def save_experiment_metadata(n_features,seed, feature_prob ,n_instances = 1,optim_fn=t.optim.Adam):
    print(f"n_features:{n_features}")
    print(f"seed:{seed}")

    t.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cfg = Config(
            n_instances = n_instances,
            n_features = n_features,
            n_hidden = 2,
            optim_fn = optim_fn
        )

    
    # importance varies within features for each instance
    importance = feature_importance(cfg.n_features)
    
    # # sparsity is the same for all features in a given instance, but varies over instances
    # feature_prob = feature_probability(cfg.n_instances)

    model = Model(
        cfg = cfg,
        device = device,
        importance = importance,
        feature_probability = feature_prob,
    )
    output_dict = model.optimize(steps=10_000)

    results_dict = {}
    results_dict["n_features"] = n_features
    results_dict["seed"] = seed
    results_dict["model"] = model
    results_dict["importance"] = model.importance
    results_dict["feature_prob"] = feature_prob
    results_dict["W"] = model.W.detach()
    results_dict["b"] = model.b_final.detach()
    results_dict["output_dict"] = output_dict
    
    return results_dict

In [None]:
def norm_vs_feature_learned(result, index):
    print("n_feature:", result['n_features'])
    feature_wise_norm = [t.norm(W[index],dim=0).cpu().tolist() for W in result['output_dict']['all_W']]
    overall_norm = [t.norm(W[index]).item() for W in result['output_dict']['all_W']]
    n_features_learned = [sum(np.round(np.abs(W[index]).sum(axis = 0),0)>0).item() for W in result['output_dict']['all_W']]
    df = pd.DataFrame(feature_wise_norm)
    df['overall_norm'] = overall_norm
    
    fig = df.plot()
        
    # Add the second trace for y2
    fig.add_trace(
        go.Scatter(x=df.index, y=n_features_learned, name='n_features_learned', yaxis='y2')
    )
    
    # Update layout for dual axes
    fig.update_layout(
        yaxis=dict(title='norm/n_features_learned', side='left'),
        yaxis2=dict(title='loss', overlaying='y', side='right'),
        xaxis=dict(title='x'),
        title='Dual Axis Plot with Plotly'
    )
    
    # Show the plot
    fig.show()

def per_feature_loss_viz(result, index, test_loss=False, log=False,moving_avg_window = 20):
    print("n_feature:", result['n_features'])
    if not test_loss:
        per_feature_loss_key = 'per_feature_losses'
        loss_key = 'losses'
        title='Per Feature loss vs overall loss'
    else:
        per_feature_loss_key = 'test_per_feature_losses'
        loss_key = 'test_losses'
        title = "Test Per Feature loss vs test overall loss"
        
    per_feature_loss_df = pd.DataFrame([per_feature_loss[index] for per_feature_loss in result['output_dict'][per_feature_loss_key]])
    per_feature_loss_df['overall_loss'] = pd.DataFrame(result['output_dict'][loss_key])[index].values
    n_features_learned = [sum(np.round(np.abs(W[index]).sum(axis = 0),0)>0).item() for W in result['output_dict']['all_W']]

    moving_avg_loss = per_feature_loss_df.rolling(window=moving_avg_window).mean()
    
    if not log:
        fig = moving_avg_loss.iloc[moving_avg_window-1:].plot()    
    else:
        fig = np.log(moving_avg_loss.iloc[moving_avg_window-1:]).plot()
        
    # Add the second trace for y2
    fig.add_trace(
        go.Scatter(x=moving_avg_loss.index[moving_avg_window-1:], y=n_features_learned[moving_avg_window-1:], name='n_features_learned', yaxis='y2')
    )
    
    # Update layout for dual axes
    fig.update_layout(
        yaxis=dict(title='per_feature_loss/overall loss', side='left'),
        yaxis2=dict(title='n_features_learned', overlaying='y', side='right'),
        xaxis=dict(title='x'),
        title=title
    )
    
    # Show the plot
    fig.show()

In [None]:
feature_prob = t.Tensor([0.057])
seed = 20
hexagon_result_057_20 = save_experiment_metadata(n_features = 7,seed = seed,feature_prob=feature_prob,n_instances=1)

print("n_features:",hexagon_result_057_20['n_features'])

plot_features_in_2d(
    hexagon_result_057_20['W'],
    colors = hexagon_result_057_20['importance'],
    title = f"Superposition: {hexagon_result_057_20['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_20['feature_prob']],
)

### Case 1

In [None]:
feature_prob = t.Tensor([0.057])
seed = 20
hexagon_result_057_20 = save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1)

print("n_features:",hexagon_result_057_20['n_features'])

plot_features_in_2d(
    hexagon_result_057_20['W'],
    colors = hexagon_result_057_20['importance'],
    title = f"Superposition: {hexagon_result_057_20['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_20['feature_prob']],
)

norm_vs_feature_learned(hexagon_result_057_20,0)

per_feature_loss_viz(hexagon_result_057_20, 0)

In [None]:
feature_grad_norm_057_20 = [t.norm(W[0],dim=0).cpu().tolist() for W in hexagon_result_057_20['output_dict']['all_W_grad']]
grad_norm_057_20_df = pd.DataFrame(feature_grad_norm_057_20)
moving_grad_norm_057_20_df = grad_norm_057_20_df.rolling(window=moving_avg_window).mean()
moving_grad_norm_057_20_df.iloc[moving_avg_window-1:].plot()

### Case 2

In [None]:
feature_prob = t.Tensor([0.057])
seed = 26
hexagon_result_057_26 = save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1)

print("n_features:",hexagon_result_057_26['n_features'])

plot_features_in_2d(
    hexagon_result_057_26['W'],
    colors = hexagon_result_057_26['importance'],
    title = f"Superposition: {hexagon_result_057_26['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_26['feature_prob']],
)

norm_vs_feature_learned(hexagon_result_057_26,0)

per_feature_loss_viz(hexagon_result_057_26, 0)

In [None]:
feature_prob = t.Tensor([0.040])
seed = 26
hexagon_result_040_26 = save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1)

print("n_features:",hexagon_result_040_26['n_features'])

plot_features_in_2d(
    hexagon_result_040_26['W'],
    colors = hexagon_result_040_26['importance'],
    title = f"Superposition: {hexagon_result_040_26['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_040_26['feature_prob']],
)

norm_vs_feature_learned(hexagon_result_040_26,0)

per_feature_loss_viz(hexagon_result_040_26, 0)

### All three losses compared..

In [None]:
loss_057_20 = [loss[0] for loss in hexagon_result_057_20["output_dict"]['losses']]
loss_057_26 = [loss[0] for loss in hexagon_result_057_26["output_dict"]['losses']]
loss_040_26 = [loss[0] for loss in hexagon_result_040_26["output_dict"]['losses']]

loss_df = pd.DataFrame({"loss_057_20":loss_057_20,"loss_057_26":loss_057_26,"loss_040_26":loss_040_26})

moving_avg_window = 20
moving_avg_loss = loss_df.rolling(window=moving_avg_window).mean()
moving_avg_loss.iloc[moving_avg_window-1:].plot()    

### All three norms compared..

In [None]:
overall_norm_057_20 = [t.norm(W).item() for W in hexagon_result_057_20["output_dict"]['all_W']]
overall_norm_057_26 = [t.norm(W).item() for W in hexagon_result_057_26["output_dict"]['all_W']]
overall_norm_040_26 = [t.norm(W).item() for W in hexagon_result_040_26["output_dict"]['all_W']]

pd.DataFrame({"pentagon":overall_norm_057_20,"hexagon":overall_norm_057_26}).plot()

In [None]:
overall_norm_057_20[0],overall_norm_057_26[0],overall_norm_040_26[0]

In [None]:
plot_features_in_2d(
    t.Tensor(hexagon_result_057_20["output_dict"]['all_W'][0]),
    colors = hexagon_result_057_20['importance'],
    title = f"Superposition: {hexagon_result_057_20['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_20['feature_prob']],
)

plot_features_in_2d(
    t.Tensor(hexagon_result_057_26["output_dict"]['all_W'][0]),
    colors = hexagon_result_057_26['importance'],
    title = f"Superposition: {hexagon_result_057_26['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_26['feature_prob']],
)

plot_features_in_2d(
    t.Tensor(hexagon_result_040_26["output_dict"]['all_W'][0]),
    colors = hexagon_result_040_26['importance'],
    title = f"Superposition: {hexagon_result_040_26['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_040_26['feature_prob']],
)

### All three gradient norms compared..

In [None]:
feature_grad_norm_057_20 = [t.norm(W[0],dim=0).cpu().tolist() for W in hexagon_result_057_20['output_dict']['all_W_grad']]
grad_norm_057_20_df = pd.DataFrame(feature_grad_norm_057_20)
moving_grad_norm_057_20_df = grad_norm_057_20_df.rolling(window=moving_avg_window).mean()
moving_grad_norm_057_20_df.iloc[moving_avg_window-1:].plot()

In [None]:
feature_grad_norm_057_26 = [t.norm(W[0],dim=0).cpu().tolist() for W in hexagon_result_057_26['output_dict']['all_W_grad']]
grad_norm_057_26_df = pd.DataFrame(feature_grad_norm_057_26)
moving_grad_norm_057_26_df = grad_norm_057_26_df.rolling(window=moving_avg_window).mean()
moving_grad_norm_057_26_df.iloc[moving_avg_window-1:].plot()

In [None]:
feature_grad_norm_040_26 = [t.norm(W[0],dim=0).cpu().tolist() for W in hexagon_result_040_26['output_dict']['all_W_grad']]
grad_norm_040_26_df = pd.DataFrame(feature_grad_norm_040_26)
moving_grad_norm_040_26_df = grad_norm_040_26_df.rolling(window=moving_avg_window).mean()
moving_grad_norm_040_26_df.iloc[moving_avg_window-1:].plot()

In [None]:
overall_grad_norm_057_20 = [t.norm(W).item() for W in hexagon_result_057_20["output_dict"]['all_W_grad']]
overall_grad_norm_057_26 = [t.norm(W).item() for W in hexagon_result_057_26["output_dict"]['all_W_grad']]
overall_grad_norm_040_26 = [t.norm(W).item() for W in hexagon_result_040_26["output_dict"]['all_W_grad']]

grad_norm_df = pd.DataFrame({"overall_grad_norm_057_20":overall_grad_norm_057_20,"overall_grad_norm_057_26":overall_grad_norm_057_26,"overall_grad_norm_040_26":overall_grad_norm_040_26})

moving_avg_window = 20
moving_grad_norm = grad_norm_df.rolling(window=moving_avg_window).mean()
moving_grad_norm.iloc[moving_avg_window-1:].plot()

In [None]:
model.W.grad.detach().cpu().shape

### Inititalization Distribution Differences

In [None]:
std = np.sqrt(2/8)
std

In [None]:
import matplotlib.pyplot as plt
import scipy.stats as stats

In [None]:
mu = 0
x = np.linspace(mu - 3*std, mu + 3*std, 100)
plt.plot(x, stats.norm.pdf(x, mu, std))
plt.show()

In [None]:
plt.hist(hexagon_result_040_26['all_W'][0].flatten())

In [None]:
plt.hist(hexagon_result_057_26['all_W'][0].flatten())

In [None]:
plt.hist(hexagon_result_057_20['all_W'][0].flatten())

### Animation for all three seeds

all_w = [W for W in hexagon_result_040_26['output_dict']['all_W']]
animate_features_in_2d(
    {
        "weights": t.stack(all_w),
    },
    steps=len(all_w),
    filename="hexagon_result_040_26.html",
    title="Visualizing 6 features across epochs",
)

### Does hexagon replicates for all sparsities for a given seed

In [None]:
feature_prob_list = t.Tensor([0.057, 0.040, 0.03 ,0.02, 0.01, 0.001, 0.0001, 0.00001])
seed = 26
hexagon_result_diff_spars = []
for feature_prob in feature_prob_list:
    print("feature_prob:", feature_prob)
    hexagon_result_diff_spars.append(save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1))

In [None]:
for result in hexagon_result_diff_spars:

    print("feature_prob:",result['feature_prob'])
    plot_features_in_2d(
        result['W'],
        colors = result['importance'],
        title = f"Superposition: {result['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {result['feature_prob']:.3f}"],
    )

### Does hexagon seeds have lower loss than pentagon seeds for a given sparsity?

In [None]:
feature_prob = t.Tensor([0.057])
seed_list = [26, 94, 32, 715, 158, 50, 30,40, 60,70]
result_diff_seeds = []
for seed in seed_list:
    print("seed:", seed)
    result_diff_seeds.append(save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1))

In [None]:
for result in result_diff_seeds:

    print("seed:",result['seed'])
    plot_features_in_2d(
        result['W'],
        colors = result['importance'],
        title = f"Superposition: {result['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_26['feature_prob']],
    )

In [None]:
losses_dict = {}
for result in result_diff_seeds:
    losses_dict[result['seed']] = [entry[0] for entry in result['output_dict']['losses']]

In [None]:
moving_avg_window = 100
losses_df = pd.DataFrame(losses_dict)
losses_df.drop([158,50], axis = 1, inplace = True) # removing four feature seeds, as they were with very high loss and were distorting the plot
losses_df.columns = ["hex_26", "hex_94", "hex_32", "hex_715", "hex_30", "pent_40", "pent_60", "pen_70"]
color_dict = {"hex_26":'#565454',"hex_94":'#565454',"hex_32":'#565454',"hex_715":'#565454',"hex_30":'#565454',"pent_40":"red","pent_60":"red","pen_70":"red"}
moving_avg_loss = losses_df.rolling(window=moving_avg_window).mean()
moving_avg_loss.plot(x=losses_df.index, y = list(losses_df.columns), color_discrete_map = color_dict)

### Patching Hexagon seed to pentagon seed for the unlearnt feature

In [None]:
def pentagon_patching(n_features,feature_idx,pent_seed, hex_seed, feature_prob ,n_instances = 1,optim_fn=t.optim.Adam):

    cfg = Config(
            n_instances = n_instances,
            n_features = n_features,
            n_hidden = 2,
            optim_fn = optim_fn
        )
    
    # importance varies within features for each instance
    importance = feature_importance(cfg.n_features)
    
    # # sparsity is the same for all features in a given instance, but varies over instances
    # feature_prob = feature_probability(cfg.n_instances)
    
    print(f"pent_seed:{pent_seed}")
    
    t.manual_seed(pent_seed)
    np.random.seed(pent_seed)
    random.seed(pent_seed)
    
    pent_model = Model(
        cfg = cfg,
        device = device,
        importance = importance,
        feature_probability = feature_prob,
    )
    print(f"hex_seed:{hex_seed}")
    t.manual_seed(hex_seed)
    np.random.seed(hex_seed)
    random.seed(hex_seed)
    
    hex_model = Model(
        cfg = cfg,
        device = device,
        importance = importance,
        feature_probability = feature_prob,
    )

    print("initial weight comparison", pent_model.W == hex_model.W)

    old_pent_w = pent_model.W.clone()

    with t.no_grad():
        pent_model.W[:,:,1] = hex_model.W[:,:,1]

    print("later pent weight comparison ",pent_model.W == old_pent_w)
    print("later hex and pent weight comparison",pent_model.W == hex_model.W)

    output_dict = pent_model.optimize(steps=10_000)
    results_dict = {}
    results_dict["n_features"] = n_features
    results_dict["seed"] = seed
    results_dict["model"] = pent_model
    results_dict["importance"] = pent_model.importance
    results_dict["feature_prob"] = feature_prob
    results_dict["W"] = pent_model.W.detach()
    results_dict["b"] = pent_model.b_final.detach()
    results_dict["output_dict"] = output_dict

    return results_dict

In [None]:
feature_prob = t.Tensor([0.057])
pent_seed = 20
hex_seed=94
feature_idx=1
patched_pent_model_results_dict = pentagon_patching(n_features=6,feature_idx=feature_idx,pent_seed=pent_seed, hex_seed=hex_seed, feature_prob = feature_prob)

plot_features_in_2d(
    patched_pent_model_results_dict['W'],
    colors = patched_pent_model_results_dict['importance'],
    title = f"Superposition: {patched_pent_model_results_dict['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in patched_pent_model_results_dict['feature_prob']],
)

norm_vs_feature_learned(patched_pent_model_results_dict,0)

per_feature_loss_viz(patched_pent_model_results_dict, 0)

In [None]:
feature_prob = t.Tensor([0.057])
pent_seed = 20
hex_seed=32
feature_idx=1
patched_pent_model_results_dict = pentagon_patching(n_features=6,feature_idx=feature_idx,pent_seed=pent_seed, hex_seed=hex_seed, feature_prob = feature_prob)

plot_features_in_2d(
    patched_pent_model_results_dict['W'],
    colors = patched_pent_model_results_dict['importance'],
    title = f"Superposition: {patched_pent_model_results_dict['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in patched_pent_model_results_dict['feature_prob']],
)

norm_vs_feature_learned(patched_pent_model_results_dict,0)

per_feature_loss_viz(patched_pent_model_results_dict, 0)

In [None]:
feature_prob = t.Tensor([0.057])
pent_seed = 20
hex_seed=30
feature_idx=1
patched_pent_model_results_dict = pentagon_patching(n_features=6,feature_idx=feature_idx,pent_seed=pent_seed, hex_seed=hex_seed, feature_prob = feature_prob)

plot_features_in_2d(
    patched_pent_model_results_dict['W'],
    colors = patched_pent_model_results_dict['importance'],
    title = f"Superposition: {patched_pent_model_results_dict['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in patched_pent_model_results_dict['feature_prob']],
)

norm_vs_feature_learned(patched_pent_model_results_dict,0)

per_feature_loss_viz(patched_pent_model_results_dict, 0)

### Impact of optimizer

In [None]:
optim_fn_list = [t.optim.Adam, t.optim.AdamW, t.optim.SGD, t.optim.RMSprop]
for optim_fn in optim_fn_list:
    print(optim_fn)
    diff_optim_runs(optim_fn)

In [None]:
feature_prob = t.Tensor([0.057])
seed = 26
optim_fn_list = [t.optim.Adam, t.optim.AdamW, t.optim.SGD, t.optim.RMSprop]
diff_optim_results = []
for optim_fn in optim_fn_list:
    print(optim_fn)
    diff_optim_results.append(save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1,optim_fn=optim_fn))

In [None]:
for optim_fn, result in zip(optim_fn_list,diff_optim_results):
    
    print("optim_fn:",optim_fn)
    
    plot_features_in_2d(
        result['W'],
        colors = result['importance'],
        title = f"Superposition: {result['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in result['feature_prob']],
    )
    
    norm_vs_feature_learned(result,0)
    
    per_feature_loss_viz(result, 0)

In [None]:
feature_prob = t.Tensor([0.057])
seed = 94
hexagon_result_057_26 = save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1, optim_fn=t.optim.RMSprop)

print("n_features:",hexagon_result_057_26['n_features'])

plot_features_in_2d(
    hexagon_result_057_26['W'],
    colors = hexagon_result_057_26['importance'],
    title = f"Superposition: {hexagon_result_057_26['n_features']} features represented in 2D space",
    subplot_titles = [f"1 - S = {i:.3f}" for i in hexagon_result_057_26['feature_prob']],
)

norm_vs_feature_learned(hexagon_result_057_26,0)

per_feature_loss_viz(hexagon_result_057_26, 0)

random_seed_list = random.sample(range(100, 1000), 10)
print(random_seed_list)
results_list_hexagon_replication = []

for seed in random_seed_list:
    feature_prob = t.Tensor([0.057])
    results_list_hexagon_replication.append(save_experiment_metadata(n_features = 6,seed = seed,feature_prob=feature_prob,n_instances=1))

for results in results_list_hexagon_replication:
    print("n_features:",results['n_features'])
    print("seed:",results['seed'])

    plot_features_in_2d(
        results['W'],
        colors = results['importance'],
        title = f"Superposition: {results['n_features']} features represented in 2D space",
        subplot_titles = [f"1 - S = {i:.3f}" for i in results['feature_prob']],
    )

    norm_vs_feature_learned(results,0)

    per_feature_loss_viz(results, 0)