In [1]:
#Import necessary libraries
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from sklearn.model_selection import StratifiedShuffleSplit
from transformers import DataCollatorWithPadding,DataCollatorForSeq2Seq
from transformers import AutoTokenizer, GPT2LMHeadModel,TrainingArguments, Trainer,GPT2Config,EarlyStoppingCallback
from sklearn.metrics import average_precision_score,matthews_corrcoef,f1_score, precision_score, recall_score, balanced_accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#Use GPU else specify '-1' for CPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
#Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2',bos_token='<startoftext>',eos_token='<endoftext>',pad_token='<PAD>')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
#Add custom tokens
tokenizer.add_tokens(['POSITIVE','NEGATIVE'])

2

In [5]:
#Map positive/negative labels and prepare prompt for inference
class SequenceClassificationDataset(Dataset):
    def __init__(self,first_ten, center, last_ten, labels, tokenizer):
        self.first_ten = first_ten
        self.center = center
        self.last_ten = last_ten
        self.labels = labels
        self.tokenizer = tokenizer
        self.map_label={1:'POSITIVE',0:'NEGATIVE'}
        self.dtype='Train'
    def __len__(self):
        return len(self.first_ten)

    def __getitem__(self, idx):
        first_ten = self.first_ten[idx]
        center = self.center[idx]
        last_ten = self.last_ten[idx]
        label = self.labels[idx]
        prep_txt= f'<startoftext>{first_ten} {center} {last_ten}\n'
        encoding = self.tokenizer(prep_txt,return_tensors='pt')
        return  {
            'input_ids': encoding['input_ids'].squeeze(), 
            'attention_mask': encoding['attention_mask'].squeeze(), 
            'label': label
        }

In [6]:
#Load benchmark dataset
data=pd.read_csv('benchmark.csv')

In [7]:
data['Seq']=data['Seq'].str.replace('\n','')

In [8]:
data['Seq']=data['Seq'].str.replace('-','')

In [9]:
# Function to extract parts of the sequence
def extract_seq_parts(seq):
    first_ten = seq[:10]  # First ten characters
    last_ten = seq[-10:]  # Last ten characters
    center = seq[len(seq) // 2] if len(seq) % 2 != 0 else ''  # Middle character for odd length, empty for even
    return first_ten, center, last_ten

# Applying the function to each sequence in the DataFrame
data[['First-Ten', 'Center', 'Last-Ten']] = data['Seq'].apply(lambda x: pd.Series(extract_seq_parts(x)))

In [10]:
data

Unnamed: 0,Seq,Label,First-Ten,Center,Last-Ten
0,ICCDILDVLDKHLIPAANTGE,0,ICCDILDVLD,K,HLIPAANTGE
1,AWRVISSIEQKTDTSDKKLQL,0,AWRVISSIEQ,K,TDTSDKKLQL
2,LIANATNPESKVFYLKMKGDY,0,LIANATNPES,K,VFYLKMKGDY
3,LAEVACGDDRKQTIDNSQGAY,0,LAEVACGDDR,K,QTIDNSQGAY
4,SWRVVSSIEQKTEGAEKKQQM,0,SWRVVSSIEQ,K,TEGAEKKQQM
...,...,...,...,...,...
31177,LDDMTKNDPFKARVSSGYVPP,1,LDDMTKNDPF,K,ARVSSGYVPP
31178,AGGTAPLPPWKSPSSSQPLPQ,1,AGGTAPLPPW,K,SPSSSQPLPQ
31179,PAPKFSPVTPKFTPVASKFSP,1,PAPKFSPVTP,K,FTPVASKFSP
31180,VTPKFTPVASKFSPGAPGGSG,1,VTPKFTPVAS,K,FSPGAPGGSG


In [11]:
#Check the choice of central amino acid
data = data[data['Center']=='R']

In [12]:
data

Unnamed: 0,Seq,Label,First-Ten,Center,Last-Ten
7,TQTWAGSHSMRYFFTSVSRPG,0,TQTWAGSHSM,R,YFFTSVSRPG
8,SDAASQRMEPRAPWIEQEGPE,0,SDAASQRMEP,R,APWIEQEGPE
9,AVVVPSGQEQRYTCHVQHEGL,0,AVVVPSGQEQ,R,YTCHVQHEGL
10,TGAVVAAVMWRRKSSDRKGGS,0,TGAVVAAVMW,R,RKSSDRKGGS
11,GAVVAAVMWRRKSSDRKGGSY,0,GAVVAAVMWR,R,KSSDRKGGSY
...,...,...,...,...,...
31140,PQLRSPRLPFRGNSYPGAAEG,1,PQLRSPRLPF,R,GNSYPGAAEG
31141,LPRFYPAGRARGIPHRFAGHE,1,LPRFYPAGRA,R,GIPHRFAGHE
31146,YYSPYALYGQRLASASALGYQ,1,YYSPYALYGQ,R,LASASALGYQ
31149,YTCEECGKAFRQSAILYVHRR,1,YTCEECGKAF,R,QSAILYVHRR


In [13]:
data['Label'].value_counts()

0    4683
1     552
Name: Label, dtype: int64

In [14]:
data = data.reset_index(drop=True)

In [15]:
first_test_texts=data['First-Ten'].reset_index(drop=True)
center_test_texts=data['Center'].reset_index(drop=True)
last_test_texts=data['Last-Ten'].reset_index(drop=True)
test_labels=data['Label'].reset_index(drop=True)

In [16]:
test_dataset=SequenceClassificationDataset(first_test_texts,center_test_texts,last_test_texts,test_labels,tokenizer)

In [17]:
test_data_loder= DataLoader(test_dataset,collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer,padding=True),batch_size=1)

