Imports

In [2]:
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
import re
import numpy as np
import os
import pandas as pd
from torch import nn
from datasets import load_dataset

Consts

In [3]:
EMBEDDING_SIZE = 1024
MAX_SEQ_LEN = 512
K_LET = 'singlets'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
distil_protbert_path = 'yarongef/DistilProtBert'
test_set_path = f"yarongef/{K_LET}_test_set"
training_set_path = f"yarongef/{K_LET}_training_set"

In [25]:
model = AutoModel.from_pretrained(distil_protbert_path)

Downloading:   0%|          | 0.00/589 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/882M [00:00<?, ?B/s]

Some weights of the model checkpoint at yarongef/DistilProtBert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at yarongef/DistilProtBert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bi

Functions

In [None]:
def convert_dataset_to_df(dataset):
    dataset_dic = {'Seq': dataset['test']['Seq'], 'length': dataset['test']['length'], 'label': dataset['test']['label']}
    return pd.DataFrame(dataset_dic)

def preprocess_seqs(data):
    """
    create space between each amino acid + replace amino acids: U, Z, O and B with X
    """
    sequences = []
    for i in range(len(data)):
        sequences.append(" ".join(data.iloc[i,0])) # add space between each aa in order for it to be a token
    sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences] # special aa map to X
    return sequences

def get_full_length_protbert_embeddings(data, feature_extractor):
    seqs_preprocessed = preprocess_seqs(data)
    all_embeddings = np.zeros(shape=(len(data), MAX_SEQ_LEN, EMBEDDING_SIZE))
    for i in range(len(seqs_preprocessed)):
        seq_len = data.iloc[i,1]
        embeddings = feature_extractor(seqs_preprocessed[i])
        embeddings = np.array(embeddings)[:, 1:seq_len+1, :] # remove <CLS> & <SEP> special tokens
        if seq_len < MAX_SEQ_LEN:
            embeddings = np.pad(embeddings.squeeze(), pad_width=((0,MAX_SEQ_LEN-seq_len), (0,0))) # pad with rows of zeros (each row represents an amino acid)
        all_embeddings[i] = embeddings
        del embeddings
    return all_embeddings

Main

In [None]:
dataset_path = 'dataset_features/'
if not os.path.exists(dataset_path):
    os.makedirs(dataset_path)

test_set = load_dataset(test_set_path)
test_set_df = convert_dataset_to_df(test_set)

tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = AutoModel.from_pretrained(distil_protbert_path)

fe = pipeline('feature-extraction', model=model, tokenizer=tokenizer, device=0)
all_embeds = get_full_length_protbert_embeddings(test_set_df, fe)

all_embeds_torch = torch.from_numpy(all_embeds)
max_pooler = nn.MaxPool1d(16, stride=16)
output = max_pooler(all_embeds_torch)
print(f'test set shape is: {output.shape}')

torch.save(output, dataset_path+f'{K_LET}.pt')