In [1]:
# !unzip /content/full\ graphs.zip


In [2]:
# !pip install torch_geometric
# import os
# import torch
# import pickle
# import pandas as pd

# def load_graph(path, is_pickle=True):
#     """
#     Load a molecule graph (.pkl) or a protein graph (.pt).
#     If is_pickle is True, use pickle to load the file; otherwise, use torch.load.
#     """
#     if is_pickle:
#         with open(path, 'rb') as f:
#             return pickle.load(f)
#     else:
#         return torch.load(path)

# def prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir):
#     """
#     Incrementally prepares the dataset and saves each (molecule, protein, target) tuple as a separate .pt file.

#     Args:
#     - filtered_dataset: The filtered KIBA dataset (DataFrame).
#     - molecule_graph_dir: Directory where molecule graphs are stored.
#     - protein_graph_dir: Directory where protein graphs are stored.
#     - output_dir: Directory to save the prepared dataset incrementally.
#     """
#     if not os.path.exists(output_dir):
#         os.makedirs(output_dir)

#     for index, row in filtered_dataset.iterrows():
#         protein_id = row['Target_ID']
#         chembl_id = row['Drug_ID']

#         # Load the protein graph (.pt)
#         pro_graph_path = os.path.join(protein_graph_dir, f"{protein_id}_graph.pt")
#         if not os.path.exists(pro_graph_path):
#             print(f"Protein graph not found: {protein_id}")
#             continue
#         pro_graph = load_graph(pro_graph_path, is_pickle=False)

#         # Load the molecule graph (.pkl)
#         mol_graph_path = os.path.join(molecule_graph_dir, f"{chembl_id}_graph.pkl")
#         if not os.path.exists(mol_graph_path):
#             print(f"Molecule graph not found: {chembl_id}")
#             continue
#         mol_graph = load_graph(mol_graph_path)

#         # Load target (affinity value)
#         target = torch.tensor([row['Y']], dtype=torch.float)

#         # Create the sample as a tuple (molecule graph, protein graph, target)
#         sample = (mol_graph, pro_graph, target)

#         # Save the sample as a .pt file
#         sample_path = os.path.join(output_dir, f"sample_{index}.pt")
#         torch.save(sample, sample_path)

#         print(f"Saved sample {index} as {sample_path}")

# # Example usage for individual saving
# molecule_graph_dir = '/content/molecule_graphs'  # Directory where molecule graphs are stored
# protein_graph_dir = '/content/ProteinGraphs'  # Directory where protein graphs are stored
# filtered_dataset_path = '/content/filtered_KibaDataSet.csv'  # Path to the filtered dataset CSV
# output_dir = '/content/prepared_samples/'  # Directory to save individual samples

# # Load filtered dataset CSV
# filtered_dataset = pd.read_csv(filtered_dataset_path)

# # Prepare the dataset incrementally, saving each sample as a .pt file
# prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir)

# print("Dataset preparation completed.")


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj


