In [10]:


esm_model_names = [
    "esm2_t6_8M_UR50D",
    "esm2_t12_35M_UR50D",
    "esm2_t30_150M_UR50D",
    "esm2_t33_650M_UR50D",
    "esm2_t36_3B_UR50D",
    "esm2_t48_15B_UR50D",
]

esm_model_depth = {}
esm_model_size_string = {}
esm_model_size_int = {}

for name in esm_model_names:
    parts = name.split("_")
    depth = int(parts[1][1:])
    esm_model_depth[name] = depth
    size_string = parts[2]
    esm_model_size_string[name] = size_string
    size_num = int(size_string[:-1])
    units = size_string[-1]
    if units == "B":
        size_num *= 1000
    esm_model_size_int[name] = size_num

esm_checkpoints = {name: "facebook/%s" % name for name in esm_model_names}


smallest_model_name = sorted(esm_model_size_int.items(), key=lambda x: x[1])[0][0]
biggest_model_name = sorted(esm_model_size_int.items(), key=lambda x: x[1])[-1][0]

for name in esm_model_names:
    print("%20s: depth=%d, size=%s" % (name, esm_model_depth[name], esm_model_size_string[name]))
    
print("%20s" % ("---",))
print("%20s: %s" % ("Smallest model", smallest_model_name))

print("%20s: %s" % ("Biggest model", biggest_model_name))

    esm2_t6_8M_UR50D: depth=6, size=8M
  esm2_t12_35M_UR50D: depth=12, size=35M
 esm2_t30_150M_UR50D: depth=30, size=150M
 esm2_t33_650M_UR50D: depth=33, size=650M
   esm2_t36_3B_UR50D: depth=36, size=3B
  esm2_t48_15B_UR50D: depth=48, size=15B
                 ---
      Smallest model: esm2_t6_8M_UR50D
       Biggest model: esm2_t48_15B_UR50D


In [11]:
import pandas as pd
df = pd.read_csv("../data/swissprot-8mers.csv")

In [12]:
df

Unnamed: 0.1,Unnamed: 0,seq,archaea,bacteria,fungi,human,invertebrates,mammals,plants,rodents,vertebrates,viruses,label_count
0,0,MTMDKSEL,False,False,False,True,False,True,False,True,False,False,3
1,1,TMDKSELV,False,False,False,True,False,True,False,True,False,False,3
2,2,MDKSELVQ,False,False,False,True,False,True,False,True,True,False,4
3,3,DKSELVQK,False,False,False,True,False,True,False,True,True,False,4
4,4,KSELVQKA,False,False,False,True,False,True,False,True,True,False,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...
101851838,101851838,LNVLTGTQ,False,False,False,False,False,False,True,False,False,False,1
101851839,101851839,NVLTGTQE,False,False,False,False,False,False,True,False,False,False,1
101851840,101851840,VLTGTQEG,False,False,False,False,False,False,True,False,False,False,1
101851841,101851841,LTGTQEGL,False,False,False,False,False,False,True,False,False,False,1


In [18]:
sequences = df.seq
labels = df.human

In [19]:
from transformers import AutoTokenizer


esm_tokenizer = AutoTokenizer.from_pretrained(esm_checkpoints[smallest_model_name])

In [20]:
esm_tokenizer(df.seq[0])

