In [1]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

import plotly.express as px

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7fa5f0126200>

In [2]:
from transformers import pipeline
sentiment_pipeline = pipeline("sentiment-analysis")

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [3]:
# model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-small', device='cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu' # mps will break when using model.generate()
model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-small', device=device)


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
# layer = 1
# prompt_pos = "Yes, I talk about wedding constantly"
# prompt_neg = "I do not talk about wedding constantly"
# prompt_pos = "Love "
# prompt_neg = "Hate"
prompt_pos = "Anger"
prompt_neg = "Calm"

In [5]:
logits, pos_cache = model.run_with_cache(prompt_pos)
# h_p = pos_cache["resid_pre", layer]

logits, neg_cache = model.run_with_cache(prompt_neg)
# h_n = neg_cache["resid_pre", layer]

# print(h_p.shape, h_n.shape)
# steering = h_p - h_n
# steering.shape

In [6]:
# c = 5
def residual_stream_patching_hook(
    resid,
    hook,
    c
):
    h_p = pos_cache[hook.name]
    h_n = neg_cache[hook.name]
    steering = h_p - h_n

    len_steering = steering.shape[1]
    
    # resid shape is (batch, pos, d_model)
    # print(resid.shape, steering.shape)
    resid[:, :len_steering, :] =  resid[:, :len_steering, :] + c * steering

    return resid

In [7]:
text = "I think you're"
n_samples = 7
hook_name = utils.get_act_name("resid_pre", 7)

with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook, c=5))]):
    # model('testing testing asdf;lkjasdf;l k')
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        print(output)

I think you're right. I've collected over a thousand links from e-commerce sites, including Yelp, Best Buy
I think you're missing the point. The situation is that I don't get paid enough to take care of Barry.
I think you're correct. These factors in turn could have created a more violent and punitive world in which bullies become society
I think you're saying that MSFT is a triple-income income tax that collects collected income from the government. That
I think you're right that most white first ancestries are given to the conquistadors and a few "s
I think you're likely expecting people such as me to be critical of companies that do good things for consumers, or because
I think you're right. Regardless of which side you're on, nationally there are many who will conveniently ignore anything that


In [8]:
love_words = ["love", "like", "adore", "enjoy", "appreciate", "cherish", "admire", "care", "fancy", "favor", "prefer"]
hate_words = ["hate", "dislike", "detest", "abhor", "despise", "scorn", "loathe", "despise", "fuck you"]

def compute_metric(positive_words, layer, n_samples, factor):
    count = 0
    hook_name = utils.get_act_name("resid_pre", layer)

    with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook, c=factor))]):
        for i in range(n_samples):
            output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=25, verbose=False)
            for word in positive_words:
                if word in output.lower():
                    count += 1
                    break
        
    return count/n_samples

In [9]:
# for l in range(model.cfg.n_layers):
#     n_samples = 10
#     score = compute_metric(hate_words, l, n_samples, factor=10)
#     print(f"layer: {l}, socre: {score}")

In [10]:
layer = 7 # pick a layer you want.

hook_name = utils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_name)

print(saes.keys())
sae = saes[hook_name]
sea = sae.to(model.W_E.device)

100%|██████████| 1/1 [00:00<00:00,  1.14it/s]

dict_keys(['blocks.7.hook_resid_pre'])





In [11]:
# cache anger and then pass through sae, find anger feature, add anger feature during the forward pass.

logits, cache = model.run_with_cache("Anger")
anger_hidden_state = cache[hook_name][0, -1, :].unsqueeze(0)

feature_acts = sae(anger_hidden_state).feature_acts[0]
print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')


Num of activated features: 78

Top 10 features:
tensor([18.4649, 16.4535, 12.0989, 11.0684,  7.7472,  7.2738,  5.0492,  4.7868,
         4.7161,  4.6675], device='cuda:0')
tensor([16077, 21456,  6857, 23357, 19453, 14237, 12147, 21901, 20881,  9111],
       device='cuda:0')

L1 contribution of top feature: 0.10172756761312485


In [12]:
target_feature = top_ids[0]

