In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import numpy as np
import pandas as pd
import time
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve, auc, recall_score
from sklearn.model_selection import train_test_split


# Define dot product similarity
class DotProductScore(nn.Module):
    def __init__(self, hidden_size):
        super(DotProductScore, self).__init__()
        self.q = nn.Parameter(torch.empty(size=(hidden_size, 1), dtype=torch.float32))
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.q.data.uniform_(-initrange, initrange)

    def forward(self, inputs):
        """
        Input：
            - X：Input matrix，inputs=[batch_size,seq_length,hidden_size]
        Output：
            - scores：Output matrix，shape=[batch_size, seq_length]
        """
        scores = torch.matmul(inputs, self.q)

        scores = scores.squeeze(-1)
        
        return scores
    
# Define attention mechanisms
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.scores = DotProductScore(hidden_size)


    def forward(self, X, valid_lens):
        scores = self.scores(X)
        arrange = torch.arange(X.size(1), dtype=torch.float32, device=X.device).unsqueeze(0)
        mask = (arrange < valid_lens.unsqueeze(-1)).float()
        scores = scores * mask - (1 - mask) * 1e9
        attention_weights = nn.functional.softmax(scores, dim=-1) 
        out = torch.matmul(attention_weights.unsqueeze(1), X).squeeze(1)
        return out
    
#Define the model class
class ModelLSTMAttention(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, dropout, ins_num):
        super(ModelLSTMAttention, self).__init__()
        self.ins_num = ins_num
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)       
        self.attention = Attention(hidden_size * 2)
        self.bn1 = nn.BatchNorm1d(hidden_size * 2) 
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size * 2, output_size)
        self.dropout = nn.Dropout(p=dropout)
        self.fc_1 = nn.Linear(self.ins_num, 1)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, seq, valid_lens):
        
        output, _ = self.lstm(seq)
        valid_lens = valid_lens.view(-1,).to(device)
        out = self.attention(output, valid_lens)
        out = self.bn1(out)
        out = self.dropout(out)
        out = self.fc(out)
        out = self.dropout(out)
        out = out.reshape(-1, self.ins_num)
        out = self.fc_1(out)
        out = self.dropout(out)
        
        return out

# # Define functions to read TSV files
def read_tsv(filename, inf_ind, skip_1st=False, file_encoding="utf8"):
    extract_inf = []
    with open(filename, "r", encoding=file_encoding) as tsv_f:
        if skip_1st:
            tsv_f.readline()
        line = tsv_f.readline()
        while line:
            line_list = line.strip().split("\t")
            temp_inf = [line_list[ind] for ind in inf_ind]
            extract_inf.append(temp_inf)
            line = tsv_f.readline()
    return extract_inf

# Define a function that reads an amino acid feature file and generates a feature dictionary
def get_features(filename, f_num=15):
    f_list = read_tsv(filename, list(range(16)), True)
    f_dict = {}
    left_num = 0
    right_num = 0
    if f_num > 15:
        left_num = (f_num - 15) // 2
        right_num = f_num - 15 - left_num
    for f in f_list:
        f_dict[f[0]] = [0] * left_num + [float(x) for x in f[1:]] + [0] * right_num
    f_dict["X"] = [0] * f_num
    return f_dict