# GCN based model
class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        self.n_output = n_output
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        self.mol_conv2 = GCNConv(num_features_mol, num_features_mol * 2)
        self.mol_conv3 = GCNConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # self.pro_conv1 = GCNConv(embed_dim, embed_dim)
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        self.pro_conv2 = GCNConv(num_features_pro, num_features_pro * 2)
        self.pro_conv3 = GCNConv(num_features_pro * 2, num_features_pro * 4)
        # self.pro_conv4 = GCNConv(embed_dim * 4, embed_dim * 8)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro * 4, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # combined layers
        self.fc1 = nn.Linear(2 * output_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # get graph input
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        # get protein input
        target_x, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch

        # target_seq=data_pro.target

        # print('size')
        # print('mol_x', mol_x.size(), 'edge_index', mol_edge_index.size(), 'batch', mol_batch.size())
        # print('target_x', target_x.size(), 'target_edge_index', target_batch.size(), 'batch', target_batch.size())

        x = self.mol_conv1(mol_x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv2(x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv3(x, mol_edge_index)
        x = self.relu(x)
        x = gep(x, mol_batch)  # global pooling

        # flatten
        x = self.relu(self.mol_fc_g1(x))
        x = self.dropout(x)
        x = self.mol_fc_g2(x)
        x = self.dropout(x)

        xt = self.pro_conv1(target_x, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv2(xt, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv3(xt, target_edge_index)
        xt = self.relu(xt)

        # xt = self.pro_conv4(xt, target_edge_index)
        # xt = self.relu(xt)
        xt = gep(xt, target_batch)  # global pooling

        # flatten
        xt = self.relu(self.pro_fc_g1(xt))
        xt = self.dropout(xt)
        xt = self.pro_fc_g2(xt)
        xt = self.dropout(xt)

        # print(x.size(), xt.size())
        # concat
        xc = torch.cat((x, xt), 1)
        # add some dense layers
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out


In [9]:
# ... [Your existing imports and GNNNet class definition remain unchanged] ...

import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)

def adjust_state_dict(state_dict, model_state_dict):
    new_state_dict = {}
    for k in state_dict.keys():
        new_k = k
        # Adjust mol_conv layers
        if 'mol_conv' in k:
            if 'weight' in k and 'lin.weight' not in k:
                new_k = k.replace('weight', 'lin.weight')
            if 'bias' in k and 'lin.bias' not in k:
                new_k = k.replace('bias', 'lin.bias')
        # Adjust pro_conv layers
        elif 'pro_conv' in k:
            if 'weight' in k and 'lin.weight' not in k:
                new_k = k.replace('weight', 'lin.weight')
            if 'bias' in k and 'lin.bias' not in k:
                new_k = k.replace('bias', 'lin.bias')

        param = state_dict[k]
        if new_k in model_state_dict:
            model_param = model_state_dict[new_k]
            if param.shape != model_param.shape:
                if len(param.shape) == 2:
                    # Transpose the weight matrix
                    param = param.t()
                    if param.shape != model_param.shape:
                        print(f"Shape mismatch after transpose for {new_k}: checkpoint {param.shape}, model {model_param.shape}")
                else:
                    print(f"Shape mismatch for {new_k}: checkpoint {param.shape}, model {model_param.shape}")
        new_state_dict[new_k] = param
    return new_state_dict

# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]



def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, loading pretrained model from 'model_GNNNet_kiba.model'.")

            # Load the pretrained model weights and adjust the keys
            state_dict = torch.load('model_GNNNet_kiba.model', map_location=device)
            adjusted_state_dict = adjust_state_dict(state_dict, model.state_dict())
            model.load_state_dict(adjusted_state_dict, strict=False)

            # The optimizer remains as initialized

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 256 # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * len(batch_samples)

            avg_loss = running_loss / len(train_files)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 256  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results


if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 5  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
No checkpoint found for fold 1, loading pretrained model from 'model_GNNNet_kiba.model'.


Training Fold 1:   0%|          | 0/5 [15:42<?, ?epoch/s]

Fold 1, Epoch 1/5, Loss: 10.3111
Checkpoint saved for fold 1 at epoch 1


Training Fold 1:  20%|██        | 1/5 [23:32<1:34:09, 1412.46s/epoch]

Fold 1, Epoch 1/5 - MSE: 0.6774, CI: 0.4904, Pearson: -0.0099


Training Fold 1:  20%|██        | 1/5 [40:22<1:34:09, 1412.46s/epoch]

Fold 1, Epoch 2/5, Loss: 0.7066
Checkpoint saved for fold 1 at epoch 2


Training Fold 1:  40%|████      | 2/5 [48:11<1:12:34, 1451.43s/epoch]

Fold 1, Epoch 2/5 - MSE: 0.6781, CI: 0.5134, Pearson: 0.0454


Training Fold 1:  40%|████      | 2/5 [1:04:27<1:12:34, 1451.43s/epoch]

Fold 1, Epoch 3/5, Loss: 0.6967
Checkpoint saved for fold 1 at epoch 3


Training Fold 1:  60%|██████    | 3/5 [1:12:24<48:24, 1452.43s/epoch]  

Fold 1, Epoch 3/5 - MSE: 0.6842, CI: 0.5306, Pearson: 0.0858


Training Fold 1:  60%|██████    | 3/5 [1:29:43<48:24, 1452.43s/epoch]

Fold 1, Epoch 4/5, Loss: 0.6910
Checkpoint saved for fold 1 at epoch 4


Training Fold 1:  80%|████████  | 4/5 [1:37:16<24:27, 1467.79s/epoch]

Fold 1, Epoch 4/5 - MSE: 0.6788, CI: 0.5495, Pearson: 0.1465


Training Fold 1:  80%|████████  | 4/5 [1:48:29<27:07, 1627.32s/epoch]


KeyboardInterrupt: 

In [18]:
import os
import torch
import torch.nn as nn
import numpy as np
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr

# Define metric functions
def get_mse(y_true, y_pred):
    return np.mean((y_pred - y_true) ** 2)

def get_ci(y_true, y_pred):
    n = 0
    h_sum = 0.0
    for i in range(len(y_true)):
        for j in range(len(y_true)):
            if y_true[i] > y_true[j]:
                n += 1
                if y_pred[i] > y_pred[j]:
                    h_sum += 1
                elif y_pred[i] == y_pred[j]:
                    h_sum += 0.5
    return h_sum / n if n != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true, y_pred)[0]

def get_spearman(y_true, y_pred):
    return spearmanr(y_true, y_pred)[0]

def get_rmse(y_true, y_pred):
    return np.sqrt(get_mse(y_true, y_pred))

def get_rm2(y_true, y_pred):
    y_mean = np.mean(y_true)
    ss_tot = np.sum((y_true - y_mean) ** 2)
    ss_res = np.sum((y_true - y_pred) ** 2)
    r_squared = 1 - (ss_res / ss_tot)
    return r_squared

import torch

def adjust_state_dict(state_dict):
    new_state_dict = {}
    for k in state_dict.keys():
        new_k = k
        # Adjust mol_conv layers
        if 'mol_conv' in k:
            if 'lin.weight' in k:
                new_k = k.replace('lin.weight', 'weight')
            elif 'lin.bias' in k:
                new_k = k.replace('lin.bias', 'bias')
        # Adjust pro_conv layers
        elif 'pro_conv' in k:
            if 'lin.weight' in k:
                new_k = k.replace('lin.weight', 'weight')
            elif 'lin.bias' in k:
                new_k = k.replace('lin.bias', 'bias')
        new_state_dict[new_k] = state_dict[k]
    return new_state_dict




# Define the load_sample function
def load_sample(path):
    data = torch.load(path)
    if isinstance(data, tuple) and len(data) == 3:
        mol_data, pro_data, target = data
    elif isinstance(data, dict):
        mol_data = data.get('mol_data')
        pro_data = data.get('pro_data')
        target = data.get('target')
    else:
        print(f"Unexpected data format in file {path}")
        return None

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'batch' attribute is set
    if not hasattr(mol_data, 'batch') or mol_data.batch is None:
        mol_data.batch = torch.zeros(mol_data.num_nodes, dtype=torch.long)
    if not hasattr(pro_data, 'batch') or pro_data.batch is None:
        pro_data.batch = torch.zeros(pro_data.num_nodes, dtype=torch.long)

    return (mol_data, pro_data, target)

# Define the predicting function
def predicting(model, device, samples):
    model.eval()
    total_preds = []
    total_labels = []
    print('Making predictions for {} samples...'.format(len(samples)))
    with torch.no_grad():
        batch_size = 256  # Adjust as necessary
        for i in tqdm(range(0, len(samples), batch_size), desc='Predicting'):
            batch_samples = samples[i:i+batch_size]
            mol_data_list = []
            pro_data_list = []
            target_list = []
            for sample in batch_samples:
                if sample is None:
                    continue
                mol_data, pro_data, target = sample
                mol_data_list.append(mol_data)
                pro_data_list.append(pro_data)
                target_list.append(target)
            if len(mol_data_list) == 0:
                continue
            mol_batch = Batch.from_data_list(mol_data_list).to(device)
            pro_batch = Batch.from_data_list(pro_data_list).to(device)
            targets = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)
            outputs = model(mol_batch, pro_batch)
            total_preds.append(outputs.cpu().numpy())
            total_labels.append(targets.cpu().numpy())
    if len(total_preds) == 0:
        print("No predictions were made.")
        return None, None
    total_preds = np.concatenate(total_preds)
    total_labels = np.concatenate(total_labels)
    return total_labels, total_preds

# Define the main function
def main():
    # Set the sample directory where your data is stored
    sample_dir = 'prepared_samples'  # Adjust the path to your available samples
    model_path = 'model_GNNNet_kiba.model'  # Path to your pretrained model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)
    
    # Load the model
    # Determine input feature dimensions from your data
    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    if not sample_files:
        print(f"No sample files found in {sample_dir}")
        return
    sample_file_path = os.path.join(sample_dir, sample_files[0])
    sample = load_sample(sample_file_path)
    if sample is None:
        print("Failed to load sample.")
        return

    mol_data, pro_data, _ = sample

    num_features_mol = mol_data.x.size(1)
    num_features_pro = pro_data.x.size(1)
    
    model = GNNNet(num_features_mol=num_features_mol, num_features_pro=num_features_pro).to(device)
    
    
    # Load the pretrained model weights and adjust the keys
    state_dict = torch.load(model_path, map_location=device)
    adjusted_state_dict = adjust_state_dict(state_dict, model.state_dict())
    model.load_state_dict(adjusted_state_dict, strict=False)
    
    # Load the data
    samples = []
    for sample_file in sample_files:
        sample_path = os.path.join(sample_dir, sample_file)
        sample = load_sample(sample_path)
        if sample is not None:
            samples.append(sample)
    if not samples:
        print(f"No valid samples found in {sample_dir}")
        return
    
    # Run predictions
    Y, P = predicting(model, device, samples)
    if Y is None or P is None:
        print("No predictions were made.")
        return
    
    # Calculate metrics
    mse = get_mse(Y, P)
    ci = get_ci(Y, P)
    pearson = get_pearson(Y, P)
    spearman = get_spearman(Y, P)
    rmse = get_rmse(Y, P)
    rm2 = get_rm2(Y, P)
    
    # Print metrics
    print('Evaluation Metrics:')
    print(f'MSE: {mse:.4f}')
    print(f'RMSE: {rmse:.4f}')
    print(f'Pearson Correlation: {pearson:.4f}')
    print(f'Spearman Correlation: {spearman:.4f}')
    print(f'CI: {ci:.4f}')
    print(f'R-squared (RM2): {rm2:.4f}')
    
    # Optionally, plot the results
    # plot_density(Y, P)

if __name__ == '__main__':
    main()


Using device: cpu


TypeError: zeros() received an invalid combination of arguments - got (NoneType, dtype=torch.dtype), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [5]:
# import os
# import torch
# import torch.optim as optim
# from torch.nn import MSELoss
# from sklearn.model_selection import KFold
# import numpy as np
# from tqdm import tqdm
# from torch_geometric.data import Data, Batch
# from torch_geometric.nn import GCNConv, global_mean_pool as gep
# from scipy.stats import pearsonr
# import warnings
# import itertools
# from collections import OrderedDict

# # Suppress FutureWarning related to torch.load
# warnings.filterwarnings('ignore', category=FutureWarning)

# # Define the load_sample function
# def load_sample(path):
#     # Load individual sample from file
#     sample = torch.load(path)
#     mol_data = sample[0]
#     pro_data = sample[1]
#     target = sample[2]

#     # Convert dictionaries to Data objects if necessary
#     if isinstance(mol_data, dict):
#         mol_data = Data(**mol_data)
#     if isinstance(pro_data, dict):
#         pro_data = Data(**pro_data)

#     # Ensure that 'x' attribute is set
#     if not hasattr(mol_data, 'x') or mol_data.x is None:
#         if hasattr(mol_data, 'features'):
#             mol_data.x = mol_data.features
#             del mol_data.features
#         else:
#             raise ValueError("mol_data does not have 'x' or 'features' attribute")

#     if not hasattr(pro_data, 'x') or pro_data.x is None:
#         if hasattr(pro_data, 'features'):
#             pro_data.x = pro_data.features
#             del pro_data.features
#         else:
#             raise ValueError("pro_data does not have 'x' or 'features' attribute")

#     # Ensure 'x' is a float tensor
#     if not isinstance(mol_data.x, torch.Tensor):
#         mol_data.x = torch.tensor(mol_data.x)
#     if not isinstance(pro_data.x, torch.Tensor):
#         pro_data.x = torch.tensor(pro_data.x)

#     if mol_data.x.dtype != torch.float:
#         mol_data.x = mol_data.x.float()
#     if pro_data.x.dtype != torch.float:
#         pro_data.x = pro_data.x.float()

#     # Adjust 'edge_index' for mol_data
#     if not isinstance(mol_data.edge_index, torch.Tensor):
#         mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
#     else:
#         mol_data.edge_index = mol_data.edge_index.long()

#     if mol_data.edge_index.shape[0] != 2:
#         mol_data.edge_index = mol_data.edge_index.t()

#     # Adjust 'edge_index' for pro_data
#     if not isinstance(pro_data.edge_index, torch.Tensor):
#         pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
#     else:
#         pro_data.edge_index = pro_data.edge_index.long()

#     if pro_data.edge_index.shape[0] != 2:
#         pro_data.edge_index = pro_data.edge_index.t()

#     # Set 'num_nodes' attribute to suppress warnings
#     mol_data.num_nodes = mol_data.x.size(0)
#     pro_data.num_nodes = pro_data.x.size(0)

#     return (mol_data, pro_data, target)

# # Define the batch_loader function
# def batch_loader(file_list, sample_dir, batch_size):
#     batch = []
#     for idx, file_name in enumerate(file_list):
#         sample_path = os.path.join(sample_dir, file_name)
#         sample = load_sample(sample_path)
#         batch.append(sample)
#         if len(batch) == batch_size:
#             yield batch
#             batch = []
#     if len(batch) > 0:
#         yield batch

# # Define the evaluation metrics functions
# def get_mse(y_true, y_pred):
#     return np.mean((y_true - y_pred) ** 2)

# def get_ci(y_true, y_pred):
#     """
#     Compute the concordance index between true and predicted values.
#     """
#     pairs = itertools.combinations(range(len(y_true)), 2)
#     c = 0
#     s = 0
#     for i, j in pairs:
#         if y_true[i] != y_true[j]:
#             s += 1
#             if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
#                (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
#                 c += 1
#             elif y_pred[i] == y_pred[j]:
#                 c += 0.5
#     return c / s if s != 0 else 0

# def get_pearson(y_true, y_pred):
#     return pearsonr(y_true.flatten(), y_pred.flatten())[0]

# def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Running on {device}.")

#     sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
#     kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

#     # Create a single directory for all checkpoints
#     training_model_dir = os.path.join(sample_dir, 'TrainingModel')
#     if not os.path.exists(training_model_dir):
#         os.makedirs(training_model_dir)
#         print(f"Created directory for checkpoints at {training_model_dir}")
#     else:
#         print(f"Using existing TrainingModel directory at {training_model_dir}")

#     results = []
#     loss_fn = MSELoss()

#     for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
#         fold_number = fold + 1
#         print(f'\nFold {fold_number}/{n_splits}')
#         train_files = [sample_files[i] for i in train_idx]
#         test_files = [sample_files[i] for i in test_idx]

#         # Determine input feature dimensions from your data
#         sample = load_sample(os.path.join(sample_dir, train_files[0]))
#         mol_data = sample[0]
#         pro_data = sample[1]

#         num_features_mol = 78
#         num_features_pro = 54

#         # Initialize the GNN model with correct input dimensions
#         model = GNNNet(
#             num_features_mol=num_features_mol,
#             num_features_pro=num_features_pro
#         ).to(device)
#         print(f"Model is on device: {next(model.parameters()).device}")

#         # Initialize starting epoch
#         start_epoch = 1

#         # Check for existing checkpoints in TrainingModel directory for the current fold
#         existing_checkpoints = [f for f in os.listdir(training_model_dir)
#                                 if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

#         if existing_checkpoints:
#             # Find the latest checkpoint based on epoch number
#             latest_checkpoint = max(existing_checkpoints,
#                                     key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
#             checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
#             print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
#             checkpoint = torch.load(checkpoint_path, map_location=device)

#             # Determine the format of the checkpoint and load accordingly
#             if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
#                 # Standard checkpoint format
#                 model.load_state_dict(checkpoint['model_state_dict'])
#                 optimizer = optim.Adam(model.parameters(), lr=lr)
#                 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#                 loaded_epoch = checkpoint['epoch']
#                 start_epoch = loaded_epoch + 1
#                 print(f"Resuming training from epoch {start_epoch}")
#             else:
#                 # Checkpoint is the model's state_dict directly
#                 model.load_state_dict(checkpoint)
#                 optimizer = optim.Adam(model.parameters(), lr=lr)
#                 # Since epoch is not saved, start from 1 or set as needed
#                 start_epoch = 1
#                 print("Loaded model state_dict directly from checkpoint.")
#                 print(f"Starting training from epoch {start_epoch}")
#         else:
#             print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")
#             # Load pre-trained model weights
#             model_weight_path = 'model_GNNNet_kiba.model'  # Adjust the path accordingly
#             if os.path.exists(model_weight_path):
#                 # Load the state_dict from the file
#                 state_dict = torch.load(model_weight_path, map_location=device)
#                 # Load the state_dict into the model
#                 if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
#                     # If state_dict has 'model_state_dict' key, use it
#                     model.load_state_dict(state_dict['model_state_dict'])
#                     print(f"Loaded pre-trained model weights from {model_weight_path} using 'model_state_dict' key.")
#                 else:
#                     # Assume state_dict is the model's state_dict directly
#                     model.load_state_dict(state_dict)
#                     print(f"Loaded pre-trained model weights from {model_weight_path}")
#             else:
#                 print(f"Pre-trained model weights not found at {model_weight_path}, proceeding without loading weights.")
#             optimizer = optim.Adam(model.parameters(), lr=lr)
#             start_epoch = 1

#         # Training loop with progress bar over epochs
#         for epoch in tqdm(range(start_epoch, num_epochs + 1),
#                           desc=f"Training Fold {fold_number}", unit="epoch"):
#             model.train()
#             running_loss = 0.0

#             # Prepare batch loader without progress bar for batches
#             batch_size = 128  # Adjust batch size as needed
#             batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

#             for batch_samples in batch_loader_iter:
#                 mol_data_list = []
#                 pro_data_list = []
#                 target_list = []

#                 for sample in batch_samples:
#                     mol_data = sample[0]
#                     pro_data = sample[1]
#                     target = sample[2]

#                     mol_data_list.append(mol_data)
#                     pro_data_list.append(pro_data)
#                     target_list.append(target)

#                 mol_batch = Batch.from_data_list(mol_data_list).to(device)
#                 pro_batch = Batch.from_data_list(pro_data_list).to(device)
#                 target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

#                 optimizer.zero_grad()
#                 output = model(mol_batch, pro_batch)
#                 loss = loss_fn(output.view(-1), target)
#                 loss.backward()
#                 optimizer.step()
#                 running_loss += loss.item() * len(batch_samples)

#             avg_loss = running_loss / len(train_files)
#             # Use tqdm.write() to print without interfering with the progress bar
#             tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

#             # Save the model and optimizer states after each epoch
#             checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
#             checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#             }, checkpoint_path)
#             tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

#             # Evaluation on the test set after each epoch
#             model.eval()
#             total_preds, total_labels = [], []
#             with torch.no_grad():
#                 batch_size = 128  # Adjust batch size as needed
#                 batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

#                 for batch_samples in batch_loader_iter:
#                     mol_data_list = []
#                     pro_data_list = []
#                     target_list = []

#                     for sample in batch_samples:
#                         mol_data = sample[0]
#                         pro_data = sample[1]
#                         target = sample[2]

#                         mol_data_list.append(mol_data)
#                         pro_data_list.append(pro_data)
#                         target_list.append(target)

#                     mol_batch = Batch.from_data_list(mol_data_list).to(device)
#                     pro_batch = Batch.from_data_list(pro_data_list).to(device)
#                     target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

#                     output = model(mol_batch, pro_batch)
#                     total_preds.append(output.cpu().numpy())
#                     total_labels.append(target.cpu().numpy())

#             # Convert lists to numpy arrays for evaluation
#             total_preds = np.concatenate(total_preds)
#             total_labels = np.concatenate(total_labels)

#             # Calculate metrics
#             mse = get_mse(total_labels, total_preds)
#             ci = get_ci(total_labels, total_preds)
#             pearson = get_pearson(total_labels, total_preds)

#             # Print metrics
#             tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

#         # Evaluation at the end of training for this fold
#         print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
#         # Store results for this fold
#         results.append((mse, ci, pearson))

#     return results

# # Example usage
# if __name__ == "__main__":
#     sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
#     num_epochs = 5  # Adjust the number of epochs as needed
#     n_splits = 5  # Number of folds for cross-validation
#     learning_rate = 0.001  # Learning rate

#     # Run the training function
#     results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

#     # Print overall results
#     print("\nCross-validation Results:")
#     for fold_idx, (mse, ci, pearson) in enumerate(results):
#         print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

#     # Optionally, compute and print average metrics across folds
#     mse_values, ci_values, pearson_values = zip(*results)
#     print(f"\nAverage Results:")
#     print(f"MSE: {np.mean(mse_values):.4f}")
#     print(f"CI: {np.mean(ci_values):.4f}")
#     print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


In [6]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]

def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 128  # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * len(batch_samples)

            avg_loss = running_loss / len(train_files)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 128  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results

# Example usage
if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 100  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
Loading checkpoint for fold 1 from prepared_samples/TrainingModel/model_fold1_epoch2.pt
Resuming training from epoch 3


Training Fold 1:   0%|          | 0/98 [11:41<?, ?epoch/s]

Fold 1, Epoch 3/100, Loss: 1.1182
Checkpoint saved for fold 1 at epoch 3


Training Fold 1:   1%|          | 1/98 [18:24<29:45:14, 1104.28s/epoch]

Fold 1, Epoch 3/100 - MSE: 1.2433, CI: 0.6028, Pearson: 0.3476


Training Fold 1:   1%|          | 1/98 [30:24<29:45:14, 1104.28s/epoch]

Fold 1, Epoch 4/100, Loss: 1.0650
Checkpoint saved for fold 1 at epoch 4


Training Fold 1:   2%|▏         | 2/98 [37:03<29:41:10, 1113.23s/epoch]

Fold 1, Epoch 4/100 - MSE: 1.3993, CI: 0.6241, Pearson: 0.3955


Training Fold 1:   2%|▏         | 2/98 [49:06<29:41:10, 1113.23s/epoch]

Fold 1, Epoch 5/100, Loss: 0.9468
Checkpoint saved for fold 1 at epoch 5


Training Fold 1:   3%|▎         | 3/98 [55:42<29:26:32, 1115.71s/epoch]

Fold 1, Epoch 5/100 - MSE: 0.8051, CI: 0.6388, Pearson: 0.4455


Training Fold 1:   3%|▎         | 3/98 [1:07:46<29:26:32, 1115.71s/epoch]

Fold 1, Epoch 6/100, Loss: 0.8229
Checkpoint saved for fold 1 at epoch 6


Training Fold 1:   4%|▍         | 4/98 [1:14:24<29:11:41, 1118.10s/epoch]

Fold 1, Epoch 6/100 - MSE: 0.7991, CI: 0.6561, Pearson: 0.4714


Training Fold 1:   4%|▍         | 4/98 [1:27:07<29:11:41, 1118.10s/epoch]

Fold 1, Epoch 7/100, Loss: 0.7494
Checkpoint saved for fold 1 at epoch 7


Training Fold 1:   5%|▌         | 5/98 [1:33:43<29:16:08, 1132.99s/epoch]

Fold 1, Epoch 7/100 - MSE: 0.8292, CI: 0.6685, Pearson: 0.5049


Training Fold 1:   5%|▌         | 5/98 [1:45:55<29:16:08, 1132.99s/epoch]

Fold 1, Epoch 8/100, Loss: 0.7194
Checkpoint saved for fold 1 at epoch 8


Training Fold 1:   6%|▌         | 6/98 [1:52:27<28:52:26, 1129.85s/epoch]

Fold 1, Epoch 8/100 - MSE: 0.9032, CI: 0.6839, Pearson: 0.5291


Training Fold 1:   6%|▌         | 6/98 [2:04:41<28:52:26, 1129.85s/epoch]

Fold 1, Epoch 9/100, Loss: 0.6964
Checkpoint saved for fold 1 at epoch 9


Training Fold 1:   7%|▋         | 7/98 [2:11:17<28:33:45, 1129.95s/epoch]

Fold 1, Epoch 9/100 - MSE: 0.9257, CI: 0.6988, Pearson: 0.5519


Training Fold 1:   7%|▋         | 7/98 [2:23:04<28:33:45, 1129.95s/epoch]

Fold 1, Epoch 10/100, Loss: 0.6750
Checkpoint saved for fold 1 at epoch 10


Training Fold 1:   8%|▊         | 8/98 [2:29:34<27:59:10, 1119.45s/epoch]

Fold 1, Epoch 10/100 - MSE: 0.9027, CI: 0.7090, Pearson: 0.5750


Training Fold 1:   8%|▊         | 8/98 [2:41:43<27:59:10, 1119.45s/epoch]

Fold 1, Epoch 11/100, Loss: 0.6586
Checkpoint saved for fold 1 at epoch 11


Training Fold 1:   9%|▉         | 9/98 [2:48:11<27:39:25, 1118.71s/epoch]

Fold 1, Epoch 11/100 - MSE: 0.9542, CI: 0.7251, Pearson: 0.5881


Training Fold 1:   9%|▉         | 9/98 [3:00:18<27:39:25, 1118.71s/epoch]

Fold 1, Epoch 12/100, Loss: 0.6408
Checkpoint saved for fold 1 at epoch 12


Training Fold 1:  10%|█         | 10/98 [3:06:44<27:17:59, 1116.81s/epoch]

Fold 1, Epoch 12/100 - MSE: 0.9246, CI: 0.7334, Pearson: 0.6076


Training Fold 1:  10%|█         | 10/98 [3:19:04<27:17:59, 1116.81s/epoch]

Fold 1, Epoch 13/100, Loss: 0.6442
Checkpoint saved for fold 1 at epoch 13


Training Fold 1:  11%|█         | 11/98 [3:25:28<27:02:41, 1119.10s/epoch]

Fold 1, Epoch 13/100 - MSE: 0.9777, CI: 0.7384, Pearson: 0.6173


Training Fold 1:  11%|█         | 11/98 [3:37:17<27:02:41, 1119.10s/epoch]

Fold 1, Epoch 14/100, Loss: 0.6180
Checkpoint saved for fold 1 at epoch 14


Training Fold 1:  12%|█▏        | 12/98 [3:43:39<26:31:54, 1110.63s/epoch]

Fold 1, Epoch 14/100 - MSE: 0.8464, CI: 0.7457, Pearson: 0.6311


Training Fold 1:  12%|█▏        | 12/98 [3:55:32<26:31:54, 1110.63s/epoch]

Fold 1, Epoch 15/100, Loss: 0.6089
Checkpoint saved for fold 1 at epoch 15


Training Fold 1:  13%|█▎        | 13/98 [4:01:55<26:06:51, 1106.02s/epoch]

Fold 1, Epoch 15/100 - MSE: 0.8935, CI: 0.7492, Pearson: 0.6440


Training Fold 1:  13%|█▎        | 13/98 [4:13:42<26:06:51, 1106.02s/epoch]

Fold 1, Epoch 16/100, Loss: 0.5886
Checkpoint saved for fold 1 at epoch 16


Training Fold 1:  14%|█▍        | 14/98 [4:20:00<25:39:40, 1099.77s/epoch]

Fold 1, Epoch 16/100 - MSE: 0.9251, CI: 0.7581, Pearson: 0.6678


Training Fold 1:  14%|█▍        | 14/98 [4:31:47<25:39:40, 1099.77s/epoch]

Fold 1, Epoch 17/100, Loss: 0.5810
Checkpoint saved for fold 1 at epoch 17


Training Fold 1:  15%|█▌        | 15/98 [4:38:05<25:15:24, 1095.47s/epoch]

Fold 1, Epoch 17/100 - MSE: 1.0028, CI: 0.7604, Pearson: 0.6714


Training Fold 1:  15%|█▌        | 15/98 [4:49:46<25:15:24, 1095.47s/epoch]

Fold 1, Epoch 18/100, Loss: 0.5519
Checkpoint saved for fold 1 at epoch 18


Training Fold 1:  16%|█▋        | 16/98 [4:55:58<24:47:51, 1088.68s/epoch]

Fold 1, Epoch 18/100 - MSE: 1.0256, CI: 0.7687, Pearson: 0.6867


Training Fold 1:  16%|█▋        | 16/98 [5:07:37<24:47:51, 1088.68s/epoch]

Fold 1, Epoch 19/100, Loss: 0.5526
Checkpoint saved for fold 1 at epoch 19


Training Fold 1:  17%|█▋        | 17/98 [5:13:54<24:24:22, 1084.73s/epoch]

Fold 1, Epoch 19/100 - MSE: 0.9649, CI: 0.7674, Pearson: 0.6916


Training Fold 1:  17%|█▋        | 17/98 [5:25:33<24:24:22, 1084.73s/epoch]

Fold 1, Epoch 20/100, Loss: 0.5419
Checkpoint saved for fold 1 at epoch 20


Training Fold 1:  18%|█▊        | 18/98 [5:31:50<24:02:49, 1082.12s/epoch]

Fold 1, Epoch 20/100 - MSE: 1.0366, CI: 0.7673, Pearson: 0.6935


Training Fold 1:  18%|█▊        | 18/98 [5:43:30<24:02:49, 1082.12s/epoch]

Fold 1, Epoch 21/100, Loss: 0.5357
Checkpoint saved for fold 1 at epoch 21


Training Fold 1:  19%|█▉        | 19/98 [5:49:44<23:41:34, 1079.68s/epoch]

Fold 1, Epoch 21/100 - MSE: 1.0556, CI: 0.7761, Pearson: 0.7028


Training Fold 1:  19%|█▉        | 19/98 [6:01:30<23:41:34, 1079.68s/epoch]

Fold 1, Epoch 22/100, Loss: 0.5291
Checkpoint saved for fold 1 at epoch 22


Training Fold 1:  20%|██        | 20/98 [6:07:42<23:23:07, 1079.33s/epoch]

Fold 1, Epoch 22/100 - MSE: 1.0538, CI: 0.7742, Pearson: 0.7012


Training Fold 1:  20%|██        | 20/98 [6:19:58<23:23:07, 1079.33s/epoch]

Fold 1, Epoch 23/100, Loss: 0.5139
Checkpoint saved for fold 1 at epoch 23


Training Fold 1:  21%|██▏       | 21/98 [6:26:24<23:21:20, 1091.95s/epoch]

Fold 1, Epoch 23/100 - MSE: 1.0722, CI: 0.7811, Pearson: 0.7157


Training Fold 1:  21%|██▏       | 21/98 [6:38:44<23:21:20, 1091.95s/epoch]

Fold 1, Epoch 24/100, Loss: 0.5202
Checkpoint saved for fold 1 at epoch 24


Training Fold 1:  22%|██▏       | 22/98 [6:45:06<23:14:39, 1101.05s/epoch]

Fold 1, Epoch 24/100 - MSE: 1.1145, CI: 0.7874, Pearson: 0.7263


Training Fold 1:  22%|██▏       | 22/98 [6:57:40<23:14:39, 1101.05s/epoch]

Fold 1, Epoch 25/100, Loss: 0.5024
Checkpoint saved for fold 1 at epoch 25


Training Fold 1:  23%|██▎       | 23/98 [7:04:04<23:10:16, 1112.22s/epoch]

Fold 1, Epoch 25/100 - MSE: 1.0912, CI: 0.7777, Pearson: 0.7173


Training Fold 1:  23%|██▎       | 23/98 [7:16:11<23:10:16, 1112.22s/epoch]

Fold 1, Epoch 26/100, Loss: 0.4956
Checkpoint saved for fold 1 at epoch 26


Training Fold 1:  24%|██▍       | 24/98 [7:22:23<22:46:54, 1108.30s/epoch]

Fold 1, Epoch 26/100 - MSE: 1.0309, CI: 0.7724, Pearson: 0.7177


Training Fold 1:  24%|██▍       | 24/98 [7:34:11<22:46:54, 1108.30s/epoch]

Fold 1, Epoch 27/100, Loss: 0.4918
Checkpoint saved for fold 1 at epoch 27


Training Fold 1:  26%|██▌       | 25/98 [7:40:26<22:18:51, 1100.44s/epoch]

Fold 1, Epoch 27/100 - MSE: 1.0226, CI: 0.7747, Pearson: 0.7221


Training Fold 1:  26%|██▌       | 25/98 [7:52:11<22:18:51, 1100.44s/epoch]

Fold 1, Epoch 28/100, Loss: 0.4882
Checkpoint saved for fold 1 at epoch 28


Training Fold 1:  27%|██▋       | 26/98 [7:58:23<21:52:19, 1093.61s/epoch]

Fold 1, Epoch 28/100 - MSE: 1.1133, CI: 0.7795, Pearson: 0.7263


Training Fold 1:  27%|██▋       | 26/98 [8:10:12<21:52:19, 1093.61s/epoch]

Fold 1, Epoch 29/100, Loss: 0.4754
Checkpoint saved for fold 1 at epoch 29


Training Fold 1:  28%|██▊       | 27/98 [8:16:25<21:30:02, 1090.17s/epoch]

Fold 1, Epoch 29/100 - MSE: 1.0451, CI: 0.7875, Pearson: 0.7294


Training Fold 1:  28%|██▊       | 27/98 [8:28:09<21:30:02, 1090.17s/epoch]

Fold 1, Epoch 30/100, Loss: 0.4785
Checkpoint saved for fold 1 at epoch 30


Training Fold 1:  29%|██▊       | 28/98 [8:34:36<21:12:01, 1090.30s/epoch]

Fold 1, Epoch 30/100 - MSE: 1.0663, CI: 0.7858, Pearson: 0.7343


Training Fold 1:  29%|██▊       | 28/98 [8:51:49<22:09:33, 1139.63s/epoch]


KeyboardInterrupt: 

In [7]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]

