# some test code to generate esm2 embeddings
Code from https://kaggle.com/code/viktorfairuschin/extracting-esm-2-embeddings-from-fasta-files 

In [1]:
import pathlib
import torch
import pickle
from esm import FastaBatchedDataset, pretrained

In [45]:
# this code seems to batch from a fasta file better from the code later 
def extract_embeddings(model_name, fasta_file, output_dir, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    
    # read in the esm model 
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    # move model to gpu if available 
    if torch.cuda.is_available():
        model = model.cuda()
      
    # batch the fasta file 
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    # create data loader obj
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(), 
        batch_sampler=batches
    )

    # make output directory 
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # start processing batches 
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            # move tokens to gpu if available 
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            
            # Extract embeddings 
            with torch.no_grad():
                results = model(toks, repr_layers=repr_layers, return_contacts=False)
            token_representations = results["representations"][33]

            # update this to save dictionary for an entire fasta file 
            results = dict() 
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                
                filename = output_dir / f"{entry_id}.pt"
                
                representation = token_representations[i, 1 : len(strs[i]) - 1].mean(0)
                results[label] = representation 
    
                
            torch.save(results, filename) 
            print(results)
            break 
                

In [46]:
#test on a simple fasta file 
model_name = 'esm2_t33_650M_UR50D' # see which model this here corresponds to 
fasta_file = pathlib.Path('/home/grig0076/scratch/databases/millardlab_phages/4May2024/pharokka/chunk_1/phanotate.faa')
output_dir = pathlib.Path('train_embeddings')

extract_embeddings(model_name, fasta_file, output_dir)

Processing batch 1 of 1671
{'JQHJTWNH_CDS_0019 hypothetical protein': tensor([ 0.0664,  0.0853,  0.0053,  ...,  0.0562, -0.2026,  0.0552]), 'JQHJTWNH_CDS_0086 hypothetical protein': tensor([ 0.0675,  0.0432,  0.0613,  ...,  0.0466, -0.0917, -0.0452]), 'JQHJTWNH_CDS_0113 hypothetical protein': tensor([ 0.0664,  0.0853,  0.0053,  ...,  0.0562, -0.2026,  0.0552]), 'JQHJTWNH_CDS_0180 hypothetical protein': tensor([ 0.0675,  0.0432,  0.0613,  ...,  0.0466, -0.0917, -0.0452]), 'JQHJTWNH_CDS_0311 hypothetical protein': tensor([ 0.0202,  0.0422, -0.0027,  ..., -0.0105, -0.0526, -0.0751]), 'JQHJTWNH_CDS_0324 hypothetical protein': tensor([ 0.0498,  0.0677,  0.1151,  ...,  0.0385, -0.1930, -0.0097]), 'JQHJTWNH_CDS_0914 hypothetical protein': tensor([ 0.0608,  0.0276,  0.0244,  ...,  0.0234, -0.0759, -0.0523]), 'JQHJTWNH_CDS_0923 hypothetical protein': tensor([ 0.0323,  0.0105, -0.0528,  ..., -0.0311, -0.0141,  0.0359]), 'JQHJTWNH_CDS_0960 hypothetical protein': tensor([ 0.0883,  0.0284,  0.0568,

In [30]:
# using code from esm2 github https://github.com/facebookresearch/esm 

In [33]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results
if torch.cuda.is_available():
        model = model.cuda()
        print('cuda')

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))


In [32]:
sequence_representations

[tensor([ 0.0614, -0.0687,  0.0430,  ..., -0.1642, -0.0678,  0.0446]),
 tensor([ 0.0553, -0.0757,  0.0414,  ..., -0.3117, -0.0026,  0.1683]),
 tensor([ 0.0618, -0.0769,  0.0405,  ..., -0.3037, -0.0013,  0.1741]),
 tensor([ 0.0084,  0.1425,  0.0506,  ...,  0.0403, -0.1063,  0.0079])]

In [34]:
batch_lens

tensor([67, 73, 73,  8])