# Defining Input Functions
def generate_input(sps, sp_lbs, feature_dict, feature_num, ins_num, max_len):
    
    xs, ys, lens = [], [], []
    for i, sp in enumerate(sps):
        ys.append(sp_lbs[i])
        lens.extend([len(tcr[0]) if tcr[0] else 0 for tcr in sp])
        
    while len(lens) % ins_num != 0:
        lens = np.concatenate((lens, np.array([0])))  
    lens = np.array(lens)
    lens = lens.reshape(-1, ins_num)
    
    while lens.shape[0] < len(sps):
        lens = np.concatenate((lens, np.zeros((1, ins_num))), axis=0)
        
    for i, sp in enumerate(sps):
        x = [[[0] * feature_num for _ in range(max_len)] for _ in range(ins_num)]
        seq_count = 0  
        for j, tcr in enumerate(sp):
            tcr_seq = tcr[0]
            right_num = max_len - len(tcr_seq)
            tcr_seq += "X" * right_num
            tcr_matrix = []
            for aa in tcr_seq:
                tcr_matrix.append(feature_dict[aa.upper()])
            x[seq_count] = tcr_matrix
            seq_count += 1

        xs.append(x)

    xs = np.array(xs)
    xs = torch.tensor(xs, dtype=torch.float32)
    xs = xs.swapaxes(2, 3)
    ys = np.array(ys)
    ys = torch.tensor(ys, dtype=torch.float32).view(-1, 1)
    lens = torch.tensor(lens, dtype=torch.long)
    
    return xs, ys, lens

#Define the Generate Label function
def load_data(sample_dir):
    training_data = []
    training_labels = []
    for sample_file in os.listdir(sample_dir):
        training_data.append(read_tsv(os.path.join(sample_dir, sample_file), [0, 1], True))
        if "P" in sample_file:
            training_labels.append(1)
        elif "H" in sample_file:
            training_labels.append(0)
        else:
            print("Wrong sample filename! Please name positive samples with 'P' and negative samples with 'H'.")
            sys.exit(1)
        
    return training_data, training_labels

#Define the evaluation function
from tqdm import tqdm
def evaluate(model, criterion, test_loader, device='cuda'):
    test_total_loss = 0.0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for test_batch_x, test_batch_y, test_valid_lens in test_loader:
            test_batch_x = test_batch_x.view(-1, 24, 15).to(device)
            test_batch_y = test_batch_y.to(device)
            test_pred = model(test_batch_x, test_valid_lens)

            test_loss = criterion(test_pred, test_batch_y)
            test_total_loss += test_loss.item()
            all_preds.append(test_pred.cpu().numpy())
            all_labels.append(test_batch_y.cpu().numpy())
            
        test_avg_loss = test_total_loss / len(test_loader)
        return test_avg_loss, all_preds, all_labels
    