In [18]:
map_label={0:'NEGATIVE',1:'POSITIVE'}

In [19]:
def get_score(mdl_path):
    model_config = GPT2Config.from_pretrained(mdl_path)
    model = GPT2LMHeadModel.from_pretrained(mdl_path,config=model_config,ignore_mismatched_sizes=True)
    model=model.cuda().eval()
    predition=[]
    for i,x in enumerate(test_data_loder):
        Actual=f"{tokenizer.decode(x['input_ids'][0],skip_special_tokens=True)} {x['label']}"
        generated=x['input_ids'].cuda()
        sample_outputs=model.generate(generated,attention_mask=x['attention_mask'].cuda(),do_sample=False,top_k=50,max_new_tokens=2,top_p=0.15,temperature=0.1,num_return_sequences=0,pad_token_id=tokenizer.eos_token_id)
        predicted_text=tokenizer.decode(sample_outputs[0],skip_special_tokens=True)
        predition+=[[map_label[int(x.pop('label'))],predicted_text.split('\n')[-1]]]
    labels=[[0 if y=='NEGATIVE' else 1  for y in x] for x in predition]
    labels=np.asanyarray(labels)
    actual=labels[:,0]
    pred=labels[:,1]
    return f1_score(actual,pred),matthews_corrcoef(actual,pred), precision_score(actual,pred), recall_score(actual, pred), average_precision_score(actual,pred)

In [None]:
#replace the path with best performing checkpoint
get_score('checkpoint-22500/')

# Check the best performing checkpoint in the benchmark dataset

In [21]:
results=[]

In [None]:
#Replace the path with the output directory used during model training
for mdl in os.listdir('/media/8TB_hardisk/results-Prompt2/'):
    if 'checkpoint' in mdl:
        mdl_path='/media/8TB_hardisk/results-Prompt2/'+mdl
        f1,mcc,prc,rec,avg=get_score(mdl_path)
        with open('results-Prompt2.csv','a') as f:
            f.write(f'{mdl},{f1},{mcc},{prc},{rec},{avg}\n')