# only top feature
steering = sae.W_dec[top_ids[0]] * top_values[0] # shape: [d_model]
# top 5 features
# steering = torch.stack([sae.W_dec[top_ids[i]] * top_values[i] for i in range(5)]).sum(dim=0)

steering = steering.to(model.cfg.device)
steering.shape

torch.Size([768])

In [13]:
# c = 5
def residual_stream_patching_hook_sae(
    resid,
    hook,
    c,
    pos
):    
    # resid shape is (batch, pos, d_model)
    # print(resid.shape, steering.shape)
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * steering

    return resid


text = "I think you're"
n_samples = 7

with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook_sae, c=20, pos=0))]):
    # model('testing testing asdf;lkjasdf;l k')
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        print(output)

I think you're well aware that equality-minded people are angry. What we don't know is the reality of who
I think you're going to wait for someone else to come up with some more outrageous stories about your chest. Seriously,
I think you're going mad. Angry angry Angry boiling angry is this; rage and anger is different things. Anger insults
I think you're right on twitter. Reports are spreading like an angry mob at people protesting Dhia. Some links to
I think you're now starting to understand how abusive a million angry upset users feel about this. I also think you're
I think you're seeing reactions (like [too many] comments from angry people who think these things riot like idiots.)
I think you're right about one thing about this charade that broke countless people within it. The mainstream media and its


In [14]:
def n_negative(sentences):
    data = sentiment_pipeline(sentences)
    n_neg = sum([1 for d in data if d['label'] == 'NEGATIVE'])
    # print("================")
    # print(sentences)
    # print(data)
    # print("================")
    return n_neg

In [41]:
love_words = ["love", "like", "adore", "enjoy", "appreciate", "cherish", "admire", "care", "fancy", "favor", "prefer"]
angry_words = ["angry", "mad", "pissed", "irritated", "annoyed", "frustrated", "enraged", "furious", "infuriated", "outraged", "resentful", "agitated", "aggravated", "bitter", "hate", "fuck"]
# angry_words = [
#     'hate', 'despise', 'loathe', 'detest', 'abhor', 'resent', 'dislike', 'disgust', 'anger', 'fury',
#     'rage', 'wrath', 'irritation', 'annoyance', 'frustration', 'bitterness', 'spite', 'vengeance',
#     'revenge', 'resentment', 'hostility', 'animosity', 'contempt', 'scorn', 'disdain', 'malice',
#     'venom', 'rancor', 'enmity', 'aversion', 'repulsion', 'revulsion', 'antipathy', 'ire',
#     'indignation', 'outrage', 'mad', 'furious', 'infuriated', 'incensed', 'irate', 'livid',
#     'seething', 'enraged', 'ballistic', 'pissed', 'aggravated', 'exasperated', 'disgusted',
#     'appalled', 'revolted', 'sickened', 'nauseated', 'fed up', 'sick and tired', 'bitter',
#     'vengeful', 'spiteful', 'vindictive', 'hostile', 'antagonistic', 'contemptuous', 'scornful',
#     'disdainful', 'evil', 'cruel', 'mean', 'nasty', 'vicious', 'vile', 'wicked', 'malicious',
#     'malevolent', 'hateful', 'venomous', 'caustic', 'virulent', 'toxic', 'noxious', 'poisonous',
#     'vitriolic', 'acrimonious'
# ]

batch_size = 128  ###
tokens = model.to_tokens(text, prepend_bos=True)
# print(tokens.shape)
batch_tokens = tokens.repeat((batch_size, 1))
# print(batch_tokens)

def compute_metric(related_words, n_samples, coef, pos, max_new_tokens=20):
    count = 0

    with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook_sae, c=coef, pos=pos))]):
        for i in range(n_samples):
            output = model.generate(batch_tokens, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=max_new_tokens, verbose=False)
            strings = model.to_string(output)

            count += n_negative(strings)

            # for s in strings:
            #     for word in related_words:
            #         if word in s.lower():
            #             count += 1
            #             break
        
    return count/(n_samples * batch_size)


n_samples = 2
max_new_tokens = 20
n_positions = 5 # 10 
cs = [0, 0.5, 1, 5, 7, 10, 15, 20, 30, 50] ###

