In [5]:
import numpy as np
from tdc.multi_pred import DTI
import sys
import os
import pandas as pd

major = sys.version_info.major
minor = sys.version_info.minor
micro = sys.version_info.micro

print(f"Python version: {major}.{minor}.{micro}")

data = DTI(name='KIBA')
KibaDataSet = data.get_data()
print('number of drugs =',len(KibaDataSet['Drug_ID'].unique()))
print('number of proteins = ',len(KibaDataSet["Target_ID"].unique()))
KibaDataSet

Found local copy...
Loading...


Python version: 3.9.20


Done!


number of drugs = 2068
number of proteins =  229


Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O00141,MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNS...,11.10000
1,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O14920,MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQC...,11.10000
2,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,O15111,MERPPGLRPGAGGPWEMRERLGTGGFGNVCLYQHRELDLKIAIKSC...,11.10000
3,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P00533,MRPSGTAGAALLALLAALCPASRALEEKKVCQGTSNKLTQLGTFED...,11.10000
4,CHEMBL1087421,COc1cc2c(cc1Cl)C(c1ccc(Cl)c(Cl)c1)=NCC2,P04626,MELAALCRWGLLLALLPPGAASTQVCTGTDMKLRLPASPETHLDML...,11.10000
...,...,...,...,...,...
117652,CHEMBL230654,CCCc1nc[nH]c1CNc1cc(Cl)c2ncc(C#N)c(Nc3ccc(F)c(...,Q13554,MATTVTCTRFTDEYQLYEDIGKGAFSVVRRCVKLCTGHEYAAKIIN...,10.49794
117653,CHEMBL230654,CCCc1nc[nH]c1CNc1cc(Cl)c2ncc(C#N)c(Nc3ccc(F)c(...,Q13555,MATTATCTRFTDDYQLFEELGKGAFSVVRRCVKKTSTQEYAAKIIN...,10.49794
117654,CHEMBL230654,CCCc1nc[nH]c1CNc1cc(Cl)c2ncc(C#N)c(Nc3ccc(F)c(...,Q13557,MASTTTCTRFTDEYQLFEELGKGAFSVVRRCMKIPTGQEYAAKIIN...,10.49794
117655,CHEMBL230654,CCCc1nc[nH]c1CNc1cc(Cl)c2ncc(C#N)c(Nc3ccc(F)c(...,Q16539,MSQERPTFYRQELNKTIWEVPERYQNLSPVGSGAYGSVCAAFDTKT...,10.49794


In [6]:

# Path to the folder where your protein graphs are saved
graph_folder = "./ProteinGraphs"

# Get the list of protein IDs for which you have graphs (strip the '_graph.pt' suffix)
protein_ids_with_graphs = [f.split('_graph.pt')[0] for f in os.listdir(graph_folder) if f.endswith('_graph.pt')]

# Convert this list to a set for faster lookups
protein_ids_with_graphs = set(protein_ids_with_graphs)


# Assuming KibaDataSet is already loaded as a pandas DataFrame
valid_protein_dataset = KibaDataSet[KibaDataSet["Target_ID"].isin(protein_ids_with_graphs)]

# Save the new dataset with valid proteins
valid_protein_dataset.to_csv("filtered_KibaDataSet.csv", index=False)

# Check the size of the new dataset
print(f"Original dataset had {len(KibaDataSet['Target_ID'].unique())} unique proteins.")
print(f"Filtered dataset has {len(valid_protein_dataset['Target_ID'].unique())} unique proteins.")


Original dataset had 229 unique proteins.
Filtered dataset has 203 unique proteins.


In [7]:
len(valid_protein_dataset)


101209

In [8]:
import os
import torch
import pickle
import pandas as pd

def load_graph(path, is_pickle=True):
    """
    Load a molecule graph (.pkl) or a protein graph (.pt).
    If is_pickle is True, use pickle to load the file; otherwise, use torch.load.
    """
    if is_pickle:
        with open(path, 'rb') as f:
            return pickle.load(f)
    else:
        return torch.load(path)

def prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir):
    """
    Incrementally prepares the dataset and saves each (molecule, protein, target) tuple as a separate .pt file.
    
    Args:
    - filtered_dataset: The filtered KIBA dataset (DataFrame).
    - molecule_graph_dir: Directory where molecule graphs are stored.
    - protein_graph_dir: Directory where protein graphs are stored.
    - output_dir: Directory to save the prepared dataset incrementally.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for index, row in filtered_dataset.iterrows():
        protein_id = row['Target_ID']
        chembl_id = row['Drug_ID']
        
        # Load the protein graph (.pt)
        pro_graph_path = os.path.join(protein_graph_dir, f"{protein_id}_graph.pt")
        if not os.path.exists(pro_graph_path):
            print(f"Protein graph not found: {protein_id}")
            continue
        pro_graph = load_graph(pro_graph_path, is_pickle=False)
        
        # Load the molecule graph (.pkl)
        mol_graph_path = os.path.join(molecule_graph_dir, f"{chembl_id}_graph.pkl")
        if not os.path.exists(mol_graph_path):
            print(f"Molecule graph not found: {chembl_id}")
            continue
        mol_graph = load_graph(mol_graph_path)

        # Load target (affinity value)
        target = torch.tensor([row['Y']], dtype=torch.float)
        
        # Create the sample as a tuple (molecule graph, protein graph, target)
        sample = (mol_graph, pro_graph, target)
        
        # Save the sample as a .pt file
        sample_path = os.path.join(output_dir, f"sample_{index}.pt")
        torch.save(sample, sample_path)

        if(index%10000 == 0 ):
            print(f"Saved sample {index} as {sample_path}")
            

        

# Example usage for individual saving
molecule_graph_dir = 'molecule_graphs/'  # Directory where molecule graphs are stored
protein_graph_dir = 'ProteinGraphs/'  # Directory where protein graphs are stored
filtered_dataset_path = 'filtered_KibaDataSet.csv'  # Path to the filtered dataset CSV
output_dir = 'prepared_samples/'  # Directory to save individual samples

# Load filtered dataset CSV
filtered_dataset = pd.read_csv(filtered_dataset_path)

# Prepare the dataset incrementally, saving each sample as a .pt file
prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir)

print("Dataset preparation completed.")


Saved sample 0 as prepared_samples/sample_0.pt
Saved sample 10000 as prepared_samples/sample_10000.pt
Saved sample 20000 as prepared_samples/sample_20000.pt
Saved sample 30000 as prepared_samples/sample_30000.pt
Saved sample 40000 as prepared_samples/sample_40000.pt
Saved sample 50000 as prepared_samples/sample_50000.pt
Saved sample 60000 as prepared_samples/sample_60000.pt
Saved sample 70000 as prepared_samples/sample_70000.pt
Saved sample 80000 as prepared_samples/sample_80000.pt
Saved sample 90000 as prepared_samples/sample_90000.pt
Saved sample 100000 as prepared_samples/sample_100000.pt
Dataset preparation completed.


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj


# GCN based model
class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        self.n_output = n_output
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        self.mol_conv2 = GCNConv(num_features_mol, num_features_mol * 2)
        self.mol_conv3 = GCNConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # self.pro_conv1 = GCNConv(embed_dim, embed_dim)
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        self.pro_conv2 = GCNConv(num_features_pro, num_features_pro * 2)
        self.pro_conv3 = GCNConv(num_features_pro * 2, num_features_pro * 4)
        # self.pro_conv4 = GCNConv(embed_dim * 4, embed_dim * 8)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro * 4, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # combined layers
        self.fc1 = nn.Linear(2 * output_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # get graph input
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        # get protein input
        target_x, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch

        # target_seq=data_pro.target

        # print('size')
        # print('mol_x', mol_x.size(), 'edge_index', mol_edge_index.size(), 'batch', mol_batch.size())
        # print('target_x', target_x.size(), 'target_edge_index', target_batch.size(), 'batch', target_batch.size())

        x = self.mol_conv1(mol_x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv2(x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv3(x, mol_edge_index)
        x = self.relu(x)
        x = gep(x, mol_batch)  # global pooling

        # flatten
        x = self.relu(self.mol_fc_g1(x))
        x = self.dropout(x)
        x = self.mol_fc_g2(x)
        x = self.dropout(x)

        xt = self.pro_conv1(target_x, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv2(xt, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv3(xt, target_edge_index)
        xt = self.relu(xt)

        # xt = self.pro_conv4(xt, target_edge_index)
        # xt = self.relu(xt)
        xt = gep(xt, target_batch)  # global pooling

        # flatten
        xt = self.relu(self.pro_fc_g1(xt))
        xt = self.dropout(xt)
        xt = self.pro_fc_g2(xt)
        xt = self.dropout(xt)

        # print("sizes of out coming gnns",x.size(), xt.size())
        # concat
        xc = torch.cat((x, xt), 1)
        # print('after concat',xc.size())
        # add some dense layers
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out


In [12]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools
from sklearn.model_selection import train_test_split


# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]



# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


In [13]:
#calculate metrics on training and testing data at the end of training
def train_and_evaluate(sample_dir, num_epochs=100, test_size=0.2, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]

    # Split data into training and test sets
    train_files, test_files = train_test_split(sample_files, test_size=test_size, random_state=42)

    # Create a directory for the model checkpoint
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    # Determine input feature dimensions from your data
    sample = load_sample(os.path.join(sample_dir, train_files[0]))
    mol_data = sample[0]
    pro_data = sample[1]

    num_features_mol = mol_data.x.size(1)
    num_features_pro = pro_data.x.size(1)

    # Initialize the GNN model with correct input dimensions
    model = GNNNet(
        num_features_mol=num_features_mol,
        num_features_pro=num_features_pro
    ).to(device)
    print(f"Model is on device: {next(model.parameters()).device}")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = MSELoss()

    # Initialize starting epoch
    start_epoch = 1

    # Check for existing checkpoints in TrainingModel directory
    existing_checkpoints = [f for f in os.listdir(training_model_dir)
                            if f.endswith('.pt') and f.startswith('model_epoch')]

    if existing_checkpoints:
        # Find the latest checkpoint based on epoch number
        latest_checkpoint = max(existing_checkpoints,
                                key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
        checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loaded_epoch = checkpoint['epoch']
        start_epoch = loaded_epoch + 1
        print(f"Resuming training from epoch {start_epoch}")
    else:
        print("No checkpoint found, starting training from scratch.")

    # Training loop with progress bar over epochs
    for epoch in tqdm(range(start_epoch, num_epochs + 1),
                      desc="Training", unit="epoch"):
        model.train()
        running_loss = 0.0

        # Prepare batch loader
        batch_size = 200  # Adjust batch size as needed
        batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

        for batch_samples in batch_loader_iter:
            mol_data_list = []
            pro_data_list = []
            target_list = []

            for sample in batch_samples:
                mol_data = sample[0]
                pro_data = sample[1]
                target = sample[2]

                mol_data_list.append(mol_data)
                pro_data_list.append(pro_data)
                target_list.append(target)

            mol_batch = Batch.from_data_list(mol_data_list).to(device)
            pro_batch = Batch.from_data_list(pro_data_list).to(device)
            target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

            optimizer.zero_grad()
            output = model(mol_batch, pro_batch)
            loss = loss_fn(output.view(-1), target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * len(batch_samples)

        avg_loss = running_loss / len(train_files)
        tqdm.write(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Save the model and optimizer states after each epoch
        checkpoint_filename = f"model_epoch{epoch}.pt"
        checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        tqdm.write(f"Checkpoint saved at epoch {epoch}")

    # Final evaluation on test data
    print("\nEvaluating model after training...")
    model.eval()

    def evaluate(files):
        total_preds, total_labels = [], []
        with torch.no_grad():
            batch_loader_iter = batch_loader(files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                output = model(mol_batch, pro_batch)
                total_preds.append(output.cpu().numpy())
                total_labels.append(target.cpu().numpy())

        # Convert lists to numpy arrays for evaluation
        total_preds = np.concatenate(total_preds)
        total_labels = np.concatenate(total_labels)

        # Calculate metrics
        mse = get_mse(total_labels, total_preds)
        ci = get_ci(total_labels, total_preds)
        pearson = get_pearson(total_labels, total_preds)

        return mse, ci, pearson

    train_mse, train_ci, train_pearson = evaluate(train_files)
    test_mse, test_ci, test_pearson = evaluate(test_files)

    print(f"Final Training Metrics: MSE: {train_mse:.4f}, CI: {train_ci:.4f}, Pearson: {train_pearson:.4f}")
    print(f"Final Test Metrics: MSE: {test_mse:.4f}, CI: {test_ci:.4f}, Pearson: {test_pearson:.4f}")

if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 300  # Adjust the number of epochs as needed
    test_size = 0.2   # Proportion of the dataset to include in the test split
    learning_rate = 0.001  # Learning rate

    # Run the training function
    train_and_evaluate(
        sample_dir,
        num_epochs=num_epochs,
        test_size=test_size,
        lr=learning_rate
    )


Running on cuda.
Created directory for checkpoints at prepared_samples/TrainingModel
GNNNet Loaded
Model is on device: cuda:0
No checkpoint found, starting training from scratch.


Training:   0%|                          | 1/300 [02:00<10:00:40, 120.54s/epoch]

Epoch 1/300, Loss: 3.2072
Checkpoint saved at epoch 1


Training:   1%|▏                         | 2/300 [04:34<11:34:57, 139.92s/epoch]

Epoch 2/300, Loss: 1.0997
Checkpoint saved at epoch 2


Training:   1%|▎                         | 3/300 [07:24<12:41:37, 153.86s/epoch]

Epoch 3/300, Loss: 1.0402
Checkpoint saved at epoch 3


Training:   1%|▎                         | 4/300 [09:07<11:00:12, 133.83s/epoch]

Epoch 4/300, Loss: 1.0058
Checkpoint saved at epoch 4


Training:   2%|▍                         | 5/300 [10:53<10:07:57, 123.65s/epoch]

Epoch 5/300, Loss: 0.9419
Checkpoint saved at epoch 5


Training:   2%|▌                          | 6/300 [12:36<9:31:11, 116.57s/epoch]

Epoch 6/300, Loss: 0.8823
Checkpoint saved at epoch 6


Training:   2%|▋                          | 7/300 [14:16<9:03:27, 111.29s/epoch]

Epoch 7/300, Loss: 0.7982
Checkpoint saved at epoch 7


Training:   3%|▋                          | 8/300 [15:58<8:47:49, 108.46s/epoch]

Epoch 8/300, Loss: 0.7433
Checkpoint saved at epoch 8


Training:   3%|▊                          | 9/300 [17:38<8:33:07, 105.80s/epoch]

Epoch 9/300, Loss: 0.7066
Checkpoint saved at epoch 9


Training:   3%|▊                         | 10/300 [19:18<8:21:48, 103.82s/epoch]

Epoch 10/300, Loss: 0.6996
Checkpoint saved at epoch 10


Training:   4%|▉                         | 11/300 [20:58<8:15:01, 102.77s/epoch]

Epoch 11/300, Loss: 0.6824
Checkpoint saved at epoch 11


Training:   4%|█                         | 12/300 [22:39<8:11:03, 102.30s/epoch]

Epoch 12/300, Loss: 0.6699
Checkpoint saved at epoch 12


Training:   4%|█▏                        | 13/300 [24:19<8:05:36, 101.52s/epoch]

Epoch 13/300, Loss: 0.6412
Checkpoint saved at epoch 13


Training:   5%|█▏                        | 14/300 [26:00<8:02:35, 101.24s/epoch]

Epoch 14/300, Loss: 0.6424
Checkpoint saved at epoch 14


Training:   5%|█▎                        | 15/300 [27:39<7:57:49, 100.60s/epoch]

Epoch 15/300, Loss: 0.6389
Checkpoint saved at epoch 15


Training:   5%|█▍                         | 16/300 [29:16<7:52:02, 99.73s/epoch]

Epoch 16/300, Loss: 0.6116
Checkpoint saved at epoch 16


Training:   6%|█▌                         | 17/300 [30:54<7:46:45, 98.96s/epoch]

Epoch 17/300, Loss: 0.6039
Checkpoint saved at epoch 17


Training:   6%|█▌                         | 18/300 [32:30<7:41:59, 98.30s/epoch]

Epoch 18/300, Loss: 0.5982
Checkpoint saved at epoch 18


Training:   6%|█▋                         | 19/300 [34:07<7:38:28, 97.89s/epoch]

Epoch 19/300, Loss: 0.5925
Checkpoint saved at epoch 19


Training:   7%|█▊                         | 20/300 [35:45<7:36:43, 97.87s/epoch]

Epoch 20/300, Loss: 0.5704
Checkpoint saved at epoch 20


Training:   7%|█▉                         | 21/300 [37:23<7:35:06, 97.87s/epoch]

Epoch 21/300, Loss: 0.5551
Checkpoint saved at epoch 21


Training:   7%|█▉                         | 22/300 [39:01<7:33:02, 97.78s/epoch]

Epoch 22/300, Loss: 0.5468
Checkpoint saved at epoch 22


Training:   8%|██                         | 23/300 [40:40<7:33:06, 98.15s/epoch]

Epoch 23/300, Loss: 0.5375
Checkpoint saved at epoch 23


Training:   8%|██▏                        | 24/300 [42:22<7:37:42, 99.50s/epoch]

Epoch 24/300, Loss: 0.5297
Checkpoint saved at epoch 24


Training:   8%|██▎                        | 25/300 [44:03<7:37:41, 99.86s/epoch]

Epoch 25/300, Loss: 0.5262
Checkpoint saved at epoch 25


Training:   9%|██▎                       | 26/300 [45:47<7:41:47, 101.12s/epoch]

Epoch 26/300, Loss: 0.5182
Checkpoint saved at epoch 26


Training:   9%|██▎                       | 27/300 [47:27<7:38:32, 100.78s/epoch]

Epoch 27/300, Loss: 0.5117
Checkpoint saved at epoch 27


Training:   9%|██▍                       | 28/300 [49:06<7:34:36, 100.28s/epoch]

Epoch 28/300, Loss: 0.5043
Checkpoint saved at epoch 28


Training:  10%|██▌                        | 29/300 [50:44<7:30:18, 99.70s/epoch]

Epoch 29/300, Loss: 0.4969
Checkpoint saved at epoch 29


Training:  10%|██▋                        | 30/300 [52:22<7:25:09, 98.92s/epoch]

Epoch 30/300, Loss: 0.4932
Checkpoint saved at epoch 30


Training:  10%|██▊                        | 31/300 [53:58<7:19:45, 98.09s/epoch]

Epoch 31/300, Loss: 0.4934
Checkpoint saved at epoch 31


Training:  11%|██▉                        | 32/300 [55:35<7:17:38, 97.98s/epoch]

Epoch 32/300, Loss: 0.4865
Checkpoint saved at epoch 32


Training:  11%|██▉                        | 33/300 [57:13<7:14:56, 97.74s/epoch]

Epoch 33/300, Loss: 0.4762
Checkpoint saved at epoch 33


Training:  11%|███                        | 34/300 [58:49<7:11:51, 97.41s/epoch]

Epoch 34/300, Loss: 0.4742
Checkpoint saved at epoch 34


Training:  12%|██▉                      | 35/300 [1:00:27<7:10:30, 97.48s/epoch]

Epoch 35/300, Loss: 0.4693
Checkpoint saved at epoch 35


Training:  12%|███                      | 36/300 [1:02:03<7:07:16, 97.11s/epoch]

Epoch 36/300, Loss: 0.4623
Checkpoint saved at epoch 36


Training:  12%|███                      | 37/300 [1:03:38<7:03:12, 96.55s/epoch]

Epoch 37/300, Loss: 0.4530
Checkpoint saved at epoch 37


Training:  13%|███▏                     | 38/300 [1:05:12<6:58:13, 95.78s/epoch]

Epoch 38/300, Loss: 0.4573
Checkpoint saved at epoch 38


Training:  13%|███▎                     | 39/300 [1:06:48<6:56:18, 95.70s/epoch]

Epoch 39/300, Loss: 0.4429
Checkpoint saved at epoch 39


Training:  13%|███▎                     | 40/300 [1:08:25<6:56:37, 96.14s/epoch]

Epoch 40/300, Loss: 0.4473
Checkpoint saved at epoch 40


Training:  14%|███▍                     | 41/300 [1:10:02<6:56:40, 96.53s/epoch]

Epoch 41/300, Loss: 0.4413
Checkpoint saved at epoch 41


Training:  14%|███▌                     | 42/300 [1:11:39<6:54:53, 96.49s/epoch]

Epoch 42/300, Loss: 0.4304
Checkpoint saved at epoch 42


Training:  14%|███▌                     | 43/300 [1:13:14<6:52:08, 96.22s/epoch]

Epoch 43/300, Loss: 0.4298
Checkpoint saved at epoch 43


Training:  15%|███▋                     | 44/300 [1:14:50<6:49:11, 95.90s/epoch]

Epoch 44/300, Loss: 0.4377
Checkpoint saved at epoch 44


Training:  15%|███▊                     | 45/300 [1:16:25<6:46:43, 95.70s/epoch]

Epoch 45/300, Loss: 0.4230
Checkpoint saved at epoch 45


Training:  15%|███▊                     | 46/300 [1:18:01<6:45:19, 95.75s/epoch]

Epoch 46/300, Loss: 0.4132
Checkpoint saved at epoch 46


Training:  16%|███▉                     | 47/300 [1:19:38<6:46:01, 96.29s/epoch]

Epoch 47/300, Loss: 0.4125
Checkpoint saved at epoch 47


Training:  16%|████                     | 48/300 [1:21:16<6:46:32, 96.79s/epoch]

Epoch 48/300, Loss: 0.4070
Checkpoint saved at epoch 48


Training:  16%|████                     | 49/300 [1:22:53<6:44:38, 96.73s/epoch]

Epoch 49/300, Loss: 0.3981
Checkpoint saved at epoch 49


Training:  17%|████▏                    | 50/300 [1:24:28<6:40:52, 96.21s/epoch]

Epoch 50/300, Loss: 0.3907
Checkpoint saved at epoch 50


Training:  17%|████▎                    | 51/300 [1:26:05<6:40:50, 96.59s/epoch]

Epoch 51/300, Loss: 0.3897
Checkpoint saved at epoch 51


Training:  17%|████▎                    | 52/300 [1:27:41<6:38:32, 96.42s/epoch]

Epoch 52/300, Loss: 0.3851
Checkpoint saved at epoch 52


Training:  18%|████▍                    | 53/300 [1:29:19<6:38:20, 96.76s/epoch]

Epoch 53/300, Loss: 0.3830
Checkpoint saved at epoch 53


Training:  18%|████▌                    | 54/300 [1:30:55<6:36:21, 96.67s/epoch]

Epoch 54/300, Loss: 0.3772
Checkpoint saved at epoch 54


Training:  18%|████▌                    | 55/300 [1:32:33<6:36:11, 97.03s/epoch]

Epoch 55/300, Loss: 0.3763
Checkpoint saved at epoch 55


Training:  19%|████▋                    | 56/300 [1:34:14<6:39:20, 98.20s/epoch]

Epoch 56/300, Loss: 0.3739
Checkpoint saved at epoch 56


Training:  19%|████▊                    | 57/300 [1:35:57<6:42:55, 99.49s/epoch]

Epoch 57/300, Loss: 0.3673
Checkpoint saved at epoch 57


Training:  19%|████▋                   | 58/300 [1:37:39<6:44:44, 100.35s/epoch]

Epoch 58/300, Loss: 0.3602
Checkpoint saved at epoch 58


Training:  20%|████▋                   | 59/300 [1:39:20<6:43:45, 100.52s/epoch]

Epoch 59/300, Loss: 0.3578
Checkpoint saved at epoch 59


Training:  20%|████▊                   | 60/300 [1:41:01<6:42:13, 100.56s/epoch]

Epoch 60/300, Loss: 0.3543
Checkpoint saved at epoch 60


Training:  20%|████▉                   | 61/300 [1:42:44<6:43:57, 101.41s/epoch]

Epoch 61/300, Loss: 0.3481
Checkpoint saved at epoch 61


Training:  21%|█████▏                   | 62/300 [1:44:20<6:36:21, 99.92s/epoch]

Epoch 62/300, Loss: 0.3476
Checkpoint saved at epoch 62


Training:  21%|█████▎                   | 63/300 [1:45:58<6:31:53, 99.21s/epoch]

Epoch 63/300, Loss: 0.3431
Checkpoint saved at epoch 63


Training:  21%|█████▎                   | 64/300 [1:47:33<6:25:34, 98.03s/epoch]

Epoch 64/300, Loss: 0.3384
Checkpoint saved at epoch 64


Training:  22%|█████▍                   | 65/300 [1:49:09<6:20:51, 97.24s/epoch]

Epoch 65/300, Loss: 0.3373
Checkpoint saved at epoch 65


Training:  22%|█████▌                   | 66/300 [1:50:43<6:16:04, 96.43s/epoch]

Epoch 66/300, Loss: 0.3297
Checkpoint saved at epoch 66


Training:  22%|█████▌                   | 67/300 [1:52:17<6:11:20, 95.63s/epoch]

Epoch 67/300, Loss: 0.3254
Checkpoint saved at epoch 67


Training:  23%|█████▋                   | 68/300 [1:53:51<6:08:14, 95.23s/epoch]

Epoch 68/300, Loss: 0.3229
Checkpoint saved at epoch 68


Training:  23%|█████▊                   | 69/300 [1:55:28<6:07:59, 95.58s/epoch]

Epoch 69/300, Loss: 0.3218
Checkpoint saved at epoch 69


Training:  23%|█████▊                   | 70/300 [1:57:04<6:07:27, 95.86s/epoch]

Epoch 70/300, Loss: 0.3223
Checkpoint saved at epoch 70


Training:  24%|█████▉                   | 71/300 [1:58:41<6:06:34, 96.05s/epoch]

Epoch 71/300, Loss: 0.3124
Checkpoint saved at epoch 71


Training:  24%|██████                   | 72/300 [2:00:17<6:05:12, 96.11s/epoch]

Epoch 72/300, Loss: 0.3127
Checkpoint saved at epoch 72


Training:  24%|██████                   | 73/300 [2:01:52<6:02:38, 95.85s/epoch]

Epoch 73/300, Loss: 0.3082
Checkpoint saved at epoch 73


Training:  25%|██████▏                  | 74/300 [2:03:26<5:58:34, 95.20s/epoch]

Epoch 74/300, Loss: 0.3069
Checkpoint saved at epoch 74


Training:  25%|██████▎                  | 75/300 [2:05:01<5:56:28, 95.06s/epoch]

Epoch 75/300, Loss: 0.3031
Checkpoint saved at epoch 75


Training:  25%|██████▎                  | 76/300 [2:06:36<5:55:14, 95.15s/epoch]

Epoch 76/300, Loss: 0.3005
Checkpoint saved at epoch 76


Training:  26%|██████▍                  | 77/300 [2:08:12<5:54:48, 95.46s/epoch]

Epoch 77/300, Loss: 0.2989
Checkpoint saved at epoch 77


Training:  26%|██████▌                  | 78/300 [2:09:49<5:54:34, 95.83s/epoch]

Epoch 78/300, Loss: 0.2951
Checkpoint saved at epoch 78


Training:  26%|██████▌                  | 79/300 [2:11:25<5:53:02, 95.85s/epoch]

Epoch 79/300, Loss: 0.2924
Checkpoint saved at epoch 79


Training:  27%|██████▋                  | 80/300 [2:12:59<5:50:02, 95.47s/epoch]

Epoch 80/300, Loss: 0.2904
Checkpoint saved at epoch 80


Training:  27%|██████▊                  | 81/300 [2:14:34<5:47:22, 95.17s/epoch]

Epoch 81/300, Loss: 0.2861
Checkpoint saved at epoch 81


Training:  27%|██████▊                  | 82/300 [2:16:07<5:43:32, 94.55s/epoch]

Epoch 82/300, Loss: 0.2865
Checkpoint saved at epoch 82


Training:  28%|██████▉                  | 83/300 [2:17:40<5:40:41, 94.20s/epoch]

Epoch 83/300, Loss: 0.2841
Checkpoint saved at epoch 83


Training:  28%|███████                  | 84/300 [2:19:14<5:38:33, 94.05s/epoch]

Epoch 84/300, Loss: 0.2793
Checkpoint saved at epoch 84


Training:  28%|███████                  | 85/300 [2:20:50<5:39:29, 94.74s/epoch]

Epoch 85/300, Loss: 0.2761
Checkpoint saved at epoch 85


Training:  29%|███████▏                 | 86/300 [2:22:28<5:40:52, 95.57s/epoch]

Epoch 86/300, Loss: 0.2735
Checkpoint saved at epoch 86


Training:  29%|███████▏                 | 87/300 [2:24:04<5:39:37, 95.67s/epoch]

Epoch 87/300, Loss: 0.2732
Checkpoint saved at epoch 87


Training:  29%|███████▎                 | 88/300 [2:25:40<5:38:28, 95.80s/epoch]

Epoch 88/300, Loss: 0.2747
Checkpoint saved at epoch 88


Training:  30%|███████▍                 | 89/300 [2:27:15<5:35:58, 95.54s/epoch]

Epoch 89/300, Loss: 0.2705
Checkpoint saved at epoch 89


Training:  30%|███████▌                 | 90/300 [2:28:48<5:32:33, 95.02s/epoch]

Epoch 90/300, Loss: 0.2656
Checkpoint saved at epoch 90


Training:  30%|███████▌                 | 91/300 [2:30:22<5:29:08, 94.49s/epoch]

Epoch 91/300, Loss: 0.2650
Checkpoint saved at epoch 91


Training:  31%|███████▋                 | 92/300 [2:31:55<5:25:59, 94.03s/epoch]

Epoch 92/300, Loss: 0.2645
Checkpoint saved at epoch 92


Training:  31%|███████▊                 | 93/300 [2:33:28<5:23:56, 93.90s/epoch]

Epoch 93/300, Loss: 0.2618
Checkpoint saved at epoch 93


Training:  31%|███████▊                 | 94/300 [2:35:02<5:22:32, 93.94s/epoch]

Epoch 94/300, Loss: 0.2597
Checkpoint saved at epoch 94


Training:  32%|███████▉                 | 95/300 [2:36:38<5:22:45, 94.47s/epoch]

Epoch 95/300, Loss: 0.2586
Checkpoint saved at epoch 95


Training:  32%|████████                 | 96/300 [2:38:15<5:24:11, 95.35s/epoch]

Epoch 96/300, Loss: 0.2543
Checkpoint saved at epoch 96


Training:  32%|████████                 | 97/300 [2:39:50<5:22:08, 95.22s/epoch]

Epoch 97/300, Loss: 0.2551
Checkpoint saved at epoch 97


Training:  33%|████████▏                | 98/300 [2:41:24<5:19:10, 94.80s/epoch]

Epoch 98/300, Loss: 0.2509
Checkpoint saved at epoch 98


Training:  33%|████████▎                | 99/300 [2:42:58<5:16:17, 94.42s/epoch]

Epoch 99/300, Loss: 0.2484
Checkpoint saved at epoch 99


Training:  33%|████████                | 100/300 [2:44:31<5:13:59, 94.20s/epoch]

Epoch 100/300, Loss: 0.2481
Checkpoint saved at epoch 100


Training:  34%|████████                | 101/300 [2:46:06<5:13:18, 94.47s/epoch]

Epoch 101/300, Loss: 0.2478
Checkpoint saved at epoch 101


Training:  34%|████████▏               | 102/300 [2:47:42<5:12:28, 94.69s/epoch]

Epoch 102/300, Loss: 0.2441
Checkpoint saved at epoch 102


Training:  34%|████████▏               | 103/300 [2:49:17<5:11:48, 94.97s/epoch]

Epoch 103/300, Loss: 0.2455
Checkpoint saved at epoch 103


Training:  35%|████████▎               | 104/300 [2:50:53<5:10:54, 95.17s/epoch]

Epoch 104/300, Loss: 0.2418
Checkpoint saved at epoch 104


Training:  35%|████████▍               | 105/300 [2:52:28<5:08:42, 94.99s/epoch]

Epoch 105/300, Loss: 0.2424
Checkpoint saved at epoch 105


Training:  35%|████████▍               | 106/300 [2:54:01<5:05:58, 94.63s/epoch]

Epoch 106/300, Loss: 0.2393
Checkpoint saved at epoch 106


Training:  36%|████████▌               | 107/300 [2:55:34<5:02:40, 94.10s/epoch]

Epoch 107/300, Loss: 0.2357
Checkpoint saved at epoch 107


Training:  36%|████████▋               | 108/300 [2:57:08<5:00:56, 94.04s/epoch]

Epoch 108/300, Loss: 0.2345
Checkpoint saved at epoch 108


Training:  36%|████████▋               | 109/300 [2:58:41<4:58:21, 93.73s/epoch]

Epoch 109/300, Loss: 0.2337
Checkpoint saved at epoch 109


Training:  37%|████████▊               | 110/300 [3:00:14<4:56:05, 93.50s/epoch]

Epoch 110/300, Loss: 0.2312
Checkpoint saved at epoch 110


Training:  37%|████████▉               | 111/300 [3:01:48<4:55:07, 93.69s/epoch]

Epoch 111/300, Loss: 0.2316
Checkpoint saved at epoch 111


Training:  37%|████████▉               | 112/300 [3:03:23<4:54:46, 94.08s/epoch]

Epoch 112/300, Loss: 0.2306
Checkpoint saved at epoch 112


Training:  38%|█████████               | 113/300 [3:04:58<4:53:51, 94.29s/epoch]

Epoch 113/300, Loss: 0.2284
Checkpoint saved at epoch 113


Training:  38%|█████████               | 114/300 [3:06:32<4:52:20, 94.30s/epoch]

Epoch 114/300, Loss: 0.2261
Checkpoint saved at epoch 114


Training:  38%|█████████▏              | 115/300 [3:08:06<4:50:23, 94.18s/epoch]

Epoch 115/300, Loss: 0.2260
Checkpoint saved at epoch 115


Training:  39%|█████████▎              | 116/300 [3:09:40<4:48:26, 94.05s/epoch]

Epoch 116/300, Loss: 0.2228
Checkpoint saved at epoch 116


Training:  39%|█████████▎              | 117/300 [3:11:13<4:45:52, 93.73s/epoch]

Epoch 117/300, Loss: 0.2225
Checkpoint saved at epoch 117


Training:  39%|█████████▍              | 118/300 [3:12:48<4:45:20, 94.07s/epoch]

Epoch 118/300, Loss: 0.2208
Checkpoint saved at epoch 118


Training:  40%|█████████▌              | 119/300 [3:14:27<4:48:17, 95.56s/epoch]

Epoch 119/300, Loss: 0.2193
Checkpoint saved at epoch 119


Training:  40%|█████████▌              | 119/300 [3:15:58<4:58:04, 98.81s/epoch]


KeyboardInterrupt: 

In [None]:
# if __name__ == "__main__":
#     sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
#     num_epochs = 300  # Adjust the number of epochs as needed
#     test_size = 0.2   # Proportion of the dataset to include in the test split
#     learning_rate = 0.001  # Learning rate

#     # Run the training function
#     train_mse_list, train_ci_list, train_pearson_list, test_mse_list, test_ci_list, test_pearson_list = train_and_evaluate(
#         sample_dir,
#         num_epochs=num_epochs,
#         test_size=test_size,
#         lr=learning_rate
#     )

#     # Optionally, plot the metrics over epochs
#     # For example, using matplotlib
#     import matplotlib.pyplot as plt

#     epochs = range(1, num_epochs + 1)

#     plt.figure(figsize=(12, 4))

#     # Plot MSE
#     plt.subplot(1, 3, 1)
#     plt.plot(epochs, train_mse_list, label='Train MSE')
#     plt.plot(epochs, test_mse_list, label='Test MSE')
#     plt.xlabel('Epoch')
#     plt.ylabel('MSE')
#     plt.title('MSE over Epochs')
#     plt.legend()

#     # Plot CI
#     plt.subplot(1, 3, 2)
#     plt.plot(epochs, train_ci_list, label='Train CI')
#     plt.plot(epochs, test_ci_list, label='Test CI')
#     plt.xlabel('Epoch')
#     plt.ylabel('Concordance Index (CI)')
#     plt.title('CI over Epochs')
#     plt.legend()

#     # Plot Pearson Correlation
#     plt.subplot(1, 3, 3)
#     plt.plot(epochs, train_pearson_list, label='Train Pearson')
#     plt.plot(epochs, test_pearson_list, label='Test Pearson')
#     plt.xlabel('Epoch')
#     plt.ylabel('Pearson Correlation')
#     plt.title('Pearson Correlation over Epochs')
#     plt.legend()

#     plt.tight_layout()
#     plt.show()
