In [26]:
import os
import sys
import json
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d, scores_clamp_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7fee01bd0850>

In [415]:
# save_dir = "plots/anger_wedding"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140]
# feature_descriptions = ["anger", "wedding"]

# save_dir = "plots/anger_london"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140]
# feature_descriptions = ["anger", "london"]

# save_dir = "plots/castle_writing"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["castle", "writing"]

# save_dir = "plots/clamp_anger_london"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["anger", "london"]

# save_dir = "plots/clamp_anger_wedding"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["anger", "wedding"]

# save_dir = "plots/clamp_london_wedding"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["london", "wedding"]

# save_dir = "plots/dyno_anger_wedding"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["anger", "wedding"]

# save_dir = "plots/london_london_culture"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["london", "london_culture"]

# save_dir = "plots/london_wedding"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140]
# feature_descriptions = ["london", "wedding"]

# save_dir = "plots/london_writing"
# scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["london", "writing"]

save_dir = "plots/narrow_anger_wedding"
scales = [40, 45, 50, 55, 60, 65, 70]
feature_descriptions = ["anger", "wedding"]


In [416]:
# load tensors
coherence_scores = torch.load(f"{save_dir}/coherence_scores.pt")
losses = torch.load(f"{save_dir}/losses.pt")
scores_1 = torch.load(f"{save_dir}/scores_1.pt")
scores_2 = torch.load(f"{save_dir}/scores_2.pt")

# rescale
coherence_scores = (coherence_scores-1) / 9
scores_1 = (scores_1-1) / 9
scores_2 = (scores_2-1) / 9

multi = scores_1 * scores_2 * coherence_scores

In [417]:
fig = make_subplots(rows=1, cols=2, subplot_titles=(feature_descriptions[0], feature_descriptions[1]), horizontal_spacing=0.2)
fig.add_trace(go.Heatmap(z=scores_1, x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=0.40)), row=1, col=1)
fig.add_trace(go.Heatmap(z=scores_2, x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=1.0)), row=1, col=2)

fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=1)
fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=1, autorange='reversed')
fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=2)
fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=2, autorange='reversed')
# fig.update_yaxes(title_text="", row=1, col=2, autorange='reversed')

# Add borders around the heatmaps
fig.add_shape(type="rect",
              x0=0, x1=0.4, y0=0, y1=1,
              line=dict(color="Black", width=2),
              xref="paper", yref="paper")
fig.add_shape(type="rect",
              x0=0.6, x1=1, y0=0, y1=1,
              line=dict(color="Black", width=2),
              xref="paper", yref="paper")

fig.update_layout(width=1000, height=500)
fig.update_layout(title_text="Steering Vector Addition")
fig.show()
fig.write_image(f"{save_dir}/both_scores.png")

In [418]:
fig = px.imshow(multi, x=scales, y=scales,
            title=f"coherence * {feature_descriptions[0]} * {feature_descriptions[1]}",labels={'x': feature_descriptions[1], 'y': feature_descriptions[0]},
            color_continuous_scale="RdBu", color_continuous_midpoint=0)
fig.show()
fig.write_image(f"{save_dir}/multi.png")

In [419]:

fig = make_subplots(rows=1, cols=2, subplot_titles=(
    f"coherence * {feature_descriptions[0]}", 
    f"coherence * {feature_descriptions[1]}"
), horizontal_spacing=0.2)

# Heatmap for scores_1 * coherence_scores
fig.add_trace(go.Heatmap(
    z=scores_1 * coherence_scores, 
    x=scales, 
    y=scales, 
    colorscale="RdBu", 
    zmid=0, 
    colorbar=dict(title="", x=0.40)
), row=1, col=1)

# Heatmap for scores_2 * coherence_scores
fig.add_trace(go.Heatmap(
    z=scores_2 * coherence_scores, 
    x=scales, 
    y=scales, 
    colorscale="RdBu", 
    zmid=0, 
    colorbar=dict(title="", x=1.0)
), row=1, col=2)

# Update x and y axes titles and settings
fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=1)
fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=1, autorange='reversed')
fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=2)
fig.update_yaxes(title_text="", row=1, col=2, autorange='reversed')

# Add borders around the heatmaps
fig.add_shape(type="rect",
              x0=0, x1=0.4, y0=0, y1=1,
              line=dict(color="Black", width=2),
              xref="paper", yref="paper")