#Define the training function  
def train(fold, model, criterion, optimizer, train_loader, test_loader, epoches=100, device='cuda'):
    
    model_path = f'../model/LSTMY/{disease_name}checkpoint{fold}.pt'   # Save path of the model file
    early_stopping = EarlyStopping(PATIENCE, path=model_path, verbose=False)
    

    epoch_train_losses = []
    epoch_test_losses = []
    with tqdm(total=epoches) as t:
        t.set_description(f'{disease_name} - Fold {fold}') 
        for epoch in range(epoches):
            model.train()
            total_loss = 0.0
            for batch_x, batch_y, valid_lens in train_loader:
                batch_x = batch_x.view(-1, 24, 15).to(device)
                batch_y = batch_y.to(device)
                pred = model(batch_x, valid_lens)

                loss = criterion(pred, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            

            avg_loss = total_loss / len(train_loader)
            epoch_train_losses.append(avg_loss)
            test_avg_loss, _, _ = evaluate(model, criterion, test_loader, device=device)
            epoch_test_losses.append(test_avg_loss)
            
            t.set_postfix(loss=avg_loss, test_loss=test_avg_loss)
            t.update(1)
            

            early_stopping(test_avg_loss, model)
            if early_stopping.early_stop:
                model.load_state_dict(torch.load(model_path))
                #print('Early stopping')
                break
                
def sigmoid(x):
    return 1 / (1 + np.exp(-x))     

# Define a function to compute a binary classification indicator               
def metrics(all_preds, all_labels, threshold=0.5):
    
    all_probs = sigmoid(np.array(all_preds))
    binary_preds = (all_probs > threshold).astype(int)
    conf_matrix = confusion_matrix(all_labels, binary_preds)
    accuracy = accuracy_score(all_labels, binary_preds)
    sensitivity = conf_matrix[1, 1] / (conf_matrix[1, 0] + conf_matrix[1, 1])
    specificity = conf_matrix[0, 0] / (conf_matrix[0, 0] + conf_matrix[0, 1])
    auc = roc_auc_score(all_labels, all_probs)
    
    return accuracy, sensitivity, specificity, auc

In [None]:
# Model parameterization
def init_model():
    input_size = 15
    hidden_size =30
    num_layers = 3
    output_size = 1
    ins_num = 100
    dropout = 0.6
    
    return ModelLSTMAttention(input_size, hidden_size, output_size, num_layers, dropout, ins_num)

# Introduce an early stop mechanism
sys.path.append('../')
from python_codes.pytorchtools import EarlyStopping

# Reading amino acid profile files
aa_file = "../data/PCA15.txt"
aa_vectors = get_features(aa_file)  

# 5-fold cross-validation
k_fold = 5
kf = KFold(n_splits=k_fold, shuffle=True,random_state=42)

BATCH_SIZE = 64   # Batch size
NUM_EPOCHES = 2000 #Total number of training rounds
PATIENCE = 300    # Set the patience value for early stops

all_accuracies = []  
all_sensitivities = [] 
all_specificities = [] 
all_aucs = []  

device = "cuda"
# Four autoimmune diseases
disease_list = ["RA", "T1D", "MS", "IAA"]
results = []
results_ROC = []

for disease_name in disease_list:
    data_dir = f'../data/{disease_name}'   #Disease File Path
    training_data, training_labels = load_data(data_dir)
    print(f"Working on {disease_name} dataset: {len(training_data)} samples")
    
    
    all_preds = []
    all_labels = []
    
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(training_data)):
        train_data = [training_data[i] for i in train_idx]
        train_labels = [training_labels[i] for i in train_idx]
        test_data = [training_data[i] for i in test_idx]
        test_labels = [training_labels[i] for i in test_idx]
        
        
         # After the training and test sets are fixed, the training set is then divided into a training set and a validation set
        train_data, valid_data, train_labels, valid_labels = train_test_split(train_data, train_labels, test_size=0.2, random_state=1234)

        
        train_input_batch, train_label_batch, train_valid_lens_batch = generate_input(train_data, train_labels, aa_vectors, 15, 100, 24)
        valid_input_batch, valid_label_batch, valid_valid_lens_batch = generate_input(valid_data, valid_labels, aa_vectors, 15, 100, 24)
        test_input_batch, test_label_batch, test_valid_lens_batch = generate_input(test_data, test_labels, aa_vectors, 15, 100, 24)
        
        train_dataset = Data.TensorDataset(train_input_batch, train_label_batch, train_valid_lens_batch)
        valid_dataset = Data.TensorDataset(valid_input_batch, valid_label_batch, valid_valid_lens_batch)
        test_dataset = Data.TensorDataset(test_input_batch, test_label_batch, test_valid_lens_batch)

        train_loader = Data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        valid_loader = Data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = Data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

        model = init_model().to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.0005)

        train(fold, model, criterion, optimizer, train_loader, valid_loader, epoches=NUM_EPOCHES, device=device)

       # Final evaluation on test set
        _, preds, labels = evaluate(model, criterion, test_loader, device=device)
        all_preds += preds
        all_labels += labels

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    accuracy, sensitivity, specificity, auc = metrics(all_preds, all_labels)
    print(f"Mean Accuracy ({disease_name}): {accuracy:.4f}")
    print(f"Mean Sensitivity ({disease_name}): {sensitivity:.4f}")
    print(f"Mean Specificity ({disease_name}): {specificity:.4f}")
    print(f"Mean AUC ({disease_name}): {auc:.4f}")

    results.append({
        'disease': disease_name,
        'accuracy': accuracy,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'auc': auc
    })

    results_ROC.append({
        'disease': disease_name,
        'auc': auc,
        'all_preds': all_preds,
        'all_labels': all_labels
    })

In [None]:
# Print the results
results_df = pd.DataFrame(results)
results_df.set_index('disease', inplace=True)
print(results_df)