score_matrix = torch.zeros((n_positions, len(cs)))
for pos in range(n_positions):
    for ci, c in enumerate(cs):
        score = compute_metric(angry_words, n_samples, c, pos, max_new_tokens)
        score_matrix[pos, ci] = score

        print(f'pos: {pos}, c: {c}, socre: {score}')

pos: 0, c: 0, socre: 0.53125
pos: 0, c: 0.5, socre: 0.51953125
pos: 0, c: 1, socre: 0.57421875
pos: 0, c: 5, socre: 0.58203125
pos: 0, c: 7, socre: 0.60546875
pos: 0, c: 10, socre: 0.6328125
pos: 0, c: 15, socre: 0.6953125
pos: 0, c: 20, socre: 0.78515625
pos: 0, c: 30, socre: 0.96484375
pos: 0, c: 50, socre: 1.0
pos: 1, c: 0, socre: 0.6328125
pos: 1, c: 0.5, socre: 0.60546875
pos: 1, c: 1, socre: 0.5390625
pos: 1, c: 5, socre: 0.6015625
pos: 1, c: 7, socre: 0.60546875
pos: 1, c: 10, socre: 0.56640625
pos: 1, c: 15, socre: 0.63671875
pos: 1, c: 20, socre: 0.671875
pos: 1, c: 30, socre: 0.61328125
pos: 1, c: 50, socre: 0.69140625
pos: 2, c: 0, socre: 0.546875
pos: 2, c: 0.5, socre: 0.5390625
pos: 2, c: 1, socre: 0.59375
pos: 2, c: 5, socre: 0.62890625
pos: 2, c: 7, socre: 0.71484375
pos: 2, c: 10, socre: 0.7109375
pos: 2, c: 15, socre: 0.71484375
pos: 2, c: 20, socre: 0.6875
pos: 2, c: 30, socre: 0.7265625
pos: 2, c: 50, socre: 0.74609375
pos: 3, c: 0, socre: 0.52734375
pos: 3, c: 0.5, 

In [42]:
toks = model.to_str_tokens(text)
x_labels = toks + [f"pos_{i}" for i in range(len(toks), n_positions)]
fig = px.imshow(score_matrix, y=x_labels, x=[str(c) for c in cs], color_continuous_scale="RdBu", color_continuous_midpoint=0)
fig.show()

In [43]:

sum_over_pos = score_matrix.mean(0)
sum_over_pos.shape

px.line(y=sum_over_pos, x=cs, title="Sum of scores over positions", markers=True, labels={'x': "coefficient", "y": "angry score"}).show()

In [44]:
sum_over_pos = score_matrix.mean(0)
sum_over_pos.shape


px.line(y=score_matrix[:, 2:].mean(1), x=x_labels, title="Sum of scores over positions", markers=True, labels={'x': "pos", "y": "angry score"}).show()

Compute model loss using pile-10k dataset

In [45]:
# load data manually to allow loading a subset of the data
pile_data = load_dataset("NeelNanda/pile-10k", split="train[:1%]")
# pile_data = pile_data.select(range(1)) # add this because my computer is extremely slow
print(len(pile_data))

dataset = utils.tokenize_and_concatenate(pile_data, model.tokenizer)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=True)

100


In [46]:
# loss for original forward pass

running_loss = 0
total = 0
max_length = 30
for batch in tqdm.tqdm(data_loader):
    loss = model(batch["tokens"][:, :max_length].to(device), return_type="loss").mean()
    running_loss += loss.item()
    total += 1
    # print(loss)

print(running_loss / total)

100%|██████████| 116/116 [00:02<00:00, 55.63it/s]

4.627064813827646





In [47]:
d = next(iter(data_loader))
model.to_str_tokens(d["tokens"])