{'input_ids': [0, 20, 11, 20, 13, 15, 8, 9, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [21]:
from tqdm import tqdm
import numpy as np 
import time

def tokenize(seqs, lookup=esm_tokenizer._token_to_id):
    start_token = lookup["<cls>"]
    end_token = lookup["<eos>"]
    pad_token = lookup["<pad>"]
    n = len(seqs)
    t0 = time.time()
    lengths = [len(s) for s in seqs]
    max_seq_length = max(lengths)
    t1 = time.time()
    print("Got sequence lengths in %0.2fs" % (t1 - t0))
    
    m = max_seq_length + 2 # format will be <cls> peptide <eos> 
    unique_token_ids = np.array(list(lookup.values()))
    min_token_id = unique_token_ids.min()
    assert min_token_id >= 0
    max_token_id = unique_token_ids.max()
    
    if max_token_id < 2 ** 8:
        dtype = 'uint8'
    elif max_token_id < 2 ** 16:
        dtype = 'uint16'
    else:
        raise ValueError("max token ID too large")
    
    result = np.empty(shape=(n, m), dtype=dtype)
    result.fill(pad_token)
    t2 = time.time()
    print("Created token_ids array (shape=%dx%d, bytes=%0.2fG) in %0.2fs" % (
        result.shape[0],
        result.shape[1],
        result.nbytes / (1024 * 1024 * 1024), t2 - t1))
    
    ascii_values = {ord(aa) for aa in lookup.keys() if len(aa) == 1}
    max_ascii_value = max(ascii_values)
    table = [pad_token] * (max_ascii_value + 1)
    for (token, token_id) in lookup.items():
        if len(token) > 1:
            continue
        table[ord(token)] = token_id
    
    t3 = time.time()
    print("Created list of token ID lookups in %0.2fs" % (
        t3 - t2,))
    
    # fill the first position of each token_ids sequence with the start token
    result[:, 0] = start_token

    for i, (seq, length) in tqdm(enumerate(zip(seqs, lengths))):
        result[i, 1:length + 1] = [lookup[aa] for aa in seq]
        result[i, length + 1] = end_token
    t4 = time.time()
    print("Filled token_ids array in %0.2fs" % (t4 - t3))
    return result


In [22]:
%time sequences_tokenized = tokenize(sequences)

Got sequence lengths in 5.99s
Created token_ids array (shape=101851843x10, bytes=0.95G) in 0.09s
Created list of token ID lookups in 0.00s


101851843it [01:19, 1282699.10it/s]


Filled token_ids array in 79.73s
CPU times: user 1min 24s, sys: 1.42 s, total: 1min 26s
Wall time: 1min 26s


In [23]:
from sklearn.model_selection import train_test_split
sequences = list(df.seq.values)
labels = df.human.values

train_sequences, test_sequences, train_labels, test_labels = \
    train_test_split(sequences_tokenized, labels, test_size=0.25, shuffle=True)



In [24]:
len(train_sequences)

76388882

In [25]:
len(test_sequences)

25462961

In [26]:
len(train_labels)

76388882

In [26]:
len(test_labels)

25462961

In [27]:
import esm 

model_8M, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

In [29]:
token_ids = tokenize(["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"])

Got sequence lengths in 0.00s
Created token_ids array (shape=1x67, bytes=0.00G) in 0.00s
Created list of token ID lookups in 0.00s


1it [00:00, 26886.56it/s]

Filled token_ids array in 0.00s





In [39]:
type(token_ids)

torch.Tensor

In [30]:
import torch
if type(token_ids) is not torch.Tensor:
    token_ids = torch.from_numpy(token_ids.astype('int32'))

# Extract per-residue representations (on CPU)
repr_layer_idx = 6


with torch.no_grad():
    results = model_8M(token_ids, repr_layers=[repr_layer_idx])
token_representations = results["representations"][repr_layer_idx]

In [31]:
results

{'logits': tensor([[[ 13.9791,  -9.0924,  -6.5645,  ..., -14.8934, -15.2000,  -9.0806],
          [ -8.2908, -14.3483,  -9.2372,  ..., -15.6513, -15.9180, -14.3464],
          [-13.0977, -23.2342, -13.0121,  ..., -16.0929, -16.0443, -23.2163],
          ...,
          [-11.8081, -23.6640, -13.1323,  ..., -16.5725, -16.6506, -23.6672],
          [-11.6440, -21.9399, -11.6811,  ..., -16.2364, -16.2875, -21.9438],
          [ -5.6318,  -6.4352,  19.2604,  ..., -16.4212, -16.2883,  -6.4788]]]),
 'representations': {6: tensor([[[ 0.2231,  0.5661,  0.1139,  ...,  1.0212, -0.1900, -0.6870],
           [ 0.4873,  0.2405, -0.1978,  ...,  0.6398, -0.0806, -0.3449],
           [-0.1065, -0.3528, -0.1022,  ..., -0.1548,  0.2464,  0.0080],
           ...,
           [-0.2542, -0.3260,  0.6081,  ...,  0.2127, -0.1515, -0.6503],
           [-0.0516, -0.1907,  0.3541,  ..., -0.0673, -0.0118, -0.5485],
           [-0.1355,  0.0183,  0.1578,  ...,  0.2765, -0.7183, -0.4105]]])}}

In [96]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class ProteinClassifier(nn.Module):
    def __init__(self, n_classes):
        super(ProteinClassifier, self).__init__()
        self.esm, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.last_esm_layer = self.esm.layers[32]
     
        self.last_esm_layer_dim = self.last_esm_layer.fc2.out_features
        
        self.classifier = nn.Sequential(nn.Dropout(p=0.2),
                                        nn.Linear(self.last_esm_layer_dim, n_classes),
                                        nn.Sigmoid())
        
    def forward(self, input_ids):
        output = self.esm(input_ids)
        return self.classifier(output.pooler_output)
    
prot_model = ProteinClassifier(2)

In [None]:
train_tensor = torch.from_numpy(train_sequences.astype('int32'))
print(train_tensor.shape)

train_out = prot_model(train_tensor)
print(train_out.shape)

In [103]:
train_sequences.dtype

dtype('uint8')