def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 200  # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * len(batch_samples)

            avg_loss = running_loss / len(train_files)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 200  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results


if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 120  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
Loading checkpoint for fold 1 from prepared_samples/TrainingModel/model_fold1_epoch30.pt
Resuming training from epoch 31


Training Fold 1:   0%|          | 0/90 [11:58<?, ?epoch/s]

Fold 1, Epoch 31/120, Loss: 0.4399
Checkpoint saved for fold 1 at epoch 31


Training Fold 1:   1%|          | 1/90 [18:14<27:04:14, 1095.00s/epoch]

Fold 1, Epoch 31/120 - MSE: 1.0760, CI: 0.7888, Pearson: 0.7425


Training Fold 1:   1%|          | 1/90 [30:12<27:04:14, 1095.00s/epoch]

Fold 1, Epoch 32/120, Loss: 0.4396
Checkpoint saved for fold 1 at epoch 32


Training Fold 1:   2%|▏         | 2/90 [36:27<26:43:45, 1093.47s/epoch]

Fold 1, Epoch 32/120 - MSE: 1.0232, CI: 0.7918, Pearson: 0.7460


Training Fold 1:   2%|▏         | 2/90 [48:59<26:43:45, 1093.47s/epoch]

Fold 1, Epoch 33/120, Loss: 0.4380
Checkpoint saved for fold 1 at epoch 33


Training Fold 1:   3%|▎         | 3/90 [55:31<26:59:21, 1116.80s/epoch]

Fold 1, Epoch 33/120 - MSE: 1.0571, CI: 0.7964, Pearson: 0.7507


Training Fold 1:   3%|▎         | 3/90 [1:08:21<26:59:21, 1116.80s/epoch]

Fold 1, Epoch 34/120, Loss: 0.4374
Checkpoint saved for fold 1 at epoch 34


Training Fold 1:   4%|▍         | 4/90 [1:14:53<27:05:53, 1134.34s/epoch]

Fold 1, Epoch 34/120 - MSE: 1.1009, CI: 0.7982, Pearson: 0.7528


Training Fold 1:   4%|▍         | 4/90 [1:27:55<27:05:53, 1134.34s/epoch]

Fold 1, Epoch 35/120, Loss: 0.4323
Checkpoint saved for fold 1 at epoch 35


Training Fold 1:   6%|▌         | 5/90 [1:34:41<27:14:25, 1153.71s/epoch]

Fold 1, Epoch 35/120 - MSE: 1.1038, CI: 0.7969, Pearson: 0.7524


Training Fold 1:   6%|▌         | 5/90 [1:47:44<27:14:25, 1153.71s/epoch]