['<|endoftext|>',
 'It',
 ' is',
 ' done',
 ',',
 ' and',
 ' submitted',
 '.',
 ' You',
 ' can',
 ' play',
 ' �',
 '�',
 'Surv',
 'ival',
 ' of',
 ' the',
 ' T',
 'ast',
 'iest',
 '�',
 '�',
 ' on',
 ' Android',
 ',',
 ' and',
 ' on',
 ' the',
 ' web',
 '.',
 ' Playing',
 ' on',
 ' the',
 ' web',
 ' works',
 ',',
 ' but',
 ' you',
 ' have',
 ' to',
 ' simulate',
 ' multi',
 '-',
 'touch',
 ' for',
 ' table',
 ' moving',
 ' and',
 ' that',
 ' can',
 ' be',
 ' a',
 ' bit',
 ' confusing',
 '.',
 '\n',
 '\n',
 'There',
 '�',
 '�',
 's',
 ' a',
 ' lot',
 ' I',
 '�',
 '�',
 'd',
 ' like',
 ' to',
 ' talk',
 ' about',
 '.',
 ' I',
 '�',
 '�',
 'll',
 ' go',
 ' through',
 ' every',
 ' topic',
 ',',
 ' inst',
 'ed',
 ' of',
 ' making',
 ' the',
 ' typical',
 ' what',
 ' went',
 ' right',
 '/',
 'wrong',
 ' list',
 '.',
 '\n',
 '\n',
 'Con',
 'cept',
 '\n',
 '\n',
 'Working',
 ' over',
 ' the',
 ' theme',
 ' was',
 ' probably',
 ' one',
 ' of',
 ' the',
 ' hardest',
 ' tasks',
 ' I',
 ' had',
 '

In [48]:
# # Testing code calculating loss for fixed coef and pos
# coef = 20
# pos = 0

# with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook_sae, c=coef, pos=pos))]):
#     running_loss = 0
#     total = 0
#     for batch in tqdm.tqdm(data_loader):
#         loss = model(batch["tokens"].to(device), return_type="loss").mean()
#         running_loss += loss.item()
#         total += 1
#         # print(loss)

#     print(running_loss / total)


loss_matrix = torch.zeros((n_positions, len(cs)))

for pos in range(n_positions):
    for ci, c in enumerate(cs):

        with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook_sae, c=c, pos=pos))]):
            running_loss = 0
            total = 0
            for batch in tqdm.tqdm(data_loader):
                loss = model(batch["tokens"][:, :max_length].to(device), return_type="loss").mean()
                running_loss += loss.item()
                total += 1
                # print(loss)
            
            loss_matrix[pos, ci] = running_loss / total

        print(f'pos: {pos}, c: {c}, loss: {loss_matrix[pos, ci]}')

100%|██████████| 116/116 [00:02<00:00, 57.10it/s]


pos: 0, c: 0, loss: 4.6270647048950195


100%|██████████| 116/116 [00:02<00:00, 57.05it/s]


pos: 0, c: 0.5, loss: 4.619467258453369


100%|██████████| 116/116 [00:02<00:00, 56.15it/s]


pos: 0, c: 1, loss: 4.612154006958008


100%|██████████| 116/116 [00:02<00:00, 56.72it/s]


pos: 0, c: 5, loss: 4.593219757080078


100%|██████████| 116/116 [00:02<00:00, 56.80it/s]


pos: 0, c: 7, loss: 4.612366199493408


100%|██████████| 116/116 [00:02<00:00, 56.80it/s]


pos: 0, c: 10, loss: 4.652684688568115


100%|██████████| 116/116 [00:02<00:00, 57.21it/s]


pos: 0, c: 15, loss: 4.731049537658691


100%|██████████| 116/116 [00:02<00:00, 57.19it/s]


pos: 0, c: 20, loss: 4.834829330444336


100%|██████████| 116/116 [00:02<00:00, 56.15it/s]


pos: 0, c: 30, loss: 5.20871114730835


100%|██████████| 116/116 [00:02<00:00, 55.67it/s]


pos: 0, c: 50, loss: 6.623801231384277


100%|██████████| 116/116 [00:02<00:00, 57.07it/s]


pos: 1, c: 0, loss: 4.6270647048950195


100%|██████████| 116/116 [00:02<00:00, 57.28it/s]


pos: 1, c: 0.5, loss: 4.625704288482666


100%|██████████| 116/116 [00:02<00:00, 57.15it/s]


pos: 1, c: 1, loss: 4.624902725219727


100%|██████████| 116/116 [00:02<00:00, 57.20it/s]


pos: 1, c: 5, loss: 4.652102947235107


100%|██████████| 116/116 [00:02<00:00, 57.19it/s]


pos: 1, c: 7, loss: 4.686624050140381


