# ProtEx: A PROTein EXtension Tool

The ultimate goal of this project is to be able to detect selenoproteins in databases which have been incorrectly truncated at the first instance of a Sec codon, which shares the same code (AUG) as the canonical stop codon. For this task, we are planning to apply the pre-trained ESM transformer.

In [1]:
# Fixing some of the issues I'm having with importing modules. 
import sys
sys.path.append('/home/prichter/Documents/protex/src/')

from src.utils import fasta_to_df, clstr_to_df
from src.dataset import SequenceDataset
from src.esm import ESMClassifier, esm_train
import scipy.stats as stats
import pandas as pd
import torch

2023-05-31 10:07:23.342187: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-31 10:07:23.495941: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Data and preprocessing

All amino acid sequences were obtained from the UniProt database. The data is in FASTA format for compatibility with the CD-Hit sequence clustering tool -- I wrote additional functions to read FASTA files into `pandas.DataFrame`s, and vice versa. The datasets used, along with descriptions, are given below.
1. `./data/sec.fasta` All peptide sequences tagged in the UniProt database as containing a selenocysteine residue (there are about 20,000). All sequences have been truncated at the first selenocysteine. 
2. `./data/short.fasta` This dataset serves as true positive instances of short proteins which are *not* incorrectly truncated. This file was generated by using a Gaussian KDE to approximate a length distribution of truncated selenoproteins (in `sec_trunc.fasta`), sampling from the distribution, and downloading sequences from UniProt which match the length. This dataset is of size __. 
3. `./data/all.fasta` A combination of `sec_trunc.fasta` and `short.fasta`. Generated using `cat sec.fasta short.fasta >> all.fasta`.

In [None]:
# Using this function already assumes that the selenoprotein data has been downloaded from UniProt.
# This function overwrites the original UniProt file with the truncated sequences. 
data.truncate_selenoproteins('/home/prichter/Documents/protex/data/sec.fasta')

In [None]:
# Read the newly-truncated selenoproteins sequences into a DataFrame.
sec_data = data.fasta_to_df('/home/prichter/Documents/protex/data/sec.fasta')
# Use KDE to generate a distribution based on the lengths of the truncated selenoproteins. 
lengths = sec_data['seq'].apply(len).to_numpy()
dist = stats.gaussian_kde(lengths)

data.download_short_proteins(dist, '/home/prichter/Documents/protex/short.fasta')



### Generating train and test data

The The CD-Hit protein clustering tool was used to organize the sequences contained in `all.fasta` into clusters with 80 percent similarity. This was done using the command `cd-hit -i all.fasta -o all -n 5 -c 0.8`. Clustering the sequences, and organizing the training and testing sets such that no cluster spans the two, helps to ensure that the test accuracy captures whether or not the model is generalizing (similar to the approach [here](https://www.biorxiv.org/content/10.1101/626507v4.full)). 

In [None]:
# Load in the clustering and sequence information. 
clstr_data = clstr_to_df('/home/prichter/Documents/protex/data/all.clstr')
fasta_data = fasta_to_df('/home/prichter/Documents/protex/data/all.fasta')

data = clstr_data.merge(fasta_data, how='inner', on='id') # Need to combine the two DataFrames according to the ID column, using intersection of keys from both frames.
data = data.drop_duplicates(subset=['id']) # Make sure there are no duplicates (I was running into some issues with this)

Using a greedy partitioning algorithm, I split the combined dataset into a train and test set (with a roughly 75%/25% train/test split). This partitioning approach ensures that no cluster group spans both train and test datasets, for the reasons mentioned above. 

In [None]:
train_data, test_data = train_test_split(data, train_size=0.75, test_size=0.25)

# Confirm that the split sizes are correct. 
train_size = int(len(train_data) / len(data) * 100)
test_size = int(len(test_data) / len(data) * 100)

print(f'The training dataset makes up {train_size} percent of the total data.')
print(f'The testing dataset makes up {test_size} percent of the total data.')

# Confirm that there is no cluster overlap. 
train_clusters = set(train_data['cluster'])
test_clusters = set(test_data['cluster'])
assert len(train_clusters) + len(test_clusters) == len(train_clusters.union(test_clusters))


After generating the train and test datasets, I saved the data to CSV files to avoid having to re-run cells more than I need to. These files are also stored in the data directory, and contain the amino acid sequences, unique identifiers, and cluster number.

In [None]:
train_data.to_csv('/home/prichter/Documents/protex/data/train.csv', columns=['seq', 'id', 'cluster'])
test_data.to_csv('/home/prichter/Documents/protex/data/test.csv', columns=['seq', 'id', 'cluster'])


## Sequence classification

This part of the project is *very* incomplete as of now. Eventually, I will try each of these classification techniques, and compare. My priority right now is to do the bag-of-words benchline using a simple `torch.Embedding` layer. 

### Using ESM

This portion of code is mostly complete, although I am thinking about making some stylistic changes (e.g. moving the label and data-loading functions outside of the classifier, just so that's more visible in the code, and probably getting rid of some code repetition)

In [2]:
train_data = SequenceDataset(pd.read_csv('/home/prichter/Documents/protex/data/train.csv', nrows=1000))
# test_data = SequenceDataset(pd.read_csv('/home/prichter/Documents/protex/data/test.csv'))

In [3]:
model_v1 = ESMClassifier(use_builtin_classifier=True)
model_v2 = ESMClassifier(use_builtin_classifier=False)

Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'c

In [None]:
losses_v1 = esm_train(model_v1, train_data, batch_size=100, n_epochs=5)
torch.save(model_v1, 'model_v1.pickle')

In [5]:
losses_v2 = esm_train(model_v2, train_data, batch_size=100, n_epochs=5)
torch.save(model_v2, 'model_v2.pickle')

Training classifier...:   0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
builtin_model_losses

In [None]:
import transformers
#dir(transformers.EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D"))
params = transformers.EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
dir(params)

In [None]:
for p in params:
    print(p)

### Using "bag-of-words" benchmark
`#TODO`

### Using LSTM
`#TODO`

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
encoding = tokenizer([['MPSMSRRQFLKVTGTTLVGSSLALMGFAPGIALAEVRQYKLTRATETRNTCTYCSVACGI', 'FAPGIALAEVRQYKLTRATETRNTCTYCSVACGI']], return_tensors='pt', padding=True, truncation=True)

In [None]:
encoding