Fold 1, Epoch 36/120, Loss: 0.4295
Checkpoint saved for fold 1 at epoch 36


Training Fold 1:   7%|▋         | 6/90 [1:54:38<27:16:05, 1168.64s/epoch]

Fold 1, Epoch 36/120 - MSE: 1.0726, CI: 0.7988, Pearson: 0.7532


Training Fold 1:   7%|▋         | 6/90 [2:07:21<27:16:05, 1168.64s/epoch]

Fold 1, Epoch 37/120, Loss: 0.4233
Checkpoint saved for fold 1 at epoch 37


Training Fold 1:   8%|▊         | 7/90 [2:13:58<26:52:32, 1165.69s/epoch]

Fold 1, Epoch 37/120 - MSE: 1.0449, CI: 0.8004, Pearson: 0.7579


Training Fold 1:   8%|▊         | 7/90 [2:26:51<26:52:32, 1165.69s/epoch]

Fold 1, Epoch 38/120, Loss: 0.4190
Checkpoint saved for fold 1 at epoch 38


Training Fold 1:   9%|▉         | 8/90 [2:33:22<26:32:24, 1165.17s/epoch]

Fold 1, Epoch 38/120 - MSE: 1.1403, CI: 0.8003, Pearson: 0.7525


Training Fold 1:   9%|▉         | 8/90 [2:45:45<26:32:24, 1165.17s/epoch]

Fold 1, Epoch 39/120, Loss: 0.4081
Checkpoint saved for fold 1 at epoch 39


Training Fold 1:  10%|█         | 9/90 [2:52:17<26:00:05, 1155.62s/epoch]

Fold 1, Epoch 39/120 - MSE: 1.0947, CI: 0.8015, Pearson: 0.7604


Training Fold 1:  10%|█         | 9/90 [3:04:42<26:00:05, 1155.62s/epoch]

Fold 1, Epoch 40/120, Loss: 0.4034
Checkpoint saved for fold 1 at epoch 40


Training Fold 1:  11%|█         | 10/90 [3:11:05<25:29:40, 1147.25s/epoch]

Fold 1, Epoch 40/120 - MSE: 1.1249, CI: 0.8037, Pearson: 0.7616


Training Fold 1:  11%|█         | 10/90 [3:23:10<25:29:40, 1147.25s/epoch]

Fold 1, Epoch 41/120, Loss: 0.4031
Checkpoint saved for fold 1 at epoch 41


Training Fold 1:  12%|█▏        | 11/90 [3:29:30<24:53:37, 1134.40s/epoch]

Fold 1, Epoch 41/120 - MSE: 1.0975, CI: 0.8017, Pearson: 0.7623


Training Fold 1:  12%|█▏        | 11/90 [3:41:35<24:53:37, 1134.40s/epoch]

Fold 1, Epoch 42/120, Loss: 0.4000
Checkpoint saved for fold 1 at epoch 42


Training Fold 1:  13%|█▎        | 12/90 [3:47:53<24:22:10, 1124.75s/epoch]

Fold 1, Epoch 42/120 - MSE: 1.0834, CI: 0.8042, Pearson: 0.7623


Training Fold 1:  13%|█▎        | 12/90 [3:59:55<24:22:10, 1124.75s/epoch]

Fold 1, Epoch 43/120, Loss: 0.3943
Checkpoint saved for fold 1 at epoch 43


Training Fold 1:  14%|█▍        | 13/90 [4:06:17<23:55:12, 1118.34s/epoch]

Fold 1, Epoch 43/120 - MSE: 1.0799, CI: 0.7974, Pearson: 0.7591


Training Fold 1:  14%|█▍        | 13/90 [4:18:17<23:55:12, 1118.34s/epoch]

Fold 1, Epoch 44/120, Loss: 0.3952
Checkpoint saved for fold 1 at epoch 44


Training Fold 1:  16%|█▌        | 14/90 [4:24:36<23:29:15, 1112.57s/epoch]

Fold 1, Epoch 44/120 - MSE: 1.0724, CI: 0.8055, Pearson: 0.7661


Training Fold 1:  16%|█▌        | 14/90 [4:36:38<23:29:15, 1112.57s/epoch]

Fold 1, Epoch 45/120, Loss: 0.3890
Checkpoint saved for fold 1 at epoch 45


Training Fold 1:  17%|█▋        | 15/90 [4:42:57<23:06:12, 1108.96s/epoch]

Fold 1, Epoch 45/120 - MSE: 1.1344, CI: 0.8024, Pearson: 0.7652


Training Fold 1:  17%|█▋        | 15/90 [4:55:00<23:06:12, 1108.96s/epoch]

Fold 1, Epoch 46/120, Loss: 0.3830
Checkpoint saved for fold 1 at epoch 46


Training Fold 1:  18%|█▊        | 16/90 [5:01:18<22:44:57, 1106.72s/epoch]

Fold 1, Epoch 46/120 - MSE: 1.1139, CI: 0.8095, Pearson: 0.7722


Training Fold 1:  18%|█▊        | 16/90 [5:13:21<22:44:57, 1106.72s/epoch]

Fold 1, Epoch 47/120, Loss: 0.3792
Checkpoint saved for fold 1 at epoch 47


Training Fold 1:  19%|█▉        | 17/90 [5:19:40<22:24:46, 1105.30s/epoch]

Fold 1, Epoch 47/120 - MSE: 1.0907, CI: 0.8066, Pearson: 0.7762


Training Fold 1:  19%|█▉        | 17/90 [5:32:36<22:24:46, 1105.30s/epoch]

Fold 1, Epoch 48/120, Loss: 0.3750
Checkpoint saved for fold 1 at epoch 48


Training Fold 1:  20%|██        | 18/90 [5:39:21<22:33:35, 1128.00s/epoch]

Fold 1, Epoch 48/120 - MSE: 1.0592, CI: 0.8051, Pearson: 0.7742


Training Fold 1:  20%|██        | 18/90 [5:52:32<22:33:35, 1128.00s/epoch]

Fold 1, Epoch 49/120, Loss: 0.3771
Checkpoint saved for fold 1 at epoch 49


Training Fold 1:  21%|██        | 19/90 [5:59:18<22:39:31, 1148.90s/epoch]

Fold 1, Epoch 49/120 - MSE: 1.0624, CI: 0.8107, Pearson: 0.7761


Training Fold 1:  21%|██        | 19/90 [6:12:41<22:39:31, 1148.90s/epoch]

Fold 1, Epoch 50/120, Loss: 0.3728
Checkpoint saved for fold 1 at epoch 50


Training Fold 1:  22%|██▏       | 20/90 [6:19:26<22:40:50, 1166.43s/epoch]

Fold 1, Epoch 50/120 - MSE: 1.0951, CI: 0.8121, Pearson: 0.7795


Training Fold 1:  22%|██▏       | 20/90 [6:32:42<22:40:50, 1166.43s/epoch]

Fold 1, Epoch 51/120, Loss: 0.3688
Checkpoint saved for fold 1 at epoch 51


Training Fold 1:  23%|██▎       | 21/90 [6:38:53<22:21:48, 1166.79s/epoch]

Fold 1, Epoch 51/120 - MSE: 1.0294, CI: 0.8115, Pearson: 0.7814


Training Fold 1:  23%|██▎       | 21/90 [6:50:43<22:21:48, 1166.79s/epoch]

Fold 1, Epoch 52/120, Loss: 0.3633
Checkpoint saved for fold 1 at epoch 52


Training Fold 1:  23%|██▎       | 21/90 [6:54:49<22:42:59, 1185.21s/epoch]


KeyboardInterrupt: 

In [8]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]

def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 256 # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * len(batch_samples)

            avg_loss = running_loss / len(train_files)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 256  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results


if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 250  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
Loading checkpoint for fold 1 from prepared_samples/TrainingModel/model_fold1_epoch52.pt
Resuming training from epoch 53


Training Fold 1:   0%|          | 0/198 [13:06<?, ?epoch/s]

Fold 1, Epoch 53/250, Loss: 0.3572
Checkpoint saved for fold 1 at epoch 53


Training Fold 1:   1%|          | 1/198 [19:33<64:12:53, 1173.47s/epoch]

Fold 1, Epoch 53/250 - MSE: 1.0370, CI: 0.8181, Pearson: 0.7939


Training Fold 1:   1%|          | 1/198 [32:14<64:12:53, 1173.47s/epoch]

Fold 1, Epoch 54/250, Loss: 0.3532
Checkpoint saved for fold 1 at epoch 54


Training Fold 1:   1%|          | 2/198 [38:43<63:07:43, 1159.51s/epoch]

Fold 1, Epoch 54/250 - MSE: 1.0438, CI: 0.8176, Pearson: 0.7883


Training Fold 1:   1%|          | 2/198 [51:29<63:07:43, 1159.51s/epoch]

Fold 1, Epoch 55/250, Loss: 0.3491
Checkpoint saved for fold 1 at epoch 55


Training Fold 1:   2%|▏         | 3/198 [58:06<62:53:36, 1161.11s/epoch]

Fold 1, Epoch 55/250 - MSE: 1.0644, CI: 0.8180, Pearson: 0.7924


Training Fold 1:   2%|▏         | 3/198 [1:11:04<62:53:36, 1161.11s/epoch]

Fold 1, Epoch 56/250, Loss: 0.3451
Checkpoint saved for fold 1 at epoch 56


Training Fold 1:   2%|▏         | 4/198 [1:17:30<62:37:53, 1162.23s/epoch]

Fold 1, Epoch 56/250 - MSE: 1.1196, CI: 0.8232, Pearson: 0.7978


Training Fold 1:   2%|▏         | 4/198 [1:30:42<62:37:53, 1162.23s/epoch]

Fold 1, Epoch 57/250, Loss: 0.3397
Checkpoint saved for fold 1 at epoch 57


Training Fold 1:   3%|▎         | 5/198 [1:37:10<62:39:30, 1168.76s/epoch]

Fold 1, Epoch 57/250 - MSE: 1.0449, CI: 0.8194, Pearson: 0.8009


Training Fold 1:   3%|▎         | 5/198 [1:49:59<62:39:30, 1168.76s/epoch]

Fold 1, Epoch 58/250, Loss: 0.3379
Checkpoint saved for fold 1 at epoch 58


Training Fold 1:   3%|▎         | 6/198 [1:56:24<62:04:21, 1163.86s/epoch]

Fold 1, Epoch 58/250 - MSE: 1.1172, CI: 0.8235, Pearson: 0.8029


Training Fold 1:   3%|▎         | 6/198 [2:09:38<62:04:21, 1163.86s/epoch]

Fold 1, Epoch 59/250, Loss: 0.3373
Checkpoint saved for fold 1 at epoch 59


Training Fold 1:   4%|▎         | 7/198 [2:16:10<62:07:22, 1170.91s/epoch]

Fold 1, Epoch 59/250 - MSE: 1.0967, CI: 0.8226, Pearson: 0.7967


Training Fold 1:   4%|▎         | 7/198 [2:29:23<62:07:22, 1170.91s/epoch]

Fold 1, Epoch 60/250, Loss: 0.3357
Checkpoint saved for fold 1 at epoch 60


Training Fold 1:   4%|▍         | 8/198 [2:35:49<61:56:19, 1173.58s/epoch]

Fold 1, Epoch 60/250 - MSE: 1.1151, CI: 0.8251, Pearson: 0.8026


Training Fold 1:   4%|▍         | 8/198 [2:48:58<61:56:19, 1173.58s/epoch]

Fold 1, Epoch 61/250, Loss: 0.3327
Checkpoint saved for fold 1 at epoch 61


Training Fold 1:   5%|▍         | 9/198 [2:55:24<61:37:51, 1173.92s/epoch]

Fold 1, Epoch 61/250 - MSE: 1.0891, CI: 0.8295, Pearson: 0.8087


Training Fold 1:   5%|▍         | 9/198 [3:08:48<61:37:51, 1173.92s/epoch]

Fold 1, Epoch 62/250, Loss: 0.3296
Checkpoint saved for fold 1 at epoch 62


Training Fold 1:   5%|▌         | 10/198 [3:14:55<61:15:59, 1173.19s/epoch]

Fold 1, Epoch 62/250 - MSE: 1.1296, CI: 0.8314, Pearson: 0.8135


Training Fold 1:   5%|▌         | 10/198 [3:26:57<61:15:59, 1173.19s/epoch]

Fold 1, Epoch 63/250, Loss: 0.3258
Checkpoint saved for fold 1 at epoch 63


Training Fold 1:   6%|▌         | 11/198 [3:33:10<59:41:25, 1149.12s/epoch]

Fold 1, Epoch 63/250 - MSE: 1.1304, CI: 0.8286, Pearson: 0.8091


Training Fold 1:   6%|▌         | 11/198 [3:45:27<59:41:25, 1149.12s/epoch]

Fold 1, Epoch 64/250, Loss: 0.3241
Checkpoint saved for fold 1 at epoch 64


Training Fold 1:   6%|▌         | 12/198 [3:51:44<58:49:42, 1138.62s/epoch]

Fold 1, Epoch 64/250 - MSE: 1.0846, CI: 0.8297, Pearson: 0.8116


Training Fold 1:   6%|▌         | 12/198 [4:04:00<58:49:42, 1138.62s/epoch]

Fold 1, Epoch 65/250, Loss: 0.3201
Checkpoint saved for fold 1 at epoch 65


Training Fold 1:   7%|▋         | 13/198 [4:10:10<57:59:59, 1128.64s/epoch]

Fold 1, Epoch 65/250 - MSE: 1.0599, CI: 0.8275, Pearson: 0.8120


Training Fold 1:   7%|▋         | 13/198 [4:22:26<57:59:59, 1128.64s/epoch]

Fold 1, Epoch 66/250, Loss: 0.3203
Checkpoint saved for fold 1 at epoch 66


