In [117]:
import os
import torch
import esm
from sae_model import SparseAutoencoder
from esm_wrapper import ESM2Model

In [118]:
D_MODEL = 1280
D_HIDDEN = 4096
SEQUENCE = 'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG'
device = 'cuda:0'
weights_dir = '/global/cfs/cdirs/m4351/ml5045/interp_weights'

esm2_weight = os.path.join(weights_dir, 'esm2_t33_650M_UR50D.pt')
sae_weight = os.path.join(weights_dir, 'sae_weights.pt')
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")

In [None]:
esm2_model = ESM2Model(num_layers=33, embed_dim=1280, attention_heads=20, 
                       alphabet=alphabet, token_dropout=False, device='cuda:0')
esm2_model.load_esm_ckpt(esm2_weight)
esm2_model = esm2_model.to(device)
sae_model = SparseAutoencoder(D_MODEL, D_HIDDEN).to(device)

In [120]:
embed = esm2_model.get_layer_activations(SEQUENCE, 24)

In [126]:
acts, mu, std = sae_model.encode(embed[0])

In [127]:
DIM = 220

In [None]:
acts.size(), mu.size(), std.size()

In [129]:
updated_latents = sae_model.decode(acts, mu, std)

In [130]:
logits = esm2_model.get_sequence(updated_latents.unsqueeze(0), 24)

In [None]:
logits

In [132]:
tokens = torch.argmax(logits[:,1:-1,4:24], dim=-1)  
sequences = [''.join([esm2_model.alphabet.all_toks[i+4] for i in sequence.tolist()]) for sequence in list(tokens)]

In [None]:
sequences[0]