100%|██████████| 116/116 [00:02<00:00, 57.20it/s]


pos: 1, c: 10, loss: 4.726689338684082


100%|██████████| 116/116 [00:02<00:00, 57.01it/s]


pos: 1, c: 15, loss: 4.773638725280762


100%|██████████| 116/116 [00:02<00:00, 56.98it/s]


pos: 1, c: 20, loss: 4.813216686248779


100%|██████████| 116/116 [00:02<00:00, 56.90it/s]


pos: 1, c: 30, loss: 4.8774237632751465


100%|██████████| 116/116 [00:02<00:00, 57.23it/s]


pos: 1, c: 50, loss: 4.970141887664795


100%|██████████| 116/116 [00:02<00:00, 56.79it/s]


pos: 2, c: 0, loss: 4.6270647048950195


100%|██████████| 116/116 [00:02<00:00, 56.80it/s]


pos: 2, c: 0.5, loss: 4.626372814178467


100%|██████████| 116/116 [00:02<00:00, 56.50it/s]


pos: 2, c: 1, loss: 4.626222133636475


100%|██████████| 116/116 [00:02<00:00, 56.47it/s]


pos: 2, c: 5, loss: 4.665950775146484


100%|██████████| 116/116 [00:02<00:00, 56.30it/s]


pos: 2, c: 7, loss: 4.703161239624023


100%|██████████| 116/116 [00:02<00:00, 56.34it/s]


pos: 2, c: 10, loss: 4.747038841247559


100%|██████████| 116/116 [00:02<00:00, 57.40it/s]


pos: 2, c: 15, loss: 4.796876907348633


100%|██████████| 116/116 [00:02<00:00, 56.87it/s]


pos: 2, c: 20, loss: 4.834746837615967


100%|██████████| 116/116 [00:02<00:00, 57.06it/s]


pos: 2, c: 30, loss: 4.892271995544434


100%|██████████| 116/116 [00:02<00:00, 56.20it/s]


pos: 2, c: 50, loss: 4.97376823425293


100%|██████████| 116/116 [00:02<00:00, 57.03it/s]


pos: 3, c: 0, loss: 4.6270647048950195


100%|██████████| 116/116 [00:02<00:00, 57.08it/s]


pos: 3, c: 0.5, loss: 4.626429557800293


100%|██████████| 116/116 [00:02<00:00, 56.83it/s]


pos: 3, c: 1, loss: 4.626209735870361


100%|██████████| 116/116 [00:02<00:00, 56.69it/s]


pos: 3, c: 5, loss: 4.667377948760986


100%|██████████| 116/116 [00:02<00:00, 57.01it/s]


pos: 3, c: 7, loss: 4.700442790985107


100%|██████████| 116/116 [00:02<00:00, 56.89it/s]


pos: 3, c: 10, loss: 4.741722106933594


100%|██████████| 116/116 [00:02<00:00, 56.29it/s]


pos: 3, c: 15, loss: 4.785483360290527


100%|██████████| 116/116 [00:02<00:00, 56.86it/s]


pos: 3, c: 20, loss: 4.819249629974365


100%|██████████| 116/116 [00:02<00:00, 56.82it/s]


pos: 3, c: 30, loss: 4.871582984924316


100%|██████████| 116/116 [00:02<00:00, 56.87it/s]


pos: 3, c: 50, loss: 4.948605537414551


100%|██████████| 116/116 [00:02<00:00, 56.91it/s]


pos: 4, c: 0, loss: 4.6270647048950195


100%|██████████| 116/116 [00:02<00:00, 57.17it/s]


pos: 4, c: 0.5, loss: 4.627649307250977


100%|██████████| 116/116 [00:02<00:00, 57.03it/s]


pos: 4, c: 1, loss: 4.628964424133301


100%|██████████| 116/116 [00:02<00:00, 55.95it/s]


pos: 4, c: 5, loss: 4.677468299865723


100%|██████████| 116/116 [00:02<00:00, 56.66it/s]


pos: 4, c: 7, loss: 4.713242053985596


100%|██████████| 116/116 [00:02<00:00, 56.44it/s]


pos: 4, c: 10, loss: 4.754469394683838


100%|██████████| 116/116 [00:02<00:00, 56.32it/s]