Training Fold 1:   7%|▋         | 14/198 [4:28:42<57:25:22, 1123.49s/epoch]

Fold 1, Epoch 66/250 - MSE: 1.1327, CI: 0.8264, Pearson: 0.8129


Training Fold 1:   7%|▋         | 14/198 [4:41:02<57:25:22, 1123.49s/epoch]

Fold 1, Epoch 67/250, Loss: 0.3151
Checkpoint saved for fold 1 at epoch 67


Training Fold 1:   8%|▊         | 15/198 [4:47:15<56:57:30, 1120.49s/epoch]

Fold 1, Epoch 67/250 - MSE: 1.1074, CI: 0.8317, Pearson: 0.8174


Training Fold 1:   8%|▊         | 15/198 [4:59:34<56:57:30, 1120.49s/epoch]

Fold 1, Epoch 68/250, Loss: 0.3145
Checkpoint saved for fold 1 at epoch 68


Training Fold 1:   8%|▊         | 16/198 [5:05:46<56:29:51, 1117.54s/epoch]

Fold 1, Epoch 68/250 - MSE: 1.1396, CI: 0.8306, Pearson: 0.8171


Training Fold 1:   8%|▊         | 16/198 [5:18:17<56:29:51, 1117.54s/epoch]

Fold 1, Epoch 69/250, Loss: 0.3115
Checkpoint saved for fold 1 at epoch 69


Training Fold 1:   9%|▊         | 17/198 [5:24:29<56:15:57, 1119.10s/epoch]

Fold 1, Epoch 69/250 - MSE: 1.1543, CI: 0.8312, Pearson: 0.8162


Training Fold 1:   9%|▊         | 17/198 [5:38:15<56:15:57, 1119.10s/epoch]

Fold 1, Epoch 70/250, Loss: 0.3105
Checkpoint saved for fold 1 at epoch 70


Training Fold 1:   9%|▉         | 18/198 [5:44:40<57:20:19, 1146.78s/epoch]

Fold 1, Epoch 70/250 - MSE: 1.0938, CI: 0.8317, Pearson: 0.8201


Training Fold 1:   9%|▉         | 18/198 [5:58:05<57:20:19, 1146.78s/epoch]

Fold 1, Epoch 71/250, Loss: 0.3070
Checkpoint saved for fold 1 at epoch 71


Training Fold 1:  10%|▉         | 19/198 [6:04:23<57:33:39, 1157.65s/epoch]

Fold 1, Epoch 71/250 - MSE: 1.1410, CI: 0.8368, Pearson: 0.8252


Training Fold 1:  10%|▉         | 19/198 [6:17:30<57:33:39, 1157.65s/epoch]

Fold 1, Epoch 72/250, Loss: 0.3102
Checkpoint saved for fold 1 at epoch 72


Training Fold 1:  10%|█         | 20/198 [6:23:38<57:11:58, 1156.84s/epoch]

Fold 1, Epoch 72/250 - MSE: 1.1440, CI: 0.8354, Pearson: 0.8210


Training Fold 1:  10%|█         | 20/198 [6:36:04<57:11:58, 1156.84s/epoch]

Fold 1, Epoch 73/250, Loss: 0.3000
Checkpoint saved for fold 1 at epoch 73


Training Fold 1:  11%|█         | 21/198 [6:42:16<56:18:13, 1145.16s/epoch]

Fold 1, Epoch 73/250 - MSE: 1.1317, CI: 0.8355, Pearson: 0.8194


Training Fold 1:  11%|█         | 21/198 [6:54:42<56:18:13, 1145.16s/epoch]

Fold 1, Epoch 74/250, Loss: 0.2994
Checkpoint saved for fold 1 at epoch 74


Training Fold 1:  11%|█         | 22/198 [7:00:49<55:31:25, 1135.72s/epoch]

Fold 1, Epoch 74/250 - MSE: 1.1484, CI: 0.8320, Pearson: 0.8172


Training Fold 1:  11%|█         | 22/198 [7:13:09<55:31:25, 1135.72s/epoch]

Fold 1, Epoch 75/250, Loss: 0.2958
Checkpoint saved for fold 1 at epoch 75


Training Fold 1:  12%|█▏        | 23/198 [7:19:26<54:56:10, 1130.12s/epoch]

Fold 1, Epoch 75/250 - MSE: 1.1473, CI: 0.8350, Pearson: 0.8223


Training Fold 1:  12%|█▏        | 23/198 [7:32:06<54:56:10, 1130.12s/epoch]

Fold 1, Epoch 76/250, Loss: 0.2972
Checkpoint saved for fold 1 at epoch 76


Training Fold 1:  12%|█▏        | 24/198 [7:38:18<54:38:28, 1130.51s/epoch]

Fold 1, Epoch 76/250 - MSE: 1.1520, CI: 0.8413, Pearson: 0.8300


Training Fold 1:  12%|█▏        | 24/198 [7:51:14<54:38:28, 1130.51s/epoch]

Fold 1, Epoch 77/250, Loss: 0.2903
Checkpoint saved for fold 1 at epoch 77


Training Fold 1:  13%|█▎        | 25/198 [7:57:32<54:39:59, 1137.57s/epoch]

Fold 1, Epoch 77/250 - MSE: 1.1161, CI: 0.8386, Pearson: 0.8234


Training Fold 1:  13%|█▎        | 25/198 [8:10:56<54:39:59, 1137.57s/epoch]

Fold 1, Epoch 78/250, Loss: 0.2937
Checkpoint saved for fold 1 at epoch 78


Training Fold 1:  13%|█▎        | 26/198 [8:17:13<54:58:33, 1150.66s/epoch]

Fold 1, Epoch 78/250 - MSE: 1.0964, CI: 0.8376, Pearson: 0.8235


Training Fold 1:  13%|█▎        | 26/198 [8:31:25<54:58:33, 1150.66s/epoch]

Fold 1, Epoch 79/250, Loss: 0.2939
Checkpoint saved for fold 1 at epoch 79


Training Fold 1:  14%|█▎        | 27/198 [8:37:50<55:53:20, 1176.61s/epoch]

Fold 1, Epoch 79/250 - MSE: 1.0846, CI: 0.8346, Pearson: 0.8256


Training Fold 1:  14%|█▎        | 27/198 [8:50:35<55:53:20, 1176.61s/epoch]

Fold 1, Epoch 80/250, Loss: 0.2871
Checkpoint saved for fold 1 at epoch 80


Training Fold 1:  14%|█▍        | 28/198 [8:56:48<55:00:56, 1165.04s/epoch]

Fold 1, Epoch 80/250 - MSE: 1.1807, CI: 0.8414, Pearson: 0.8283


Training Fold 1:  14%|█▍        | 28/198 [9:09:14<55:00:56, 1165.04s/epoch]

Fold 1, Epoch 81/250, Loss: 0.2853
Checkpoint saved for fold 1 at epoch 81


Training Fold 1:  15%|█▍        | 29/198 [9:15:27<54:02:18, 1151.12s/epoch]

Fold 1, Epoch 81/250 - MSE: 1.0849, CI: 0.8404, Pearson: 0.8287


Training Fold 1:  15%|█▍        | 29/198 [9:27:51<54:02:18, 1151.12s/epoch]

Fold 1, Epoch 82/250, Loss: 0.2872
Checkpoint saved for fold 1 at epoch 82


Training Fold 1:  15%|█▌        | 30/198 [9:34:01<53:11:59, 1140.00s/epoch]

Fold 1, Epoch 82/250 - MSE: 1.1513, CI: 0.8385, Pearson: 0.8287


Training Fold 1:  15%|█▌        | 30/198 [9:47:56<53:11:59, 1140.00s/epoch]

Fold 1, Epoch 83/250, Loss: 0.2808
Checkpoint saved for fold 1 at epoch 83


Training Fold 1:  16%|█▌        | 31/198 [9:54:11<53:51:31, 1161.03s/epoch]

Fold 1, Epoch 83/250 - MSE: 1.1494, CI: 0.8412, Pearson: 0.8242


Training Fold 1:  16%|█▌        | 31/198 [10:07:17<53:51:31, 1161.03s/epoch]

Fold 1, Epoch 84/250, Loss: 0.2806
Checkpoint saved for fold 1 at epoch 84


Training Fold 1:  16%|█▌        | 32/198 [10:13:38<53:37:05, 1162.80s/epoch]

Fold 1, Epoch 84/250 - MSE: 1.1007, CI: 0.8343, Pearson: 0.8262


Training Fold 1:  16%|█▌        | 32/198 [10:26:48<53:37:05, 1162.80s/epoch]

Fold 1, Epoch 85/250, Loss: 0.2810
Checkpoint saved for fold 1 at epoch 85


Training Fold 1:  17%|█▋        | 33/198 [10:33:07<53:22:21, 1164.49s/epoch]

Fold 1, Epoch 85/250 - MSE: 1.1405, CI: 0.8406, Pearson: 0.8326


Training Fold 1:  17%|█▋        | 33/198 [10:46:16<53:22:21, 1164.49s/epoch]

Fold 1, Epoch 86/250, Loss: 0.2790
Checkpoint saved for fold 1 at epoch 86


Training Fold 1:  17%|█▋        | 34/198 [10:52:36<53:07:15, 1166.07s/epoch]

Fold 1, Epoch 86/250 - MSE: 1.1511, CI: 0.8413, Pearson: 0.8328


Training Fold 1:  17%|█▋        | 34/198 [11:06:13<53:07:15, 1166.07s/epoch]

Fold 1, Epoch 87/250, Loss: 0.2744
Checkpoint saved for fold 1 at epoch 87


Training Fold 1:  18%|█▊        | 35/198 [11:12:33<53:13:04, 1175.37s/epoch]

Fold 1, Epoch 87/250 - MSE: 1.1459, CI: 0.8396, Pearson: 0.8310


Training Fold 1:  18%|█▊        | 35/198 [11:25:58<53:13:04, 1175.37s/epoch]

Fold 1, Epoch 88/250, Loss: 0.2745
Checkpoint saved for fold 1 at epoch 88


Training Fold 1:  18%|█▊        | 36/198 [11:32:21<53:03:52, 1179.21s/epoch]

Fold 1, Epoch 88/250 - MSE: 1.1533, CI: 0.8449, Pearson: 0.8351


Training Fold 1:  18%|█▊        | 36/198 [11:45:58<53:03:52, 1179.21s/epoch]

Fold 1, Epoch 89/250, Loss: 0.2687
Checkpoint saved for fold 1 at epoch 89


Training Fold 1:  19%|█▊        | 37/198 [11:52:19<52:59:10, 1184.79s/epoch]

Fold 1, Epoch 89/250 - MSE: 1.1709, CI: 0.8438, Pearson: 0.8313


Training Fold 1:  19%|█▊        | 37/198 [12:05:01<52:59:10, 1184.79s/epoch]

Fold 1, Epoch 90/250, Loss: 0.2696
Checkpoint saved for fold 1 at epoch 90


Training Fold 1:  19%|█▉        | 38/198 [12:11:13<51:58:25, 1169.41s/epoch]

Fold 1, Epoch 90/250 - MSE: 1.1495, CI: 0.8411, Pearson: 0.8352


Training Fold 1:  19%|█▉        | 38/198 [12:23:43<51:58:25, 1169.41s/epoch]

Fold 1, Epoch 91/250, Loss: 0.2680
Checkpoint saved for fold 1 at epoch 91


Training Fold 1:  20%|█▉        | 39/198 [12:29:54<51:00:15, 1154.81s/epoch]

Fold 1, Epoch 91/250 - MSE: 1.1742, CI: 0.8437, Pearson: 0.8325


Training Fold 1:  20%|█▉        | 39/198 [12:42:27<51:00:15, 1154.81s/epoch]

Fold 1, Epoch 92/250, Loss: 0.2662
Checkpoint saved for fold 1 at epoch 92


Training Fold 1:  20%|██        | 40/198 [12:48:40<50:18:17, 1146.19s/epoch]

Fold 1, Epoch 92/250 - MSE: 1.1266, CI: 0.8438, Pearson: 0.8342


Training Fold 1:  20%|██        | 40/198 [13:01:01<50:18:17, 1146.19s/epoch]

Fold 1, Epoch 93/250, Loss: 0.2684
Checkpoint saved for fold 1 at epoch 93


Training Fold 1:  21%|██        | 41/198 [13:07:09<49:30:03, 1135.06s/epoch]

Fold 1, Epoch 93/250 - MSE: 1.1219, CI: 0.8438, Pearson: 0.8382


Training Fold 1:  21%|██        | 41/198 [13:19:30<49:30:03, 1135.06s/epoch]

Fold 1, Epoch 94/250, Loss: 0.2609
Checkpoint saved for fold 1 at epoch 94


Training Fold 1:  21%|██        | 42/198 [13:25:43<48:54:53, 1128.81s/epoch]

Fold 1, Epoch 94/250 - MSE: 1.1546, CI: 0.8419, Pearson: 0.8384


Training Fold 1:  21%|██        | 42/198 [13:38:20<48:54:53, 1128.81s/epoch]

Fold 1, Epoch 95/250, Loss: 0.2627
Checkpoint saved for fold 1 at epoch 95


Training Fold 1:  22%|██▏       | 43/198 [13:44:31<48:35:06, 1128.43s/epoch]

Fold 1, Epoch 95/250 - MSE: 1.1111, CI: 0.8431, Pearson: 0.8387


Training Fold 1:  22%|██▏       | 43/198 [13:57:16<48:35:06, 1128.43s/epoch]

Fold 1, Epoch 96/250, Loss: 0.2617
Checkpoint saved for fold 1 at epoch 96


Training Fold 1:  22%|██▏       | 44/198 [14:03:26<48:21:46, 1130.56s/epoch]

