In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve, auc, recall_score
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

# Define functions to read TSV files
def read_tsv(filename, inf_ind, skip_1st=False, file_encoding="utf8"):
    extract_inf = []
    with open(filename, "r", encoding=file_encoding) as tsv_f:
        if skip_1st:
            tsv_f.readline()
        line = tsv_f.readline()
        while line:
            line_list = line.strip().split("\t")
            temp_inf = []
            for ind in inf_ind:
                temp_inf.append(line_list[ind])
            extract_inf.append(temp_inf)
            line = tsv_f.readline()
    return extract_inf

# Define a function that reads an amino acid feature file and generates a feature dictionary
def get_features(filename, f_num=15):
    f_list = read_tsv(filename, list(range(16)), True)
    f_dict = {}
    left_num = 0
    right_num = 0
    if f_num > 15:
        left_num = (f_num - 15) // 2
        right_num = f_num - 15 - left_num
    for f in f_list:
        f_dict[f[0]] = [0] * left_num
        f_dict[f[0]] += [float(x) for x in f[1:]]
        f_dict[f[0]] += [0] * right_num
    f_dict["X"] = [0] * f_num
    return f_dict

# Defining Model Classes
class AotoY(nn.Module):
    def __init__(self, ins_num, drop_out, n=[12, 12, 12, 6], k=[2, 3, 4, 5]):
        super(AotoY, self).__init__()

        self.ins_num = ins_num
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=15, out_channels=n[0], kernel_size=k[0], stride=1, padding=0),
            nn.BatchNorm1d(n[0]),
            nn.ReLU(),
            nn.Conv1d(in_channels=n[0], out_channels=n[0], kernel_size=k[0], stride=1, padding=0),
            nn.BatchNorm1d(n[0]),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        
        
        
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=15, out_channels=n[1], kernel_size=k[1], stride=1, padding=0),
            nn.BatchNorm1d(n[1]),
            nn.ReLU(),
            nn.Conv1d(in_channels=n[1], out_channels=n[1], kernel_size=k[1], stride=1, padding=0),
            nn.BatchNorm1d(n[1]),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        
        
        
        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=15, out_channels=n[2], kernel_size=k[2], stride=1, padding=0),
            nn.BatchNorm1d(n[2]),
            nn.ReLU(),
            nn.Conv1d(in_channels=n[2], out_channels=n[2], kernel_size=k[2], stride=1, padding=0),
            nn.BatchNorm1d(n[2]),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        
        
        
        self.conv4 = nn.Sequential(
            nn.Conv1d(in_channels=15, out_channels=n[3], kernel_size=k[3], stride=1, padding=0),
            nn.BatchNorm1d(n[3]),
            nn.ReLU(),
            nn.Conv1d(in_channels=n[3], out_channels=n[3], kernel_size=k[3], stride=1, padding=0),
            nn.BatchNorm1d(n[3]),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        
        
         
        self.fc_1 = nn.Linear(sum(n), 1)
        self.dropout = nn.Dropout(p=drop_out) 
        self.fc_2 = nn.Linear(self.ins_num, 2)
        self.dropout = nn.Dropout(p=drop_out) 
        


    def forward(self, x):
        x = x.reshape(-1, 15, 24)
        
        out1 = self.conv1(x)
        out2 = self.conv2(x)
        out3 = self.conv3(x)
        out4 = self.conv4(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)
        out = out.view(out.size(0), -1) 
        out = self.fc_1(out)
        out = self.dropout(out)
        out = out.view(-1, self.ins_num)
        out = self.fc_2(out)
        out = self.dropout(out)
        return out

# Defining Input Functions
def generate_input(sps, sp_lbs, feature_dict, feature_num, ins_num, max_len):
    xs, ys = [], []
    i = 0
    for sp in sps:
        xs.append([[[0] * feature_num] * max_len] * ins_num)
        ys.append(sp_lbs[i])
        j = 0
        for tcr in sp:
            tcr_seq = tcr[0]
            right_num = max_len - len(tcr_seq)
            tcr_seq += "X" * right_num
            tcr_matrix = []
            for aa in tcr_seq:
                tcr_matrix.append(feature_dict[aa.upper()])
            xs[i][j] = tcr_matrix
            j += 1
        i += 1
    xs = np.array(xs)
    xs = xs.swapaxes(2, 3)
    ys = np.array(ys)
    return xs, ys

#Define the Generate Label function
def load_data(sample_dir):
    training_data = []
    training_labels = []
    for sample_file in os.listdir(sample_dir):
        training_data.append(read_tsv(os.path.join(sample_dir, sample_file), [0, 1], True))
        if "P" in sample_file:
            training_labels.append(1)
        elif "H" in sample_file:
            training_labels.append(0)
        else:
            print("Wrong sample filename! Please name positive samples with 'P' and negative samples with 'H'.")
            sys.exit(1)
        
    return training_data, training_labels

#Define the evaluation function
def evaluate(model, criterion, test_loader, device="cuda"):
    test_total_loss = 0.0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for test_batch_x, test_batch_y in test_loader:
            test_batch_x = test_batch_x.to(device)
            test_batch_y = test_batch_y.to(device)
            test_pred = model(test_batch_x)

            test_loss = criterion(test_pred, test_batch_y)
            test_total_loss += test_loss.item()
            all_preds.append(test_pred.cpu().numpy())
            all_labels.append(test_batch_y.cpu().numpy())
            
        test_avg_loss = test_total_loss / len(test_loader)
        return test_avg_loss, all_preds, all_labels
    
#Define the training function    
from tqdm import tqdm
def train(fold, model, criterion, optimizer, train_loader, valid_loader, epoches=100, device='cuda'):
    
    
    model_path = f'../model/AutoY/{disease_name}checkpoint{fold}.pt'  # Save path of the model file
    early_stopping = EarlyStopping(PATIENCE, path=model_path, verbose=False)
    

    epoch_train_losses = []
    epoch_test_losses = []
    with tqdm(total=epoches) as t:
        t.set_description(f'{disease_name} - Fold {fold}')  
        for epoch in range(epoches):
            model.train()
            total_loss = 0.0
            for batch_x, batch_y in train_loader:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)
                pred = model(batch_x)

                loss = criterion(pred, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
           
            avg_loss = total_loss / len(train_loader)
            epoch_train_losses.append(avg_loss)
            test_avg_loss, _, _ = evaluate(model, criterion, test_loader, device=device)
            epoch_test_losses.append(test_avg_loss) 
            t.set_postfix(loss=avg_loss, test_loss=test_avg_loss)
            t.update(1)
            early_stopping(test_avg_loss, model)
            
            if early_stopping.early_stop:
                model.load_state_dict(torch.load(model_path))
                #print('Early stopping')
                break

def sigmoid(x):
    return 1 / (1 + np.exp(-x))                

# Define a function to compute a binary classification indicator                
def metrics(all_preds, all_labels, threshold=0.5):
    all_probs = sigmoid(all_preds[:, 1] - all_preds[:, 0])
    binary_preds = (all_probs > threshold).astype(int)
    conf_matrix = confusion_matrix(all_labels, binary_preds)
    accuracy = accuracy_score(all_labels, binary_preds)
    sensitivity = conf_matrix[1, 1] / (conf_matrix[1, 0] + conf_matrix[1, 1])
    specificity = conf_matrix[0, 0] / (conf_matrix[0, 0] + conf_matrix[0, 1])
    auc = roc_auc_score(all_labels, all_probs)

    return accuracy, sensitivity, specificity, auc

In [None]:
# Model parameterization
def init_model():
    ins_num = 100
    drop_out = 0.5
    return AotoY(ins_num, drop_out)

# Introduce an early stop mechanism
sys.path.append('../')
from python_codes.pytorchtools import EarlyStopping

# Reading amino acid profile files
aa_file = "../data/PCA15.txt"
aa_vectors = get_features(aa_file) 

# Four autoimmune diseases
disease_list = ["RA", "T1D", "MS", "IAA"]

# 5-fold cross-validation
k_fold = 5
kf = KFold(n_splits=k_fold, shuffle=True ,random_state=42)

#Total number of training rounds
NUM_EPOCHES =3000
# Set the patience value for early stops
PATIENCE = 300
device = "cuda"

all_accuracies = []  
all_sensitivities = [] 
all_specificities = []  
all_aucs = [] 
results = []
results_ROC = []

for disease_name in disease_list:
    data_dir = f'../data/{disease_name}'   #Disease File Path
    training_data, training_labels = load_data(data_dir)
    print(f"Working on {disease_name} dataset: {len(training_data)} samples")


    all_preds = []
    all_labels = []
    
 
    for fold, (train_idx, test_idx) in enumerate(kf.split(training_data)):
        train_data = [training_data[i] for i in train_idx]
        train_labels = [training_labels[i] for i in train_idx]
        test_data = [training_data[i] for i in test_idx]
        test_labels = [training_labels[i] for i in test_idx]
        
        
        # After the training and test sets are fixed, the training set is then divided into a training set and a validation set
        train_data, valid_data, train_labels, valid_labels = train_test_split(train_data, train_labels, test_size=0.2, random_state=1234)

                                                                        
        train_input_batch, train_label_batch = generate_input(train_data, train_labels, aa_vectors, 15, 100, 24)
        train_input_batch, train_label_batch = torch.Tensor(train_input_batch).to(torch.device("cuda")), torch.LongTensor(train_label_batch).to(torch.device("cuda"))

        valid_input_batch, valid_label_batch = generate_input(valid_data, valid_labels, aa_vectors, 15, 100, 24)
        valid_input_batch, valid_label_batch = torch.Tensor(valid_input_batch).to(torch.device("cuda")), torch.LongTensor(valid_label_batch).to(torch.device("cuda"))
        test_input_batch, test_label_batch = generate_input(test_data, test_labels, aa_vectors, 15, 100, 24)
        test_input_batch, test_label_batch = torch.Tensor(test_input_batch).to(torch.device("cuda")), torch.LongTensor(test_label_batch).to(torch.device("cuda"))
        
        train_dataset = Data.TensorDataset(train_input_batch, train_label_batch)
        valid_dataset = Data.TensorDataset(valid_input_batch, valid_label_batch)
        test_dataset = Data.TensorDataset(test_input_batch, test_label_batch)
        
        train_loader = Data.DataLoader(train_dataset, len(train_input_batch), True)
        valid_loader = Data.DataLoader(valid_dataset, len(valid_input_batch), True)
        test_loader = Data.DataLoader(test_dataset, len(test_input_batch), True)

        model = init_model().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.0005)

        train(fold, model, criterion, optimizer, train_loader, valid_loader, epoches=NUM_EPOCHES, device=device)
        # Final evaluation on test set
        _, preds, labels = evaluate(model, criterion, test_loader, device=device)
        all_preds += preds
        all_labels += labels

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    accuracy, sensitivity, specificity, auc = metrics(all_preds, all_labels)
    print(f"Mean Accuracy ({disease_name}): {accuracy:.4f}")
    print(f"Mean Sensitivity ({disease_name}): {sensitivity:.4f}")
    print(f"Mean Specificity ({disease_name}): {specificity:.4f}")
    print(f"Mean AUC ({disease_name}): {auc:.4f}")

    results.append({
        'disease': disease_name,
        'accuracy': accuracy,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'auc': auc
    })

    results_ROC.append({
        'disease': disease_name,
        'auc': auc,
        'all_preds': all_preds,
        'all_labels': all_labels
    })
    

In [None]:
# Print the results
results_df = pd.DataFrame(results)
results_df.set_index('disease', inplace=True)
print(results_df)