<br>

## Installlation and imports

Install `esm` (https://github.com/facebookresearch/esm)

In [1]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-0.4.0-py3-none-any.whl (37 kB)
Installing collected packages: fair-esm
Successfully installed fair-esm-0.4.0


<br>

Import libraries

In [2]:
import pandas as pd
import torch
import esm
import time

In [3]:
from torch.utils.data import Dataset, DataLoader

In [4]:
from sklearn.preprocessing import LabelEncoder

<br>

Init pre-trained ESM model and alphabet (this may take about 5min)

In [5]:
# esm_model, alphabet = torch.hub.load(
#     "facebookresearch/esm:main", 
#     "esm1b_t33_650M_UR50S"
# )

In [None]:
model, alphabet = esm.pretrained.esm.pretrained.esm1b_t33_650M_UR50S()

<br>

Init batch converter

In [None]:
batch_converter = alphabet.get_batch_converter()

<br>

## Load and pre-process data

In [None]:
seq_data = pd.read_csv(
    'results.csv', 
    skiprows = 1, 
    names = ['ECnumber', 'Sequence', 'Specimen']
)
print(seq_data.shape)

<br>

Perform data pre-processing

In [None]:
seq_data = seq_data[seq_data['Sequence'].apply(len) <= 1000]

In [None]:
ec_encoder = LabelEncoder()
seq_data['EClabel'] = ec_encoder.fit_transform(seq_data['ECnumber'])

In [None]:
print(seq_data.shape)

In [None]:
seq_data.head()

<br>

## Create dataset

In [None]:
class protDataset(Dataset):
    def __init__(self, labels, sequences, tokens):
        super().__init__()
        assert len(labels) == len(sequences)
        assert len(labels) == tokens.shape[0]
        
        # Add dataset size
        self.n = len(labels)

        # Add labels and token to class
        self.labels = torch.as_tensor(labels)
        self.tokens = tokens

        # Get sequences lengths
        lengths = [len(seq) for seq in sequences]
        self.lengths = torch.as_tensor(lengths)
            
    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.labels[idx], self.lengths[idx], self.tokens[idx,:]

In [None]:
cols = ['EClabel', 'Sequence']
seq_data_list = seq_data[cols].values.tolist()
labels, sequences, tokens = batch_converter(seq_data_list)
dataset = protDataset(labels, sequences, tokens)

In [15]:
print(len(dataset))

NameError: ignored

<br>

## Get sequence representations

In [63]:
class sequenceEmbedder():
    def __init__(self, model, dataset, batch_size = 16, num_layers = 33, device = 'cuda'):
        self.device = device
        self.model = model
        self.num_layers = num_layers
        self.batch_size = batch_size

        # Initiate data loader
        self.loader = DataLoader(
            dataset, 
            batch_size=self.batch_size,
            shuffle = False
        )
        
        # Set to eval model
        self.model.eval()

        # Set device
        if self.device == 'cuda':
            torch.cuda.empty_cache()
            self.model.cuda()
        else:
            self.model.cpu()

    def average_tokens(self, lengths, representations):
        seq_reps = []
        for i, l in enumerate(lengths):
            k = l + 1
            representation = representations[i]           
            r = representation[1:k,:].mean(0)
            seq_reps.append(r)
        seq_reps = torch.stack(seq_reps)
        return seq_reps

    def get_embeddings(self, verbosity = 100):
        # Init containers         
        embeddings = [] 
        all_labels = []        
        with torch.no_grad():
            start = time.time()
            for i, (labels, lengths, tokens) in enumerate(self.loader):
                # Set device
                if self.device == 'cuda':
                    tokens = tokens.cuda()

                # Get ESM results
                res = self.model(
                    tokens, 
                    repr_layers = [self.num_layers], 
                    return_contacts = False
                )

                # Extract token representations
                token_representations = res['representations'][self.num_layers]

                # Average across tokens
                sequence_representations = self.average_tokens(lengths, token_representations)

                # Detach and bring to cpu (unload gpu)
                sequence_representations = sequence_representations.detach().cpu()

                # Add to containers
                embeddings.append(sequence_representations)
                all_labels.append(labels)

                if i % verbosity == 0:
                    print('Batch no. {}; Sequence no.: {}; Elapsed time: {:1.2f}'.format(i + 1, (i + 1) * self.batch_size, time.time() - start))

            # Concat
            embeddings = torch.cat(embeddings)
            all_labels = torch.cat(all_labels)

        return embeddings, all_labels





In [64]:
embedder = sequenceEmbedder(esm_model, dataset, batch_size = 16, device = 'cuda')

In [None]:
embeddings = embedder.get_embeddings(verbosity = 5)

Batch no. 1; Sequence no.: 16; Elapsed time: 2.91
Batch no. 6; Sequence no.: 96; Elapsed time: 82.02
Batch no. 11; Sequence no.: 176; Elapsed time: 161.34
Batch no. 16; Sequence no.: 256; Elapsed time: 240.69
Batch no. 21; Sequence no.: 336; Elapsed time: 319.22
Batch no. 26; Sequence no.: 416; Elapsed time: 398.52
Batch no. 31; Sequence no.: 496; Elapsed time: 477.90
Batch no. 36; Sequence no.: 576; Elapsed time: 557.26
Batch no. 41; Sequence no.: 656; Elapsed time: 636.58
Batch no. 46; Sequence no.: 736; Elapsed time: 715.84
Batch no. 51; Sequence no.: 816; Elapsed time: 795.24
Batch no. 56; Sequence no.: 896; Elapsed time: 874.53
Batch no. 61; Sequence no.: 976; Elapsed time: 953.91
Batch no. 66; Sequence no.: 1056; Elapsed time: 1033.08
Batch no. 71; Sequence no.: 1136; Elapsed time: 1111.26
Batch no. 76; Sequence no.: 1216; Elapsed time: 1189.85
Batch no. 81; Sequence no.: 1296; Elapsed time: 1268.86
Batch no. 86; Sequence no.: 1376; Elapsed time: 1348.09
Batch no. 91; Sequence no