Fold 1, Epoch 96/250 - MSE: 1.1435, CI: 0.8448, Pearson: 0.8357


Training Fold 1:  22%|██▏       | 44/198 [14:16:01<48:21:46, 1130.56s/epoch]

Fold 1, Epoch 97/250, Loss: 0.2611
Checkpoint saved for fold 1 at epoch 97


Training Fold 1:  23%|██▎       | 45/198 [14:22:13<48:00:16, 1129.52s/epoch]

Fold 1, Epoch 97/250 - MSE: 1.1358, CI: 0.8477, Pearson: 0.8372


Training Fold 1:  23%|██▎       | 45/198 [14:34:47<48:00:16, 1129.52s/epoch]

Fold 1, Epoch 98/250, Loss: 0.2583
Checkpoint saved for fold 1 at epoch 98


Training Fold 1:  23%|██▎       | 46/198 [14:40:56<47:36:15, 1127.47s/epoch]

Fold 1, Epoch 98/250 - MSE: 1.1795, CI: 0.8440, Pearson: 0.8419


Training Fold 1:  23%|██▎       | 46/198 [14:53:35<47:36:15, 1127.47s/epoch]

Fold 1, Epoch 99/250, Loss: 0.2563
Checkpoint saved for fold 1 at epoch 99


Training Fold 1:  24%|██▎       | 47/198 [14:59:46<47:19:15, 1128.18s/epoch]

Fold 1, Epoch 99/250 - MSE: 1.1733, CI: 0.8481, Pearson: 0.8418


Training Fold 1:  24%|██▎       | 47/198 [15:12:26<47:19:15, 1128.18s/epoch]

Fold 1, Epoch 100/250, Loss: 0.2563
Checkpoint saved for fold 1 at epoch 100


Training Fold 1:  24%|██▍       | 48/198 [15:18:40<47:05:24, 1130.16s/epoch]

Fold 1, Epoch 100/250 - MSE: 1.1841, CI: 0.8462, Pearson: 0.8401


Training Fold 1:  24%|██▍       | 48/198 [15:31:20<47:05:24, 1130.16s/epoch]

Fold 1, Epoch 101/250, Loss: 0.2523
Checkpoint saved for fold 1 at epoch 101


Training Fold 1:  25%|██▍       | 49/198 [15:37:37<46:50:58, 1131.94s/epoch]

Fold 1, Epoch 101/250 - MSE: 1.1490, CI: 0.8451, Pearson: 0.8388


Training Fold 1:  25%|██▍       | 49/198 [15:50:19<46:50:58, 1131.94s/epoch]

Fold 1, Epoch 102/250, Loss: 0.2516
Checkpoint saved for fold 1 at epoch 102


Training Fold 1:  25%|██▌       | 50/198 [15:56:35<46:36:55, 1133.89s/epoch]

Fold 1, Epoch 102/250 - MSE: 1.1484, CI: 0.8475, Pearson: 0.8406


Training Fold 1:  25%|██▌       | 50/198 [16:09:20<46:36:55, 1133.89s/epoch]

Fold 1, Epoch 103/250, Loss: 0.2506
Checkpoint saved for fold 1 at epoch 103


Training Fold 1:  26%|██▌       | 51/198 [16:15:34<46:21:53, 1135.47s/epoch]

Fold 1, Epoch 103/250 - MSE: 1.2122, CI: 0.8473, Pearson: 0.8424


Training Fold 1:  26%|██▌       | 51/198 [16:28:34<46:21:53, 1135.47s/epoch]

Fold 1, Epoch 104/250, Loss: 0.2497
Checkpoint saved for fold 1 at epoch 104


Training Fold 1:  26%|██▋       | 52/198 [16:34:48<46:16:25, 1140.99s/epoch]

Fold 1, Epoch 104/250 - MSE: 1.1699, CI: 0.8491, Pearson: 0.8468


Training Fold 1:  26%|██▋       | 52/198 [16:47:46<46:16:25, 1140.99s/epoch]

Fold 1, Epoch 105/250, Loss: 0.2490
Checkpoint saved for fold 1 at epoch 105


Training Fold 1:  27%|██▋       | 53/198 [16:54:09<46:12:10, 1147.10s/epoch]

Fold 1, Epoch 105/250 - MSE: 1.1608, CI: 0.8460, Pearson: 0.8406


Training Fold 1:  27%|██▋       | 53/198 [17:08:10<46:12:10, 1147.10s/epoch]

Fold 1, Epoch 106/250, Loss: 0.2441
Checkpoint saved for fold 1 at epoch 106


Training Fold 1:  27%|██▋       | 54/198 [17:14:16<46:35:47, 1164.91s/epoch]

Fold 1, Epoch 106/250 - MSE: 1.1547, CI: 0.8469, Pearson: 0.8440


Training Fold 1:  27%|██▋       | 54/198 [17:26:45<46:35:47, 1164.91s/epoch]

Fold 1, Epoch 107/250, Loss: 0.2447
Checkpoint saved for fold 1 at epoch 107


Training Fold 1:  28%|██▊       | 55/198 [17:32:57<45:45:14, 1151.85s/epoch]

Fold 1, Epoch 107/250 - MSE: 1.2118, CI: 0.8470, Pearson: 0.8457


Training Fold 1:  28%|██▊       | 55/198 [17:45:24<45:45:14, 1151.85s/epoch]

Fold 1, Epoch 108/250, Loss: 0.2435
Checkpoint saved for fold 1 at epoch 108


Training Fold 1:  28%|██▊       | 56/198 [17:51:36<45:02:53, 1142.07s/epoch]

Fold 1, Epoch 108/250 - MSE: 1.1822, CI: 0.8505, Pearson: 0.8471


Training Fold 1:  28%|██▊       | 56/198 [18:03:56<45:02:53, 1142.07s/epoch]

Fold 1, Epoch 109/250, Loss: 0.2427
Checkpoint saved for fold 1 at epoch 109


Training Fold 1:  29%|██▉       | 57/198 [18:10:08<44:22:08, 1132.83s/epoch]

Fold 1, Epoch 109/250 - MSE: 1.1499, CI: 0.8497, Pearson: 0.8485


Training Fold 1:  29%|██▉       | 57/198 [18:22:50<44:22:08, 1132.83s/epoch]

Fold 1, Epoch 110/250, Loss: 0.2404
Checkpoint saved for fold 1 at epoch 110


Training Fold 1:  29%|██▉       | 58/198 [18:29:15<44:13:09, 1137.07s/epoch]

Fold 1, Epoch 110/250 - MSE: 1.1955, CI: 0.8486, Pearson: 0.8461


Training Fold 1:  29%|██▉       | 58/198 [18:42:38<44:13:09, 1137.07s/epoch]

Fold 1, Epoch 111/250, Loss: 0.2387
Checkpoint saved for fold 1 at epoch 111


Training Fold 1:  30%|██▉       | 59/198 [18:48:51<44:21:10, 1148.71s/epoch]

Fold 1, Epoch 111/250 - MSE: 1.1563, CI: 0.8513, Pearson: 0.8508


Training Fold 1:  30%|██▉       | 59/198 [19:02:11<44:21:10, 1148.71s/epoch]

Fold 1, Epoch 112/250, Loss: 0.2384
Checkpoint saved for fold 1 at epoch 112


Training Fold 1:  30%|███       | 60/198 [19:08:33<44:25:30, 1158.92s/epoch]

Fold 1, Epoch 112/250 - MSE: 1.2318, CI: 0.8509, Pearson: 0.8489


Training Fold 1:  30%|███       | 60/198 [19:21:58<44:25:30, 1158.92s/epoch]

Fold 1, Epoch 113/250, Loss: 0.2379
Checkpoint saved for fold 1 at epoch 113


Training Fold 1:  31%|███       | 61/198 [19:28:11<44:19:19, 1164.67s/epoch]

Fold 1, Epoch 113/250 - MSE: 1.1926, CI: 0.8483, Pearson: 0.8453


Training Fold 1:  31%|███       | 61/198 [19:40:30<44:19:19, 1164.67s/epoch]

Fold 1, Epoch 114/250, Loss: 0.2340
Checkpoint saved for fold 1 at epoch 114


Training Fold 1:  31%|███▏      | 62/198 [19:46:35<43:18:18, 1146.31s/epoch]

Fold 1, Epoch 114/250 - MSE: 1.2180, CI: 0.8504, Pearson: 0.8467


Training Fold 1:  31%|███▏      | 62/198 [19:59:25<43:18:18, 1146.31s/epoch]

Fold 1, Epoch 115/250, Loss: 0.2335
Checkpoint saved for fold 1 at epoch 115


Training Fold 1:  32%|███▏      | 63/198 [20:05:40<42:58:10, 1145.85s/epoch]

Fold 1, Epoch 115/250 - MSE: 1.1550, CI: 0.8499, Pearson: 0.8494


Training Fold 1:  32%|███▏      | 63/198 [20:18:13<42:58:10, 1145.85s/epoch]

Fold 1, Epoch 116/250, Loss: 0.2336
Checkpoint saved for fold 1 at epoch 116


Training Fold 1:  32%|███▏      | 64/198 [20:24:25<42:25:40, 1139.85s/epoch]

Fold 1, Epoch 116/250 - MSE: 1.1639, CI: 0.8499, Pearson: 0.8460


Training Fold 1:  32%|███▏      | 64/198 [20:36:49<42:25:40, 1139.85s/epoch]

Fold 1, Epoch 117/250, Loss: 0.2323
Checkpoint saved for fold 1 at epoch 117


Training Fold 1:  33%|███▎      | 65/198 [20:43:02<41:51:23, 1132.96s/epoch]

Fold 1, Epoch 117/250 - MSE: 1.1554, CI: 0.8506, Pearson: 0.8486


Training Fold 1:  33%|███▎      | 65/198 [20:55:40<41:51:23, 1132.96s/epoch]

Fold 1, Epoch 118/250, Loss: 0.2348
Checkpoint saved for fold 1 at epoch 118


Training Fold 1:  33%|███▎      | 66/198 [21:01:55<41:32:14, 1132.84s/epoch]

Fold 1, Epoch 118/250 - MSE: 1.1762, CI: 0.8528, Pearson: 0.8476


Training Fold 1:  33%|███▎      | 66/198 [21:14:46<41:32:14, 1132.84s/epoch]

Fold 1, Epoch 119/250, Loss: 0.2289
Checkpoint saved for fold 1 at epoch 119


Training Fold 1:  34%|███▍      | 67/198 [21:20:55<41:18:25, 1135.15s/epoch]

Fold 1, Epoch 119/250 - MSE: 1.1469, CI: 0.8545, Pearson: 0.8525


Training Fold 1:  34%|███▍      | 67/198 [21:33:26<41:18:25, 1135.15s/epoch]

Fold 1, Epoch 120/250, Loss: 0.2272
Checkpoint saved for fold 1 at epoch 120


Training Fold 1:  34%|███▍      | 68/198 [21:39:38<40:51:22, 1131.41s/epoch]

Fold 1, Epoch 120/250 - MSE: 1.2068, CI: 0.8534, Pearson: 0.8481


Training Fold 1:  34%|███▍      | 68/198 [21:52:13<40:51:22, 1131.41s/epoch]

Fold 1, Epoch 121/250, Loss: 0.2303
Checkpoint saved for fold 1 at epoch 121


Training Fold 1:  35%|███▍      | 69/198 [21:58:27<40:30:42, 1130.56s/epoch]

Fold 1, Epoch 121/250 - MSE: 1.1715, CI: 0.8504, Pearson: 0.8474


Training Fold 1:  35%|███▍      | 69/198 [22:11:04<40:30:42, 1130.56s/epoch]

Fold 1, Epoch 122/250, Loss: 0.2266
Checkpoint saved for fold 1 at epoch 122


Training Fold 1:  35%|███▌      | 70/198 [22:17:12<40:08:30, 1128.98s/epoch]

Fold 1, Epoch 122/250 - MSE: 1.2005, CI: 0.8536, Pearson: 0.8493


Training Fold 1:  35%|███▌      | 70/198 [22:29:39<40:08:30, 1128.98s/epoch]

Fold 1, Epoch 123/250, Loss: 0.2270
Checkpoint saved for fold 1 at epoch 123


Training Fold 1:  36%|███▌      | 71/198 [22:35:49<39:41:46, 1125.25s/epoch]

Fold 1, Epoch 123/250 - MSE: 1.1446, CI: 0.8522, Pearson: 0.8484


Training Fold 1:  36%|███▌      | 71/198 [22:48:21<39:41:46, 1125.25s/epoch]

Fold 1, Epoch 124/250, Loss: 0.2207
Checkpoint saved for fold 1 at epoch 124


Training Fold 1:  36%|███▋      | 72/198 [22:54:33<39:22:36, 1125.05s/epoch]

Fold 1, Epoch 124/250 - MSE: 1.1818, CI: 0.8523, Pearson: 0.8517


Training Fold 1:  36%|███▋      | 72/198 [23:07:17<39:22:36, 1125.05s/epoch]

Fold 1, Epoch 125/250, Loss: 0.2221
Checkpoint saved for fold 1 at epoch 125


Training Fold 1:  37%|███▋      | 73/198 [23:13:28<39:09:52, 1127.94s/epoch]

Fold 1, Epoch 125/250 - MSE: 1.1858, CI: 0.8532, Pearson: 0.8529


Training Fold 1:  37%|███▋      | 73/198 [23:26:21<39:09:52, 1127.94s/epoch]

Fold 1, Epoch 126/250, Loss: 0.2191
Checkpoint saved for fold 1 at epoch 126


Training Fold 1:  37%|███▋      | 74/198 [23:32:31<39:00:42, 1132.60s/epoch]

Fold 1, Epoch 126/250 - MSE: 1.1736, CI: 0.8547, Pearson: 0.8512


