In [1]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"

In [2]:
import pandas as pd
df = pd.read_csv("../data/swissprot-8mers.csv")

In [3]:
df

Unnamed: 0.1,Unnamed: 0,seq,archaea,bacteria,fungi,human,invertebrates,mammals,plants,rodents,vertebrates,viruses,label_count
0,0,MTMDKSEL,False,False,False,True,False,True,False,True,False,False,3
1,1,TMDKSELV,False,False,False,True,False,True,False,True,False,False,3
2,2,MDKSELVQ,False,False,False,True,False,True,False,True,True,False,4
3,3,DKSELVQK,False,False,False,True,False,True,False,True,True,False,4
4,4,KSELVQKA,False,False,False,True,False,True,False,True,True,False,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...
101851838,101851838,LNVLTGTQ,False,False,False,False,False,False,True,False,False,False,1
101851839,101851839,NVLTGTQE,False,False,False,False,False,False,True,False,False,False,1
101851840,101851840,VLTGTQEG,False,False,False,False,False,False,True,False,False,False,1
101851841,101851841,LTGTQEGL,False,False,False,False,False,False,True,False,False,False,1


In [4]:
sequences = df.seq
labels = df.human

In [5]:
from transformers import AutoTokenizer

esm_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [6]:
esm_tokenizer(df.seq[0])

{'input_ids': [0, 20, 11, 20, 13, 15, 8, 9, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [70]:
from tqdm import tqdm
import numpy as np 
import time

def tokenize(seqs, lookup=esm_tokenizer._token_to_id):
    start_token = lookup["<cls>"]
    end_token = lookup["<eos>"]
    pad_token = lookup["<pad>"]
    n = len(seqs)
    t0 = time.time()
    lengths = [len(s) for s in seqs]
    max_seq_length = max(lengths)
    t1 = time.time()
    print("Got sequence lengths in %0.2fs" % (t1 - t0))
    
    m = max_seq_length + 2 # format will be <cls> peptide <eos> 
    unique_token_ids = np.array(list(lookup.values()))
    min_token_id = unique_token_ids.min()
    assert min_token_id >= 0
    max_token_id = unique_token_ids.max()
    
    if max_token_id < 2 ** 8:
        dtype = 'uint8'
    elif max_token_id < 2 ** 16:
        dtype = 'uint16'
    else:
        raise ValueError("max token ID too large")
    
    result = np.empty(shape=(n, m), dtype=dtype)
    result.fill(pad_token)
    t2 = time.time()
    print("Created token_ids array (shape=%dx%d, bytes=%0.2fG) in %0.2fs" % (
        result.shape[0],
        result.shape[1],
        result.nbytes / (1024 * 1024 * 1024), t2 - t1))
    
    ascii_values = {ord(aa) for aa in lookup.keys() if len(aa) == 1}
    max_ascii_value = max(ascii_values)
    table = [pad_token] * (max_ascii_value + 1)
    for (token, token_id) in lookup.items():
        if len(token) > 1:
            continue
        table[ord(token)] = token_id
    
    t3 = time.time()
    print("Created list of token ID lookups in %0.2fs" % (
        t3 - t2,))
    
    # fill the first position of each token_ids sequence with the start token
    result[:, 0] = start_token

    for i, (seq, length) in tqdm(enumerate(zip(seqs, lengths))):
        result[i, 1:length + 1] = [lookup[aa] for aa in seq]
        result[i, length + 1] = end_token
    t4 = time.time()
    print("Filled token_ids array in %0.2fs" % (t4 - t3))
    return result


In [71]:
%time sequences_tokenized = tokenize(sequences)

Got sequence lengths in 3.49s
Created token_ids array (shape=101851843x10, bytes=0.95G) in 0.10s
Created list of token ID lookups in 0.00s


101851843it [01:20, 1264983.77it/s]


Filled token_ids array in 80.58s
CPU times: user 1min 23s, sys: 781 ms, total: 1min 24s
Wall time: 1min 24s


In [9]:
from sklearn.model_selection import train_test_split
sequences = list(df.seq.values)
labels = df.human.values

train_sequences, test_sequences, train_labels, test_labels = \
    train_test_split(sequences_tokenized, labels, test_size=0.25, shuffle=True)



In [72]:
len(train_sequences)

76388882

In [73]:
len(test_sequences)

25462961

In [74]:
len(train_labels)

76388882

In [75]:
len(test_labels)

25462961