<a href="https://colab.research.google.com/github/faezesarlakifar/Protein-toxicity-prediction/blob/main/Protein_toxixcity_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Protein Toxicity Prediction using ProtT5 Embeddings and LightGBM

This notebook implements a machine learning pipeline for predicting protein toxicity using ProtT5 embeddings and LightGBM classifier.

## Overview
- Uses ProtT5 embeddings to represent protein sequences
- Implements feature selection using Fisher scores
- Trains a LightGBM classifier with hyperparameter optimization
- Evaluates model performance using multiple metrics

## Requirements
- Python 3.x
- PyTorch
- Transformers
- LightGBM
- scikit-learn
- pandas
- numpy
- h5py

## Data
- Input: FASTA files containing protein sequences
- Labels: Binary classification (toxic/non-toxic)


## The super helpful command of the [ProtTrans repository](https://github.com/agemagician/ProtTrans) is used to extract embeddings from protein sequences in FASTA format. ☕



## Data Preparation and Embedding Extraction

In [None]:
# @title Extract ProtT5 embeddings from positive data in FASTA format

input_path = '/content/drive/MyDrive/toxicity-prediction/fasta-files/'

input_file = input_path+"positive.fasta"

!python ProtTrans/Embedding/prott5_embedder.py --input $input_file --output /train_positive/residue_embeddings.h5
!python ProtTrans/Embedding/prott5_embedder.py --input $input_file --output /train_positive/protein_embeddings.h5 --per_protein 1

In [None]:
# @title Extract ProtT5 embeddings from negative data in FASTA format

input_path = '/content/drive/MyDrive/toxicity-prediction/fasta-files/'

input_file = input_path+"negative.fasta"

!python ProtTrans/Embedding/prott5_embedder.py --input $input_file --output /train_negative/residue_embeddings.h5
!python ProtTrans/Embedding/prott5_embedder.py --input $input_file --output /train_negative/protein_embeddings.h5 --per_protein 1

## Environment Setup

In [None]:
#@title Install requirements. { display-mode: "form" }
# Install requirements
!pip install torch transformers sentencepiece h5py lightgbm
import warnings
warnings.filterwarnings('ignore')

In [None]:
import torch
import numpy as np
import pandas as pd
import h5py

from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import MinMaxScaler

import lightgbm
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.feature_selection import f_classif
from sklearn.model_selection import cross_validate
from sklearn.metrics import recall_score, make_scorer

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


## Data Loading and Preprocessing

In [None]:
# Extract ProtT5 embeddings from positive (toxic) sequences
# This generates both residue-level and protein-level embeddings

def load_T5_embedding(h5_path, toxin_label):
  embedding = []
  protein_id = []

  emb_file = h5py.File(h5_path, 'r')

  for id in list(emb_file.keys()):
    embedding.append(torch.from_numpy(emb_file[id][:]))
    protein_id.append(id)

  columns = ['protein_id', 'embedding']
  df = pd.DataFrame(list(zip(protein_id, embedding)), columns=columns)
  df['toxin_label'] = toxin_label

  return df

In [None]:
T5_POSITIVE_PATH = '/content/drive/MyDrive/university/8th_Semester_Spring2023/Article/final/embedding_data/alternate/positive_alternate_emb_ProtT5.h5'
T5_NEGATIVE_PATH = '/content/drive/MyDrive/university/8th_Semester_Spring2023/Article/final/embedding_data/alternate/negative_alternate_emb_ProtT5.h5'

positive_t5_df = load_T5_embedding(T5_POSITIVE_PATH, 1)
negative_t5_df = load_T5_embedding(T5_NEGATIVE_PATH, 0)

## Feature Preprocessing

In [None]:
t5_df = pd.concat([positive_t5_df, negative_t5_df]).set_index('protein_id')

In [None]:
emb_np = t5_df['embedding'].apply(lambda x: x.numpy()).tolist()

scaler = MinMaxScaler()
emb_normalized = scaler.fit_transform(emb_np)

t5_df['embedding'] = [row for row in emb_normalized]
t5_df.head()

Unnamed: 0_level_0,embedding,toxin_label
protein_id,Unnamed: 1_level_1,Unnamed: 2_level_1
P_1,"[0.5867490350301072, 0.5632674746817872, 0.326...",1
P_10,"[0.3447602492987731, 0.6437145213066595, 0.379...",1
P_100,"[0.4865383439899823, 0.508863055280562, 0.6732...",1
P_1000,"[0.3562618971742902, 0.6768918506364869, 0.492...",1
P_1001,"[0.5892896446339582, 0.4080563325747942, 0.590...",1