Training Fold 1:  37%|███▋      | 74/198 [23:45:09<39:00:42, 1132.60s/epoch]

Fold 1, Epoch 127/250, Loss: 0.2197
Checkpoint saved for fold 1 at epoch 127


Training Fold 1:  38%|███▊      | 75/198 [23:51:21<38:39:52, 1131.65s/epoch]

Fold 1, Epoch 127/250 - MSE: 1.1997, CI: 0.8556, Pearson: 0.8541


Training Fold 1:  38%|███▊      | 75/198 [24:03:54<38:39:52, 1131.65s/epoch]

Fold 1, Epoch 128/250, Loss: 0.2189
Checkpoint saved for fold 1 at epoch 128


Training Fold 1:  38%|███▊      | 76/198 [24:10:07<38:17:35, 1129.96s/epoch]

Fold 1, Epoch 128/250 - MSE: 1.1862, CI: 0.8558, Pearson: 0.8541


Training Fold 1:  38%|███▊      | 76/198 [24:22:43<38:17:35, 1129.96s/epoch]

Fold 1, Epoch 129/250, Loss: 0.2182
Checkpoint saved for fold 1 at epoch 129


Training Fold 1:  39%|███▉      | 77/198 [24:29:08<38:05:50, 1133.48s/epoch]

Fold 1, Epoch 129/250 - MSE: 1.1684, CI: 0.8573, Pearson: 0.8558


Training Fold 1:  39%|███▉      | 77/198 [24:42:27<38:05:50, 1133.48s/epoch]

Fold 1, Epoch 130/250, Loss: 0.2160
Checkpoint saved for fold 1 at epoch 130


Training Fold 1:  39%|███▉      | 78/198 [24:48:39<38:09:09, 1144.58s/epoch]

Fold 1, Epoch 130/250 - MSE: 1.1937, CI: 0.8518, Pearson: 0.8485


Training Fold 1:  39%|███▉      | 78/198 [25:01:47<38:09:09, 1144.58s/epoch]

Fold 1, Epoch 131/250, Loss: 0.2174
Checkpoint saved for fold 1 at epoch 131


Training Fold 1:  40%|███▉      | 79/198 [25:08:01<38:00:12, 1149.69s/epoch]

Fold 1, Epoch 131/250 - MSE: 1.1955, CI: 0.8554, Pearson: 0.8547


Training Fold 1:  40%|███▉      | 79/198 [25:20:47<38:00:12, 1149.69s/epoch]

Fold 1, Epoch 132/250, Loss: 0.2147
Checkpoint saved for fold 1 at epoch 132


Training Fold 1:  40%|████      | 80/198 [25:26:57<37:33:11, 1145.69s/epoch]

Fold 1, Epoch 132/250 - MSE: 1.1935, CI: 0.8545, Pearson: 0.8560


Training Fold 1:  40%|████      | 80/198 [25:39:54<37:33:11, 1145.69s/epoch]

Fold 1, Epoch 133/250, Loss: 0.2171
Checkpoint saved for fold 1 at epoch 133


Training Fold 1:  41%|████      | 81/198 [25:46:44<37:38:13, 1158.06s/epoch]

Fold 1, Epoch 133/250 - MSE: 1.1685, CI: 0.8558, Pearson: 0.8508


Training Fold 1:  41%|████      | 81/198 [25:59:49<37:38:13, 1158.06s/epoch]

Fold 1, Epoch 134/250, Loss: 0.2096
Checkpoint saved for fold 1 at epoch 134


Training Fold 1:  41%|████      | 81/198 [26:03:10<37:37:54, 1157.90s/epoch]


KeyboardInterrupt: 

In [9]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_pred - y_true ) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]

def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 256 # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * len(batch_samples)

            avg_loss = running_loss / len(train_files)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 256  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results


if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 250  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
Loading checkpoint for fold 1 from prepared_samples/TrainingModel/model_fold1_epoch134.pt
Resuming training from epoch 135


Training Fold 1:   0%|          | 0/116 [13:33<?, ?epoch/s]

Fold 1, Epoch 135/250, Loss: 0.2120
Checkpoint saved for fold 1 at epoch 135


Training Fold 1:   1%|          | 1/116 [20:02<38:24:04, 1202.13s/epoch]

Fold 1, Epoch 135/250 - MSE: 1.1602, CI: 0.8533, Pearson: 0.8513


Training Fold 1:   1%|          | 1/116 [33:27<38:24:04, 1202.13s/epoch]

Fold 1, Epoch 136/250, Loss: 0.2120
Checkpoint saved for fold 1 at epoch 136


Training Fold 1:   2%|▏         | 2/116 [39:45<37:42:40, 1190.88s/epoch]

Fold 1, Epoch 136/250 - MSE: 1.1482, CI: 0.8550, Pearson: 0.8539


Training Fold 1:   2%|▏         | 2/116 [53:13<37:42:40, 1190.88s/epoch]

Fold 1, Epoch 137/250, Loss: 0.2086
Checkpoint saved for fold 1 at epoch 137


Training Fold 1:   3%|▎         | 3/116 [59:29<37:17:36, 1188.11s/epoch]

Fold 1, Epoch 137/250 - MSE: 1.1583, CI: 0.8573, Pearson: 0.8539


Training Fold 1:   3%|▎         | 3/116 [1:12:42<37:17:36, 1188.11s/epoch]

Fold 1, Epoch 138/250, Loss: 0.2108
Checkpoint saved for fold 1 at epoch 138


Training Fold 1:   3%|▎         | 4/116 [1:18:55<36:41:11, 1179.21s/epoch]

Fold 1, Epoch 138/250 - MSE: 1.1979, CI: 0.8581, Pearson: 0.8551


Training Fold 1:   3%|▎         | 4/116 [1:31:53<36:41:11, 1179.21s/epoch]

Fold 1, Epoch 139/250, Loss: 0.2066
Checkpoint saved for fold 1 at epoch 139


Training Fold 1:   4%|▍         | 5/116 [1:38:10<36:05:20, 1170.45s/epoch]

Fold 1, Epoch 139/250 - MSE: 1.1869, CI: 0.8572, Pearson: 0.8548


Training Fold 1:   4%|▍         | 5/116 [1:51:22<36:05:20, 1170.45s/epoch]

Fold 1, Epoch 140/250, Loss: 0.2072
Checkpoint saved for fold 1 at epoch 140


Training Fold 1:   5%|▌         | 6/116 [1:57:45<35:48:26, 1171.88s/epoch]

Fold 1, Epoch 140/250 - MSE: 1.1759, CI: 0.8585, Pearson: 0.8554


Training Fold 1:   5%|▌         | 6/116 [2:11:39<35:48:26, 1171.88s/epoch]

Fold 1, Epoch 141/250, Loss: 0.2098
Checkpoint saved for fold 1 at epoch 141


Training Fold 1:   6%|▌         | 7/116 [2:18:03<35:56:46, 1187.22s/epoch]

Fold 1, Epoch 141/250 - MSE: 1.1773, CI: 0.8583, Pearson: 0.8572


Training Fold 1:   6%|▌         | 7/116 [2:32:09<35:56:46, 1187.22s/epoch]

Fold 1, Epoch 142/250, Loss: 0.2065
Checkpoint saved for fold 1 at epoch 142


Training Fold 1:   7%|▋         | 8/116 [2:38:40<36:05:03, 1202.81s/epoch]

Fold 1, Epoch 142/250 - MSE: 1.1901, CI: 0.8552, Pearson: 0.8563


Training Fold 1:   7%|▋         | 8/116 [2:52:32<36:05:03, 1202.81s/epoch]

Fold 1, Epoch 143/250, Loss: 0.2054
Checkpoint saved for fold 1 at epoch 143


Training Fold 1:   8%|▊         | 9/116 [2:58:53<35:51:11, 1206.28s/epoch]

Fold 1, Epoch 143/250 - MSE: 1.1932, CI: 0.8556, Pearson: 0.8538


Training Fold 1:   8%|▊         | 9/116 [3:11:53<35:51:11, 1206.28s/epoch]

Fold 1, Epoch 144/250, Loss: 0.2046
Checkpoint saved for fold 1 at epoch 144


Training Fold 1:   9%|▊         | 10/116 [3:17:58<34:57:12, 1187.10s/epoch]

Fold 1, Epoch 144/250 - MSE: 1.2295, CI: 0.8577, Pearson: 0.8603


Training Fold 1:   9%|▊         | 10/116 [3:31:52<34:57:12, 1187.10s/epoch]

Fold 1, Epoch 145/250, Loss: 0.2031
Checkpoint saved for fold 1 at epoch 145


Training Fold 1:   9%|▉         | 11/116 [3:38:12<34:51:59, 1195.42s/epoch]

Fold 1, Epoch 145/250 - MSE: 1.2576, CI: 0.8547, Pearson: 0.8558


Training Fold 1:   9%|▉         | 11/116 [3:52:02<34:51:59, 1195.42s/epoch]

Fold 1, Epoch 146/250, Loss: 0.2000
Checkpoint saved for fold 1 at epoch 146


Training Fold 1:  10%|█         | 12/116 [3:58:24<34:40:52, 1200.50s/epoch]

Fold 1, Epoch 146/250 - MSE: 1.1480, CI: 0.8580, Pearson: 0.8585


Training Fold 1:  10%|█         | 12/116 [4:11:55<34:40:52, 1200.50s/epoch]

Fold 1, Epoch 147/250, Loss: 0.2007
Checkpoint saved for fold 1 at epoch 147


Training Fold 1:  11%|█         | 13/116 [4:18:18<34:17:29, 1198.54s/epoch]

Fold 1, Epoch 147/250 - MSE: 1.1918, CI: 0.8577, Pearson: 0.8513


Training Fold 1:  11%|█         | 13/116 [4:31:16<34:17:29, 1198.54s/epoch]

Fold 1, Epoch 148/250, Loss: 0.2017
Checkpoint saved for fold 1 at epoch 148


Training Fold 1:  12%|█▏        | 14/116 [4:37:29<33:33:11, 1184.23s/epoch]

Fold 1, Epoch 148/250 - MSE: 1.2058, CI: 0.8554, Pearson: 0.8578


Training Fold 1:  12%|█▏        | 14/116 [4:49:53<33:33:11, 1184.23s/epoch]

Fold 1, Epoch 149/250, Loss: 0.1988
Checkpoint saved for fold 1 at epoch 149


Training Fold 1:  13%|█▎        | 15/116 [4:56:05<32:38:37, 1163.54s/epoch]

Fold 1, Epoch 149/250 - MSE: 1.2168, CI: 0.8583, Pearson: 0.8591


Training Fold 1:  13%|█▎        | 15/116 [5:08:40<32:38:37, 1163.54s/epoch]

Fold 1, Epoch 150/250, Loss: 0.1993
Checkpoint saved for fold 1 at epoch 150


Training Fold 1:  14%|█▍        | 16/116 [5:14:46<31:58:08, 1150.88s/epoch]

Fold 1, Epoch 150/250 - MSE: 1.2167, CI: 0.8583, Pearson: 0.8582


Training Fold 1:  14%|█▍        | 16/116 [5:27:06<31:58:08, 1150.88s/epoch]

Fold 1, Epoch 151/250, Loss: 0.1998
Checkpoint saved for fold 1 at epoch 151


Training Fold 1:  15%|█▍        | 17/116 [5:33:13<31:17:00, 1137.58s/epoch]

Fold 1, Epoch 151/250 - MSE: 1.1708, CI: 0.8579, Pearson: 0.8534


Training Fold 1:  15%|█▍        | 17/116 [5:45:54<31:17:00, 1137.58s/epoch]

Fold 1, Epoch 152/250, Loss: 0.2011
Checkpoint saved for fold 1 at epoch 152


Training Fold 1:  16%|█▌        | 18/116 [5:52:20<31:02:28, 1140.29s/epoch]

Fold 1, Epoch 152/250 - MSE: 1.2071, CI: 0.8562, Pearson: 0.8553


Training Fold 1:  16%|█▌        | 18/116 [6:05:49<31:02:28, 1140.29s/epoch]

Fold 1, Epoch 153/250, Loss: 0.1969
Checkpoint saved for fold 1 at epoch 153


Training Fold 1:  16%|█▋        | 19/116 [6:12:13<31:09:17, 1156.27s/epoch]

Fold 1, Epoch 153/250 - MSE: 1.2210, CI: 0.8577, Pearson: 0.8570


Training Fold 1:  16%|█▋        | 19/116 [6:25:37<31:09:17, 1156.27s/epoch]

Fold 1, Epoch 154/250, Loss: 0.1955
Checkpoint saved for fold 1 at epoch 154


Training Fold 1:  17%|█▋        | 20/116 [6:31:41<30:55:49, 1159.89s/epoch]

Fold 1, Epoch 154/250 - MSE: 1.2090, CI: 0.8576, Pearson: 0.8570


Training Fold 1:  17%|█▋        | 20/116 [6:43:56<30:55:49, 1159.89s/epoch]

Fold 1, Epoch 155/250, Loss: 0.1965
Checkpoint saved for fold 1 at epoch 155


Training Fold 1:  18%|█▊        | 21/116 [6:50:07<30:10:50, 1143.69s/epoch]

Fold 1, Epoch 155/250 - MSE: 1.2240, CI: 0.8596, Pearson: 0.8593


Training Fold 1:  18%|█▊        | 21/116 [7:02:33<30:10:50, 1143.69s/epoch]

Fold 1, Epoch 156/250, Loss: 0.1918
Checkpoint saved for fold 1 at epoch 156


Training Fold 1:  19%|█▉        | 22/116 [7:08:38<29:36:06, 1133.69s/epoch]

Fold 1, Epoch 156/250 - MSE: 1.2119, CI: 0.8578, Pearson: 0.8547


