### Notebook to generate embeddings using ESM2-3B finetuned with siamese network

In [1]:
from Bio import SeqIO
import esm
import torch
from tqdm import tqdm
from torch.nn import CosineSimilarity

In [2]:
seqs = [
    ("YP_009794187.1", "MDLSAIGFASKQFRVIPVEKGNLVTDFIQGKFQVIGVECNTRGAYGSPIQQSLSRRFPEM"),
    ("YP_009293175.1", "MLIFRDERHVEGDLFNAPETYKVITINCVGAMGKGIALACRERYPDLYENYRTRCRAGEI"),
    ("YP_009882144.1", "MIKQYVNYDLLDAFEHNDFDAIVHGCNCFHTMGAGIAGAIAKRFPVAVEADKKTEYGDWS"),
]

In [3]:
device = 'cpu'

In [None]:
model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()

In [None]:
model = model.to(device)

In [6]:
tokenizer = alphabet.get_batch_converter()

In [10]:
seqs[0]

('YP_009794187.1',
 'MDLSAIGFASKQFRVIPVEKGNLVTDFIQGKFQVIGVECNTRGAYGSPIQQSLSRRFPEM')

In [8]:
seq_id1, seq_str1, tokenized_input1 = tokenizer([seqs[0]])
seq_id2, seq_str2, tokenized_input2 = tokenizer([seqs[1]])

In [9]:
tokenized_input1

tensor([[ 0, 20, 13,  4,  8,  5, 12,  6, 18,  5,  8, 15, 16, 18, 10,  7, 12, 14,
          7,  9, 15,  6, 17,  4,  7, 11, 13, 18, 12, 16,  6, 15, 18, 16,  7, 12,
          6,  7,  9, 23, 17, 11, 10,  6,  5, 19,  6,  8, 14, 12, 16, 16,  8,  4,
          8, 10, 10, 18, 14,  9, 20,  2]])

In [11]:
finetuned_checkpoint = '3b_model_checkpoint.pt'

In [10]:
def remap(layers):
    new_state_dict = {}
    for name, param in layers.items():
        new_name = ".".join(name.split(".")[1:])
        new_state_dict[new_name] = param

    return new_state_dict

In [13]:
torch.cuda.empty_cache()

In [15]:
state_dict = remap(torch.load(finetuned_checkpoint))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [23]:
with torch.no_grad():
    embedding1 = model(tokenized_input1, repr_layers=[model.num_layers])["representations"][model.num_layers] # embedding is of shape: [1, seq_len+2, 2560]
    embedding2 = embedding1[0][1:-1] # to remove padding; resulting shape: [seqlen, 2560]

    embedding1 = model(tokenized_input2, repr_layers=[model.num_layers])["representations"][model.num_layers] # embedding is of shape: [1, seq_len+2, 2560]
    embedding2 = embedding2[0][1:-1]
    torch.save(embedding1, f"tmp/{seq_id1[0]}.pt")
    torch.save(embedding2, f"tmp/{seq_id2[0]}.pt")

----

In [2]:
# Specify the path to the FASTA file and the output path
fasta_file_path = "data/sample.fasta"
output_path = "tmp"

In [3]:
# Read sequences from the FASTA file and process them
records = list(SeqIO.parse(fasta_file_path, "fasta"))
num_records = len(records)  # number of sequences
print('Number of sequences = ', num_records)
print('Seq ids:')
print([record.id for record in records])

Number of sequences =  3
Seq ids:
['YP_009794187.1', 'YP_009293175.1', 'YP_009882144.1']


In [4]:
def remap(layers):
    new_state_dict = {}
    for name, param in layers.items():
        new_name = ".".join(name.split(".")[1:])
        new_state_dict[new_name] = param
    return new_state_dict

def get_model_and_tokenizer(weight_path, device='cpu', model_type=esm.pretrained.esm2_t36_3B_UR50D):
    model, alphabet = model_type()
    tokenizer = alphabet.get_batch_converter()
    state_dict = remap(torch.load(weight_path, map_location=device))
    model.load_state_dict(state_dict)
    return model, tokenizer

def run_model(model, tokenizer, sequence, layers=None, device='cpu'):
    if layers==None:
        layers = [model.num_layers] # last layer

    model.eval()
    with torch.no_grad():
        # pdb.set_trace()
        try:
            _, _, tokenized_inputs = tokenizer([sequence])
        except Exception as e:
            print(e)
            raise e
        tokenized_inputs = tokenized_inputs.to(device)
        model = model.to(device)
        output = {}
        output["mean_representations"] = {}
        output["representations"] = model(tokenized_inputs, repr_layers=layers)["representations"]
        # remove padding bos, eos
        for layer in layers:
            output["representations"][layer] = output["representations"][layer][:,1:-1,:] 
            output["mean_representations"][layer] = torch.mean(output["representations"][layer], dim=1)[0]
    return output

In [6]:
print('Loading model...')
model, tokenizer = get_model_and_tokenizer('../PooledAAEmbeddings/3b_model_checkpoint.pt', device='cpu')
print('Model loaded')

Loading model...
Model loaded


In [9]:
layers = [36]
# Process each sequence and save the embeddings
for record in tqdm(records, total=num_records, desc="Processing sequences"):
    seq_id = record.id
    sequence = str(record.seq)
    # NOTE: embeddings have a start and end padding token
    embedding = run_model(model, tokenizer, (seq_id, sequence), device='cpu')
    assert(len(sequence) == embedding["representations"][layers[0]].size()[1])
    # Save the embedding
    torch.save(embedding, f"{output_path}/{seq_id}.pt")

print('Done. Embeddings saved at:', output_path)



Processing sequences: 100%|██████████| 3/3 [01:44<00:00, 34.87s/it]

Done. Embeddings saved at: tmp





----

### Load Embeddings

In [18]:
seq_id1 = 'YP_009794187.1'
seq_id2 = 'YP_009293175.1'
seq_id3 = 'YP_009882144.1'

emb1 = torch.load(f'tmp/{seq_id1}.pt')
emb2 = torch.load(f'tmp/{seq_id2}.pt')
emb3 = torch.load(f'tmp/{seq_id3}.pt')

In [19]:
emb1_mean = emb1["mean_representations"][36]
emb2_mean = emb2["mean_representations"][36]
emb3_mean = emb3["mean_representations"][36]

### Compute Cosine Similarity
##### using: https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html

In [20]:
cos = CosineSimilarity(dim=0, eps=1e-6)

In [21]:
cos(emb1_mean, emb2_mean)

tensor(0.9778)

In [22]:
cos(emb1_mean, emb3_mean)

tensor(0.9779)

In [23]:
cos(emb2_mean, emb3_mean)

tensor(0.9798)