fig.add_shape(type="rect",
              x0=0.6, x1=1, y0=0, y1=1,
              line=dict(color="Black", width=2),
              xref="paper", yref="paper")

# Update layout with title
fig.update_layout(title_text="Coherence * Individual Feature Scores")
fig.update_layout(width=1000, height=500)

# Show and save the figure
fig.show()
fig.write_image(f"{save_dir}/individual_coherence_scores.png")

In [420]:
# # load clamp tensors
# c_save_dir = "runs/clamp_anger_wedding"
# c_scales = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# feature_descriptions = ["anger", "wedding"]

# # load tensors
# c_coherence_scores = torch.load(f"{c_save_dir}/coherence_scores.pt")
# c_losses = torch.load(f"{c_save_dir}/losses.pt")
# c_scores_1 = torch.load(f"{c_save_dir}/scores_1.pt")
# c_scores_2 = torch.load(f"{c_save_dir}/scores_2.pt")

# # rescale
# c_coherence_scores = (c_coherence_scores-1) / 9
# c_scores_1 = (c_scores_1-1) / 9
# c_scores_2 = (c_scores_2-1) / 9
# c_multi = c_scores_1 * c_scores_2 * c_coherence_scores

In [421]:
# fig = px.imshow(c_multi, x=c_scales, y=c_scales,
#             title=f"(clamp) coherence * {feature_descriptions[0]} * {feature_descriptions[1]}",labels={'x': feature_descriptions[1], 'y': feature_descriptions[0]},
#             color_continuous_scale="RdBu", color_continuous_midpoint=0)
# fig.show()

In [422]:
# fig = make_subplots(rows=1, cols=2, subplot_titles=("Anger", "Wedding"), horizontal_spacing=0.2)
# fig.add_trace(go.Heatmap(z=c_scores_1, x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=0.40)), row=1, col=1)
# fig.add_trace(go.Heatmap(z=c_scores_2, x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=1.0)), row=1, col=2)

# fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=1)
# fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=1, autorange='reversed')
# fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=2)
# # fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=2, autorange='reversed')
# fig.update_yaxes(title_text="", row=1, col=2, autorange='reversed')

# # Add borders around the heatmaps
# fig.add_shape(type="rect",
#               x0=0, x1=0.4, y0=0, y1=1,
#               line=dict(color="Black", width=2),
#               xref="paper", yref="paper")
# fig.add_shape(type="rect",
#               x0=0.6, x1=1, y0=0, y1=1,
#               line=dict(color="Black", width=2),
#               xref="paper", yref="paper")

# fig.update_layout(title_text="Feature Clamping")
# fig.show()

## Compute per-instance score_1 * score_2

In [423]:
with open(f"{save_dir}/gen_log.json", "r") as f:
    gen_log = json.load(f)

In [424]:
per_instance_mults = torch.zeros((len(scales), len(scales)))

for d in gen_log:
    scale1, scale2 = d['scales']
    i = scales.index(scale1)
    j = scales.index(scale2)
    s1 = d['scores_1']
    s1 = [(x - 1) / 9 for x in s1]
    s2 = d['scores_2']
    s2 = [(x - 1) / 9 for x in s2]
    coh = d['coherence_scores']
    coh = [(x - 1) / 9 for x in coh]
    total = sum([a*b*c for a, b, c in zip(s1, s2, coh)])
    total = total / len(s1)
    per_instance_mults[i, j] = total


fig = px.imshow(per_instance_mults, x=scales, y=scales,
            title=f"per-instance mult",labels={'x': feature_descriptions[1], 'y': feature_descriptions[0]},
            color_continuous_scale="RdBu", color_continuous_midpoint=0)
fig.show()
fig.write_image(f"{save_dir}/per_instance_mult.png")

In [425]:
# # now for clamp
# with open(f"{c_save_dir}/gen_log.json", "r") as f:
#     c_gen_log = json.load(f)

In [426]:
# c_per_instance_mults = torch.zeros((len(c_scales), len(c_scales)))

# for d in c_gen_log:
#     scale1, scale2 = d['scales']
#     i = scales.index(scale1)
#     j = scales.index(scale2)
#     s1 = d['scores_1']
#     s1 = [(x - 1) / 9 for x in s1]
#     s2 = d['scores_2']
#     s2 = [(x - 1) / 9 for x in s2]
#     coh = d['coherence_scores']
#     coh = [(x - 1) / 9 for x in coh]
#     total = sum([a*b*c for a, b, c in zip(s1, s2, coh)])
#     total = total / len(s1)
#     c_per_instance_mults[i, j] = total