Training Fold 1:  19%|█▉        | 22/116 [7:21:03<29:36:06, 1133.69s/epoch]

Fold 1, Epoch 157/250, Loss: 0.1926
Checkpoint saved for fold 1 at epoch 157


Training Fold 1:  20%|█▉        | 23/116 [7:27:09<29:06:42, 1126.91s/epoch]

Fold 1, Epoch 157/250 - MSE: 1.2402, CI: 0.8578, Pearson: 0.8581


Training Fold 1:  20%|█▉        | 23/116 [7:39:31<29:06:42, 1126.91s/epoch]

Fold 1, Epoch 158/250, Loss: 0.1888
Checkpoint saved for fold 1 at epoch 158


Training Fold 1:  21%|██        | 24/116 [7:45:40<28:40:54, 1122.34s/epoch]

Fold 1, Epoch 158/250 - MSE: 1.1576, CI: 0.8608, Pearson: 0.8586


Training Fold 1:  21%|██        | 24/116 [7:58:08<28:40:54, 1122.34s/epoch]

Fold 1, Epoch 159/250, Loss: 0.1898
Checkpoint saved for fold 1 at epoch 159


Training Fold 1:  22%|██▏       | 25/116 [8:04:14<28:18:21, 1119.79s/epoch]

Fold 1, Epoch 159/250 - MSE: 1.2068, CI: 0.8578, Pearson: 0.8593


Training Fold 1:  22%|██▏       | 25/116 [8:16:44<28:18:21, 1119.79s/epoch]

Fold 1, Epoch 160/250, Loss: 0.1888
Checkpoint saved for fold 1 at epoch 160


Training Fold 1:  22%|██▏       | 26/116 [8:22:56<28:00:40, 1120.45s/epoch]

Fold 1, Epoch 160/250 - MSE: 1.1977, CI: 0.8601, Pearson: 0.8594


Training Fold 1:  22%|██▏       | 26/116 [8:35:23<28:00:40, 1120.45s/epoch]

Fold 1, Epoch 161/250, Loss: 0.1910
Checkpoint saved for fold 1 at epoch 161


Training Fold 1:  23%|██▎       | 27/116 [8:41:31<27:39:16, 1118.62s/epoch]

Fold 1, Epoch 161/250 - MSE: 1.1951, CI: 0.8622, Pearson: 0.8627


Training Fold 1:  23%|██▎       | 27/116 [8:54:05<27:39:16, 1118.62s/epoch]

Fold 1, Epoch 162/250, Loss: 0.1888
Checkpoint saved for fold 1 at epoch 162


Training Fold 1:  24%|██▍       | 28/116 [9:00:17<27:24:13, 1121.06s/epoch]

Fold 1, Epoch 162/250 - MSE: 1.1946, CI: 0.8620, Pearson: 0.8626


Training Fold 1:  24%|██▍       | 28/116 [9:13:09<27:24:13, 1121.06s/epoch]

Fold 1, Epoch 163/250, Loss: 0.1879
Checkpoint saved for fold 1 at epoch 163


Training Fold 1:  25%|██▌       | 29/116 [9:19:26<27:17:27, 1129.28s/epoch]

Fold 1, Epoch 163/250 - MSE: 1.2116, CI: 0.8603, Pearson: 0.8574


Training Fold 1:  25%|██▌       | 29/116 [9:33:36<27:17:27, 1129.28s/epoch]

Fold 1, Epoch 164/250, Loss: 0.1898
Checkpoint saved for fold 1 at epoch 164


Training Fold 1:  26%|██▌       | 30/116 [9:39:50<27:39:26, 1157.75s/epoch]

Fold 1, Epoch 164/250 - MSE: 1.2335, CI: 0.8618, Pearson: 0.8602


Training Fold 1:  26%|██▌       | 30/116 [9:54:03<27:39:26, 1157.75s/epoch]

Fold 1, Epoch 165/250, Loss: 0.1869
Checkpoint saved for fold 1 at epoch 165


Training Fold 1:  27%|██▋       | 31/116 [10:00:30<27:55:07, 1182.44s/epoch]

Fold 1, Epoch 165/250 - MSE: 1.1986, CI: 0.8637, Pearson: 0.8622


Training Fold 1:  27%|██▋       | 31/116 [10:13:46<27:55:07, 1182.44s/epoch]

Fold 1, Epoch 166/250, Loss: 0.1833
Checkpoint saved for fold 1 at epoch 166


Training Fold 1:  28%|██▊       | 32/116 [10:19:57<27:28:51, 1177.76s/epoch]

Fold 1, Epoch 166/250 - MSE: 1.1585, CI: 0.8632, Pearson: 0.8611


Training Fold 1:  28%|██▊       | 32/116 [10:32:23<27:28:51, 1177.76s/epoch]

Fold 1, Epoch 167/250, Loss: 0.1842
Checkpoint saved for fold 1 at epoch 167


Training Fold 1:  28%|██▊       | 33/116 [10:38:27<26:41:13, 1157.52s/epoch]

Fold 1, Epoch 167/250 - MSE: 1.1861, CI: 0.8614, Pearson: 0.8607


Training Fold 1:  28%|██▊       | 33/116 [10:50:50<26:41:13, 1157.52s/epoch]

Fold 1, Epoch 168/250, Loss: 0.1858
Checkpoint saved for fold 1 at epoch 168


Training Fold 1:  29%|██▉       | 34/116 [10:56:58<26:02:37, 1143.39s/epoch]

Fold 1, Epoch 168/250 - MSE: 1.1994, CI: 0.8617, Pearson: 0.8616


Training Fold 1:  29%|██▉       | 34/116 [11:09:23<26:02:37, 1143.39s/epoch]

Fold 1, Epoch 169/250, Loss: 0.1851
Checkpoint saved for fold 1 at epoch 169


Training Fold 1:  30%|███       | 35/116 [11:15:30<25:30:54, 1134.01s/epoch]

Fold 1, Epoch 169/250 - MSE: 1.1956, CI: 0.8618, Pearson: 0.8602


Training Fold 1:  30%|███       | 35/116 [11:27:54<25:30:54, 1134.01s/epoch]

Fold 1, Epoch 170/250, Loss: 0.1857
Checkpoint saved for fold 1 at epoch 170


Training Fold 1:  31%|███       | 36/116 [11:34:01<25:02:52, 1127.15s/epoch]

Fold 1, Epoch 170/250 - MSE: 1.2293, CI: 0.8613, Pearson: 0.8589


Training Fold 1:  31%|███       | 36/116 [11:46:59<25:02:52, 1127.15s/epoch]

Fold 1, Epoch 171/250, Loss: 0.1848
Checkpoint saved for fold 1 at epoch 171


Training Fold 1:  32%|███▏      | 37/116 [11:53:10<24:52:53, 1133.85s/epoch]

Fold 1, Epoch 171/250 - MSE: 1.1696, CI: 0.8594, Pearson: 0.8578


Training Fold 1:  32%|███▏      | 37/116 [12:06:00<24:52:53, 1133.85s/epoch]

Fold 1, Epoch 172/250, Loss: 0.1814
Checkpoint saved for fold 1 at epoch 172


Training Fold 1:  33%|███▎      | 38/116 [12:12:12<24:37:08, 1136.26s/epoch]

Fold 1, Epoch 172/250 - MSE: 1.1879, CI: 0.8591, Pearson: 0.8605


Training Fold 1:  33%|███▎      | 38/116 [12:24:55<24:37:08, 1136.26s/epoch]

Fold 1, Epoch 173/250, Loss: 0.1799
Checkpoint saved for fold 1 at epoch 173


Training Fold 1:  34%|███▎      | 39/116 [12:31:08<24:17:53, 1136.01s/epoch]

Fold 1, Epoch 173/250 - MSE: 1.1888, CI: 0.8608, Pearson: 0.8590


Training Fold 1:  34%|███▎      | 39/116 [12:43:54<24:17:53, 1136.01s/epoch]

Fold 1, Epoch 174/250, Loss: 0.1788
Checkpoint saved for fold 1 at epoch 174


Training Fold 1:  34%|███▍      | 40/116 [12:50:11<24:01:42, 1138.19s/epoch]

Fold 1, Epoch 174/250 - MSE: 1.1817, CI: 0.8635, Pearson: 0.8620


Training Fold 1:  34%|███▍      | 40/116 [12:58:53<24:39:53, 1168.34s/epoch]


KeyboardInterrupt: 

In [12]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools

# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_pred - y_true ) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]

def train_5fold_cross_validation(sample_dir, num_epochs=1000, n_splits=5, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Create a single directory for all checkpoints
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    results = []
    loss_fn = MSELoss()

    for fold, (train_idx, test_idx) in enumerate(kfold.split(sample_files)):
        fold_number = fold + 1
        print(f'\nFold {fold_number}/{n_splits}')
        train_files = [sample_files[i] for i in train_idx]
        test_files = [sample_files[i] for i in test_idx]

        # Determine input feature dimensions from your data
        sample = load_sample(os.path.join(sample_dir, train_files[0]))
        mol_data = sample[0]
        pro_data = sample[1]

        num_features_mol = mol_data.x.size(1)
        num_features_pro = pro_data.x.size(1)

        # Initialize the GNN model with correct input dimensions
        model = GNNNet(
            num_features_mol=num_features_mol,
            num_features_pro=num_features_pro
        ).to(device)
        print(f"Model is on device: {next(model.parameters()).device}")

        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Initialize starting epoch
        start_epoch = 1

        # Check for existing checkpoints in TrainingModel directory for the current fold
        existing_checkpoints = [f for f in os.listdir(training_model_dir)
                                if f.endswith('.pt') and f.startswith(f'model_fold{fold_number}_epoch')]

        if existing_checkpoints:
            # Find the latest checkpoint based on epoch number
            latest_checkpoint = max(existing_checkpoints,
                                    key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
            checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
            print(f"Loading checkpoint for fold {fold_number} from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            loaded_epoch = checkpoint['epoch']
            start_epoch = loaded_epoch + 1
            print(f"Resuming training from epoch {start_epoch}")
        else:
            print(f"No checkpoint found for fold {fold_number}, starting training from scratch.")

        # Training loop with progress bar over epochs
        for epoch in tqdm(range(start_epoch, num_epochs + 1),
                          desc=f"Training Fold {fold_number}", unit="epoch"):
            model.train()
            running_loss = 0.0

            # Prepare batch loader without progress bar for batches
            batch_size = 256 # Adjust batch size as needed
            batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                optimizer.zero_grad()
                output = model(mol_batch, pro_batch)
                loss = loss_fn(output.view(-1), target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            avg_loss = running_loss / len(batch_samples)
            # Use tqdm.write() to print without interfering with the progress bar
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

            # Save the model and optimizer states after each epoch
            checkpoint_filename = f"model_fold{fold_number}_epoch{epoch}.pt"
            checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            tqdm.write(f"Checkpoint saved for fold {fold_number} at epoch {epoch}")

            # Evaluation on the test set after each epoch
            model.eval()
            total_preds, total_labels = [], []
            with torch.no_grad():
                batch_size = 256  # Adjust batch size as needed
                batch_loader_iter = batch_loader(test_files, sample_dir, batch_size=batch_size)

                for batch_samples in batch_loader_iter:
                    mol_data_list = []
                    pro_data_list = []
                    target_list = []

                    for sample in batch_samples:
                        mol_data = sample[0]
                        pro_data = sample[1]
                        target = sample[2]

                        mol_data_list.append(mol_data)
                        pro_data_list.append(pro_data)
                        target_list.append(target)

                    mol_batch = Batch.from_data_list(mol_data_list).to(device)
                    pro_batch = Batch.from_data_list(pro_data_list).to(device)
                    target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                    output = model(mol_batch, pro_batch)
                    total_preds.append(output.cpu().numpy())
                    total_labels.append(target.cpu().numpy())

            # Convert lists to numpy arrays for evaluation
            total_preds = np.concatenate(total_preds)
            total_labels = np.concatenate(total_labels)

            # Calculate metrics
            mse = get_mse(total_labels, total_preds)
            ci = get_ci(total_labels, total_preds)
            pearson = get_pearson(total_labels, total_preds)

            # Print metrics
            tqdm.write(f"Fold {fold_number}, Epoch {epoch}/{num_epochs} - MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")

        # Evaluation at the end of training for this fold
        print(f"Final evaluation for Fold {fold_number}: MSE: {mse:.4f}, CI: {ci:.4f}, Pearson: {pearson:.4f}")
        # Store results for this fold
        results.append((mse, ci, pearson))

    return results


if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 250  # Adjust the number of epochs as needed
    n_splits = 5  # Number of folds for cross-validation
    learning_rate = 0.001  # Learning rate

    # Run the training function
    results = train_5fold_cross_validation(sample_dir, num_epochs=num_epochs, n_splits=n_splits, lr=learning_rate)

    # Print overall results
    print("\nCross-validation Results:")
    for fold_idx, (mse, ci, pearson) in enumerate(results):
        print(f"Fold {fold_idx + 1}: MSE={mse:.4f}, CI={ci:.4f}, Pearson={pearson:.4f}")

    # Optionally, compute and print average metrics across folds
    mse_values, ci_values, pearson_values = zip(*results)
    print(f"\nAverage Results:")
    print(f"MSE: {np.mean(mse_values):.4f}")
    print(f"CI: {np.mean(ci_values):.4f}")
    print(f"Pearson Correlation: {np.mean(pearson_values):.4f}")


Running on cpu.
Using existing TrainingModel directory at prepared_samples/TrainingModel

Fold 1/5
GNNNet Loaded
Model is on device: cpu
Loading checkpoint for fold 1 from prepared_samples/TrainingModel/model_fold1_epoch249.pt


KeyError: 'model_state_dict'