In [1]:
import torch
import esm

# Load ESM-1b model
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1b_t33_650M_UR50S")
model.cuda()
batch_converter = alphabet.get_batch_converter()

Using cache found in /home/step/.cache/torch/hub/facebookresearch_esm_master


In [6]:
import pandas as pd
import numpy as np

import pickle
from pathlib import Path
from tqdm import tqdm
from torch.utils.data import DataLoader


import random

def predict(sequence):
    
    # only one sequence
    if np.array(sequence).shape == ():
        sequence = sequence.upper()
        data = [("0", sequence)]
    else:
        data = sequence
    
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    with torch.no_grad():
        results = model(batch_tokens.cuda(), repr_layers=[33])
    return results
    
def get_embedding(sequence, sequence_emb = True):
    results = predict(sequence)
    
    token_embeddings = results["representations"][33] # shape: (bs, seq_len, emb_dim)
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    
    if sequence_emb: # one emb for the whole seq => average over seq_len
        if token_embeddings.shape[0] == 1:
            sequence_embeddings = (
                token_embeddings[0, 1 : len(sequence) + 1].mean(dim=0).cpu().numpy()
            )
        else:
            sequence_embeddings = (
                token_embeddings[:, 1 : len(sequence) + 1].mean(dim=0).cpu().numpy()
            )
            
        return sequence_embeddings
    
    return token_embeddings
    

def generate_embeddings(path_in, path_out = None, kind = 'train', bs = 32, subset = None):
    df = open_sets(path_in)[kind]
    
    if subset is not None:
        random.choice(df.index)
        raise NotImplementedError
        
    # batch the dataset and return a batch of the form [(label, seq), ...]
    dl = DataLoader(
        list(df.itertuples(index=False, name=None)), 
        batch_size=bs, 
        collate_fn=lambda batch: batch
    )

    embeddings = {}
    for batch in tqdm(dl):
        batch_embeddings = get_embedding(batch)
        
        if len(batch_embeddings.shape) == 2:            
            for (label, seq), emb in zip(batch, batch_embeddings):
                embeddings[seq] = emb
        else:
            label, seq = zip(*batch)
            embeddings[seq[0]] = batch_embeddings
        
    if path_out is not None:
        pickle.dump(embeddings, open(path_out, 'wb'))
    
#     df['embeddings'] = list(embeddings.values())
    
    return embeddings, df



def open_sets(base_path):
    sets = {}

    for path in base_path.glob('*.csv'):
        fname = path.stem
        kind = fname.split('_')[1]

        df = pd.read_csv(path)
        cols = df.columns
        
        sets[kind] = df[cols[::-1]]
        
    return sets

data_path = Path('data')
sets = open_sets(data_path)
sets['train']

Unnamed: 0,consensus_stability_score,sequence
0,0.37,GSSQETIEVEDEEEARRVAKELRKKGYEVKDERRGNKWHVHRT
1,0.62,TLDEARELVERAKKEGTGMDVNGQRFEDWREAERWVREQEKNK
2,-0.03,TELKKKLEEALKKGEEVRVKFNGIEIRNTSEDAARKAVELLEK
3,1.41,GSSQETIEVEDEEEARRVAKELRKTGYEVKIERRGNKWHVHRT
4,1.11,TTIHVGDLTLKYDNPKKAYEIAKKLAKKYNLQVTIKNGKITVT
...,...,...
7705,0.80,GSSKTQYEYDTKEEHQKAYEKFKKQGIPVTITQKNGKWFVQVE
7706,0.82,TIDEIIKALEQAVKDNKPIQVGNYTVTSADEAEKLAKKLKKPY
7707,0.66,TQDEIIKALEQAVKDNKPIQVGNYTVTSADEAEKLAKKLKKEY
7708,1.05,TTIKVNGQEYTVPLSPEQAAKAAKKRWPDYEVQIHGNTVWVTR


In [13]:
emb = generate_embeddings(data_path, bs=64)
model.cpu();

In [14]:
torch.cuda.empty_cache()