In [None]:
X_train, X_test, y_train, y_test = train_test_split(t5_df['embedding'],
                                                    t5_df['toxin_label'],
                                                    test_size=0.2,
                                                    random_state=42)

In [None]:
# Feature selection using Fisher scores
# Select top k features based on Fisher scores

X_train = np.vstack(X_train.values)
X_test = np.vstack(X_test.values)

fisher_scores, _ = f_classif(X_train, y_train)

ranked_features = sorted(range(len(fisher_scores)), key=lambda i: fisher_scores[i], reverse=True)

# Number of features to select after Fisher score analysis
k = 512
selected_features = ranked_features[:k]

X_train_selected = X_train[:, selected_features]
X_test_selected = X_test[:, selected_features]

## Model Training

In [None]:
# Initialize LightGBM classifier
# Define hyperparameter grid for optimization
# Perform grid search with cross-validation

lgbm = lightgbm.LGBMClassifier()

# Hyperparameter ranges for grid search
lgbm_params = {
    'num_leaves': [100, 200],
    'n_estimators': [50, 200],
    'min_data_in_leaf': range(10, 51, 20),
    'max_depth': [100, 150],
    'learning_rate': [0.01, 0.1],
    'bagging_fraction': [0.5, 0.7]
}

lgbm_clf = GridSearchCV(
    estimator=lgbm,
    param_grid=lgbm_params,
    cv=5,
    n_jobs=5,
    verbose=1
)

lgbm_clf.fit(X_train_selected, y_train)

Fitting 5 folds for each of 96 candidates, totalling 480 fits
[LightGBM] [Info] Number of positive: 6556, number of negative: 6604
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.086029 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 130560
[LightGBM] [Info] Number of data points in the train set: 13160, number of used features: 512
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.498176 -> initscore=-0.007295
[LightGBM] [Info] Start training from score -0.007295


In [None]:
lgbm_clf.best_params_

{'bagging_fraction': 0.5,
 'learning_rate': 0.1,
 'max_depth': 100,
 'min_data_in_leaf': 50,
 'n_estimators': 200,
 'num_leaves': 100}

In [None]:
lgbm_params = {
    'num_leaves': 100,
    'n_estimators': 200,
    'min_data_in_leaf': 50,
    'max_depth': 100,
    'learning_rate': 0.1,
    'bagging_fraction': 0.5
}

In [None]:
lgbm_clf_alternate  = lightgbm.LGBMClassifier(**lgbm_params)
lgbm_clf_alternate.fit(X_train_selected, y_train)

In [None]:
y_pred_lgbm = lgbm_clf_alternate.predict(X_test_selected)
y_pred_proba = lgbm_clf_alternate.predict_proba(X_test_selected)



## Model Evaluation

In [None]:
# Calculate and display performance metrics

accuracy = accuracy_score(y_test, y_pred_lgbm)
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred_lgbm)
auc = roc_auc_score(y_test , y_pred_proba[:,1])
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
print("AUC", auc)

Accuracy: 0.8062418725617685
Precision: [0.77806122 0.83554377]
Recall: [0.83106267 0.78358209]
F1 Score: [0.80368906 0.80872914]
AUC 0.8933127279135658


In [None]:
from sklearn.metrics import matthews_corrcoef
matthews_corrcoef(y_test, y_pred_lgbm)

0.6141246554071792

In [None]:
sensitivity = make_scorer(recall_score, average='binary', pos_label=1)
specificty = make_scorer(recall_score, average='binary', pos_label=0)

scoring = {'accuracy':'accuracy', 'sensitivity':sensitivity, 'specificty':specificty, 'AUC':'roc_auc', 'MCC':'matthews_corrcoef'}
scores = cross_validate(lgbm_clf_alternate , X_train_selected, y_train, scoring=scoring)

In [None]:
# Define scoring metrics for cross-validation
print( 'accuracy: ' + str(scores['test_accuracy'].mean()))
print( 'sensitivity: ' + str(scores['test_sensitivity'].mean()))
print( 'specificty: ' + str(scores['test_specificty'].mean()))
print( 'AUC: ' + str(scores['test_AUC'].mean()))
print( 'MCC:' + str(scores['test_MCC'].mean()))

accuracy: 0.8148994994835942
sensitivity: 0.7778421921139482
specificty: 0.850993486684805
AUC: 0.8945853295803825
MCC:0.6311112916426561