# fig = px.imshow(c_per_instance_mults, x=c_scales, y=c_scales,
#             title=f"(clamp) per-instance mult",labels={'x': feature_descriptions[1], 'y': feature_descriptions[0]},
#             color_continuous_scale="RdBu", color_continuous_midpoint=0)
# fig.show()

## Correlations
tbh I'm not sure correlation is what we want

In [427]:
correlations = torch.zeros((len(scales), len(scales)))
for d in gen_log:
    scale1, scale2 = d['scales']
    i = scales.index(scale1)
    j = scales.index(scale2)
    s1 = d['scores_1']
    s1 = [(x - 1) / 9 for x in s1]
    s2 = d['scores_2']
    s2 = [(x - 1) / 9 for x in s2]

    array1 = np.array(s1)
    array2 = np.array(s2)
    correlation_matrix = np.corrcoef(array1, array2)
    correlation = correlation_matrix[0, 1]

    correlations[i, j] = correlation


# c_correlations = torch.zeros((len(c_scales), len(c_scales)))
# for d in c_gen_log:
#     scale1, scale2 = d['scales']
#     i = scales.index(scale1)
#     j = scales.index(scale2)
#     s1 = d['scores_1']
#     s1 = [(x - 1) / 9 for x in s1]
#     s2 = d['scores_2']
#     s2 = [(x - 1) / 9 for x in s2]

#     array1 = np.array(s1)
#     array2 = np.array(s2)
#     correlation_matrix = np.corrcoef(array1, array2)
#     correlation = correlation_matrix[0, 1]

#     c_correlations[i, j] = correlation



In [428]:
min_correlation = np.nanmin(correlations)
max_correlation = np.nanmax(correlations)

In [429]:
# fig = make_subplots(rows=1, cols=2, subplot_titles=("Adding", "Clamping"), horizontal_spacing=0.2)
# fig.add_trace(go.Heatmap(z=correlations[:13, :13], x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=0.40), zmin=min_correlation, zmax=max_correlation), row=1, col=1)
# fig.add_trace(go.Heatmap(z=c_correlations, x=scales, y=scales, colorscale="RdBu", zmid=0, colorbar=dict(title="", x=1.0), zmin=min_correlation, zmax=max_correlation), row=1, col=2)

# fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=1)
# fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=1, autorange='reversed')
# fig.update_xaxes(title_text=feature_descriptions[1], row=1, col=2)
# # fig.update_yaxes(title_text=feature_descriptions[0], row=1, col=2, autorange='reversed')
# fig.update_yaxes(title_text="", row=1, col=2, autorange='reversed')

# # Add borders around the heatmaps
# fig.add_shape(type="rect",
#               x0=0, x1=0.4, y0=0, y1=1,
#               line=dict(color="Black", width=2),
#               xref="paper", yref="paper")
# fig.add_shape(type="rect",
#               x0=0.6, x1=1, y0=0, y1=1,
#               line=dict(color="Black", width=2),
#               xref="paper", yref="paper")

# fig.update_layout(title_text="Score correlations")
# fig.show()

In [430]:
# Create a single heatmap for the left plot
fig = go.Figure()

# Add the heatmap trace for the left plot
fig.add_trace(go.Heatmap(
    z=correlations[:13, :13], 
    x=scales, 
    y=scales, 
    colorscale="RdBu", 
    zmid=0, 
    colorbar=dict(title=""), 
    zmin=min_correlation, 
    zmax=max_correlation
))

# Update axes titles and y-axis range
fig.update_xaxes(title_text=feature_descriptions[1], scaleanchor="y", scaleratio=1)
fig.update_yaxes(title_text=feature_descriptions[0], autorange='reversed', scaleanchor="x", scaleratio=1)

# Add a border around the heatmap
fig.add_shape(type="rect",
              x0=0, x1=1, y0=0, y1=1,
              line=dict(color="Black", width=2),
              xref="paper", yref="paper")

# Update layout with title
fig.update_layout(
    title_text="Score correlations",
    autosize=False,
    width=500,  # Set the width to a fixed size
    height=500  # Set the height to a fixed size
)

# Show the plot
fig.show()
fig.write_image(f"{save_dir}/correlations.png")
