In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from transformers import BertModel, BertTokenizer
import random
import os
from scipy.interpolate import interp1d
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score,roc_curve
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import random
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import torch
import csv
device = torch.device("cuda:0") 

In [None]:
#Model settings
batch_size=32# Setting batchsize
learning_rates=0.000065# Setting learning rates
model1_name="esm2_12"# Select different versions of the ESM2 model: esm2_6,esm_12,esm_30.
model2_name="protbert_ur100"#Selection of different versions of ProBert models: protbert_ur100,protbert_bfd

In [None]:
#Dataset definitions
class MyDataset(Dataset):
    def __init__(self, file):
        self.sequence, self.label = self.read_file(file)
        self.sequence_protbert=self.add_space_between_characters(self.sequence)

    def read_file(self,file_path):
        sequences = []
        labels = []
        with open(file_path, 'r', newline='') as csv_file:
            csv_reader = csv.reader(csv_file)
            next(csv_reader, None) 
            data = list(csv_reader)
            random.seed(42)
            random.shuffle(data)
            for row in data:
                sequences.append(row[0])
                labels.append(row[1])
        return sequences, labels
    
    def add_space_between_characters(self,input_list):
        new_list = []
        for element in input_list:
            new_element = ' '.join(element)
            new_list.append(new_element)

        return new_list

    def __len__(self):
        return len(self.sequence)

    def __getitem__(self, index):
        sample=self.sequence[index]
        sample_protbert=self.sequence_protbert[index]
        label=int(self.label[index])
        return sample, label, sample_protbert

In [None]:
# Read the training set
train_file = 'data/trainCPP.csv'  
train_dataset = MyDataset(train_file)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)

In [None]:
##Define the fusion models
class MyModel(nn.Module):
    def __init__(self,):
        super(MyModel, self).__init__()
        if model1_name=="esm2_30":
            self.model = AutoModel.from_pretrained("facebook/esm2_t30_150M_UR50D")
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
            self.layer=640
        elif model1_name=="esm2_12":
            self.model = AutoModel.from_pretrained("facebook/esm2_t12_35M_UR50D")
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
            self.layer=480
        elif model1_name=="esm2_6":
            self.model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
            self.layer=320
        if model2_name=="protbert_bdf":
            self.tokenizer_pro = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False)
            self.model_pro = BertModel.from_pretrained("Rostlab/prot_bert_bfd")
        elif model2_name=="protbert_ur100":
            self.tokenizer_pro = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
            self.model_pro = BertModel.from_pretrained("Rostlab/prot_bert")
        self.dropout = nn.Dropout(0.2)
        self.fc_pro = nn.Linear(480, 1024)  
        self.fc1 = nn.Linear(1024, 2)  
        self.sigmoid = nn.Sigmoid()
        self.fc2 = nn.Linear(20, 2)  

    def forward(self, inputs,inputs2):
        inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        encoded_input = self.tokenizer_pro(inputs2, padding=True, truncation=True,return_tensors='pt').to(device)
        outputs_pro = self.model_pro(**encoded_input)
        pooler_output1 = outputs.pooler_output   
        pooler_output2=outputs_pro.pooler_output
        pooler_output1=self.fc_pro(pooler_output1)
        x =pooler_output1+pooler_output2
        x=self.fc1(x)
        return x

In [None]:
#Model loading and setting
device = torch.device("cuda:0") 
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = MyModel()
model.to(device)#Model loading
criterion = nn.CrossEntropyLoss()
loss_all=99999
best_auc=0
all_fpr = []
all_tpr = []
all_aucs = []

In [None]:
# Five-fold cross-validation
kf = KFold(n_splits=5, shuffle=False)
for fold, (train_indices, valid_indices) in enumerate(kf.split(train_dataset)):
    best_auc=0
    best_acc=0
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(valid_indices)
    best_fpr=np.array([])
    best_tpr=np.array([])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    valid_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler)
    model = MyModel()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rates)
    item=0
    for epoch in range(50):
        item=item+1
        for batch_data, batch_labels, batch_data_protbert in train_dataloader:
            model.train()
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data,batch_data_protbert)
            loss = criterion(outputs, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        all_labels = []
        all_scores = []
        model.eval()      
        for batch_data, batch_labels, batch_data_protbert in valid_dataloader:
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data,batch_data_protbert)
            probabilities = nn.functional.softmax(outputs, dim=1)
            scores = probabilities[:, 1] 
            all_labels.extend(batch_labels.tolist())
            all_scores.extend(scores.tolist())
        fpr, tpr, _ = roc_curve(all_labels, all_scores)
        auc = roc_auc_score(all_labels, all_scores)
        correct_predictions = (np.array(all_scores) >= 0.5).astype(int)
        acc = np.mean(correct_predictions == np.array(all_labels))
        if auc>best_auc:
            best_fpr=fpr
            best_tpr=tpr
            best_auc=auc
    all_fpr.append(best_fpr)
    all_tpr.append(best_tpr)
    all_aucs.append(best_auc)
    print(f"Fold {fold + 1}: AUC = {best_auc:.6f}")

In [None]:
#Drawing roc diagrams
plt.figure(figsize=(8, 6))
max_length = max(len(fpr) for fpr in all_fpr)
new_all_fpr = []
new_all_tpr = []
for fpr, tpr in zip(all_fpr, all_tpr):
    f = interp1d(np.linspace(0, 1, len(fpr)), fpr)
    t = interp1d(np.linspace(0, 1, len(tpr)), tpr)
    new_fpr = f(np.linspace(0, 1, max_length))
    new_tpr = t(np.linspace(0, 1, max_length))
    new_all_fpr.append(new_fpr)
    new_all_tpr.append(new_tpr)
all_fpr=new_all_fpr
all_tpr=new_all_tpr
for i in range(len(all_fpr)):
    plt.plot(all_fpr[i], all_tpr[i], linestyle='--',lw=1, label=f'Fold {i + 1} (AUC = {all_aucs[i]:.3f})')
mean_fpr = np.mean(all_fpr, axis=0)
mean_tpr = np.mean(all_tpr, axis=0)
plt.plot(mean_fpr, mean_tpr, color='b', linestyle='-', lw=1.5, label='Mean ROC (AUC = {:.3f})'.format(np.mean(all_aucs)))
# Setting Graphic Properties
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.savefig('../../autodl-tmp/5fold_probertUR100+esm2_12/5_fold_roc_8.png',dpi=400)# roc image save path
plt.show()
print("AUC for each fold:", all_aucs)
print("Mean AUC:", np.mean(all_aucs))