## Load model + weights -- > modify model

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
import esm

model, alphabet = torch.hub.load(repo_or_dir='esm/' ,model="esm1b_t33_650M_UR50S", source='local')
model.eval() #make reproducible results

# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

In [2]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG")
]

batch_labels, batch_strs, batch_tokens = batch_converter(data)

### Split the transformer model in two, getting access to the embedding layer

In [3]:
#split the model into each of its modules and put it in a list
splitted_model = []
for name, module in model.named_children():
    splitted_model.append(module)

# take the first layer
embedding_layer = splitted_model[0]
embeddings = embedding_layer(batch_tokens)

#replace the embedding layer with an Identity layer
identity_layer = torch.nn.Identity()
model.embed_tokens = identity_layer

#set token dropout to False
model.args.token_dropout=False

In [4]:
# meddle with embeddings here #
#insert code

In [5]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(embeddings, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

In [6]:
sequence_representations

[tensor([ 0.1786,  0.0513,  0.0074,  ..., -0.0532, -0.0705, -0.0256])]

## Example of changing the embedding values

### get embedding for amino acid X:

In [7]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
x_data = [
    ("protein1", "XX")
]
_ ,_ , x_token = batch_converter(x_data)
x_embedding = embedding_layer(x_token)

x_embedding = x_embedding[:,1,:]

x_embedding.shape

torch.Size([1, 1280])

In [8]:
embeddings.shape

torch.Size([1, 67, 1280])

### Replace the embedding of amino acid i with that of X

In [9]:
new_embedding = embeddings
new_embedding[:,i+1,:] = x_embedding