pos: 4, c: 15, loss: 4.798036575317383


100%|██████████| 116/116 [00:02<00:00, 56.36it/s]


pos: 4, c: 20, loss: 4.831392765045166


100%|██████████| 116/116 [00:02<00:00, 56.29it/s]


pos: 4, c: 30, loss: 4.884396553039551


100%|██████████| 116/116 [00:02<00:00, 56.21it/s]

pos: 4, c: 50, loss: 4.961891174316406





In [49]:
# plot scatter plot of loss (x axis) vs score (y axis). with different colors for different c values, and different shapes for different pos values.

import pandas as pd
import numpy as np


n_c = len(cs)
n_pos = n_positions

cs_str = [str(c) for c in cs]
pos_str = [f"pos_{i}" for i in range(n_positions)]

# Random score and loss matrices
# np.random.seed(42)
# score_matrix = torch.randn(n_pos, n_c)
# loss_matrix = torch.randn(n_pos, n_c)

# Create a DataFrame
data = {
    'Loss': loss_matrix.numpy().flatten(),
    'Score': score_matrix.numpy().flatten(),
    'Position': np.repeat(pos_str, n_c), # ['a', 'b'] -> ['a', 'a', 'b', 'b']
    'Coef': np.tile(cs_str, n_pos) # ['a', 'b'] -> ['a', 'b', 'a', 'b']
}

df = pd.DataFrame(data)

# Map the position to marker shapes
# markers = ['circle', 'square', 'diamond', 'cross', 'x']
# df['Marker'] = df['Position'].apply(lambda x: markers[x % len(markers)])

# Plotting, with coef as color
fig = px.scatter(
    df, x='Loss', y='Score', color='Coef',
    labels={'Coef': 'Coef', 'Marker': 'Position'},
    title='Angry score vs Loss (Coef as color)',
    hover_data=['Position', 'Coef', 'Loss', 'Score'],
)

fig.update_traces(marker=dict(size=10))  # Adjust marker size
fig.show()

# line plot, one line for each coef, color by position, 
fig = px.scatter(
    df, x='Loss', y='Score', color='Position',
    labels={'Position': 'Position', 'Score': 'Score'},
    title='Angry score vs Loss (Position as color)',
    hover_data=['Position', 'Coef', 'Loss', 'Score'],
)

fig.update_traces(marker=dict(size=10))  # Adjust marker size
fig.update_traces(mode='lines+markers')
fig.show()



In [50]:
 # would be cool to plot attention score vs sentiment score.

In [51]:
score_matrix

tensor([[0.5312, 0.5195, 0.5742, 0.5820, 0.6055, 0.6328, 0.6953, 0.7852, 0.9648,
         1.0000],
        [0.6328, 0.6055, 0.5391, 0.6016, 0.6055, 0.5664, 0.6367, 0.6719, 0.6133,
         0.6914],
        [0.5469, 0.5391, 0.5938, 0.6289, 0.7148, 0.7109, 0.7148, 0.6875, 0.7266,
         0.7461],
        [0.5273, 0.5391, 0.5664, 0.5859, 0.6211, 0.6484, 0.6094, 0.6328, 0.6562,
         0.7109],
        [0.5078, 0.5312, 0.5430, 0.5938, 0.6484, 0.6094, 0.7148, 0.7109, 0.7227,
         0.7734]])

In [52]:
loss_matrix

tensor([[4.6271, 4.6195, 4.6122, 4.5932, 4.6124, 4.6527, 4.7310, 4.8348, 5.2087,
         6.6238],
        [4.6271, 4.6257, 4.6249, 4.6521, 4.6866, 4.7267, 4.7736, 4.8132, 4.8774,
         4.9701],
        [4.6271, 4.6264, 4.6262, 4.6660, 4.7032, 4.7470, 4.7969, 4.8347, 4.8923,
         4.9738],
        [4.6271, 4.6264, 4.6262, 4.6674, 4.7004, 4.7417, 4.7855, 4.8192, 4.8716,
         4.9486],
        [4.6271, 4.6276, 4.6290, 4.6775, 4.7132, 4.7545, 4.7980, 4.8314, 4.8844,
         4.9619]])