## Sequence embeddings

In [1]:
from transformers import BertModel, BertTokenizer
import re
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert")



In [2]:
sequence_Example = "A E T C Z A O"
sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example)
encoded_input = tokenizer(sequence_Example, return_tensors='pt')
encoded_input

{'input_ids': tensor([[ 2,  6,  9, 15, 23, 25,  6, 25,  3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [3]:
output = model(**encoded_input)

### Load seq

In [4]:
hyper_aa_regex = re.compile('[BXZJUO]')

In [5]:
MAX_LEN = 200

In [13]:
import torch
import os
import pickle

In [14]:
modelname = 'seq_ann_1'

In [15]:
states_dir = f'states/{modelname}/'
if not os.path.exists(states_dir):
    os.makedirs(states_dir)

In [16]:
seq_file = "data/uniprot_sprot.fasta"
# seq_file = "data/debugging_sequence.fasta"


seqdb_pickle = f'{states_dir}/seqdb.pickle'
if os.path.exists(seqdb_pickle):
    seqdb = pickle.load(open(seqdb_pickle, 'rb'))
else:
    seqdb = {}
    irregs = 0
    for record in Bio.SeqIO.parse(seq_file, "fasta"):
        if '|' in record.id:
            _, acc, geneid = record.id.split('|')
        else:
            acc = record.id
        if hyper_aa_regex.findall(str(record.seq)):
            irregs += 1
            continue
        # if len(record.seq) > MAX_LEN:
        #     irregs += 1
        #     continue
        seqdb[acc] = record
    print('irregs', irregs)
    pickle.dump(seqdb, open(seqdb_pickle, 'wb'))
    seqdb = pickle.load(open(seqdb_pickle, 'rb'))



In [17]:
len(seqdb)

568610

### Formatting seq

In [22]:
import pandas as pd

df = pd.DataFrame(((k, ' '.join(v)) for k, v in seqdb.items()), columns=['id', 'seq']).set_index('id')
df

Unnamed: 0_level_0,seq
id,Unnamed: 1_level_1
Q6GZX4,M A F S A E D V L K E Y D R R R R M E A L L L ...
Q6GZX3,M S I I G A T R L Q N D K S D T Y S A G P C Y ...
Q197F8,M A S N T V S A Q G G S N R P V R D F S N I Q ...
Q197F7,M Y Q A I N P C P Q S W Y G S P Q L E R E I V ...
Q6GZX2,M A R P L L G K T S S V R R R L E S L S A C S ...
...,...
Q6UY62,M G N S K S K S K L S A N Q Y E Q Q T V N S T ...
P08105,M S S S L E I T S F Y S F I W T P H I G P L L ...
Q88470,M G N C N R T Q K P S S S S N N L E K P P Q A ...
A9JR22,M G L R Y S K E V R D R H G D K D P E G R I P ...


In [24]:
from tqdm import tqdm


In [30]:
from torch.utils.data import DataLoader, TensorDataset

# Tokenize all definitions and create batches
inputs = tokenizer(list(df['seq']), return_tensors='pt', padding=True, truncation=True)
dataset = TensorDataset(inputs['input_ids'], inputs['attention_mask'])
dataloader = DataLoader(dataset, batch_size=16)

In [34]:
dataloader = DataLoader(dataset, batch_size=2)

In [36]:
encoded_definitions = []

for batch in tqdm(dataloader, desc="Encoding Definitions"):
    input_ids, attention_mask = batch
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    embeddings = outputs.last_hidden_state.mean(dim=1)
    # embeddings = output['last_hidden_state'][:,0][0].detach().cpu().numpy()
    encoded_definitions.extend(embeddings)
    break

Encoding Definitions:   0%|          | 0/142153 [02:22<?, ?it/s]


KeyboardInterrupt: 

In [None]:
encoded_definitions = []

for batch in tqdm(dataloader, desc="Encoding Definitions"):
    input_ids, attention_mask = batch
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    embeddings = outputs.last_hidden_state.mean(dim=1)
    encoded_definitions.extend(embeddings)