In [1]:
!pip install torch-geometric

Collecting torch-geometric


  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m






Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
[?25h

Installing collected packages: torch-geometric


Successfully installed torch-geometric-2.6.1


In [2]:

#create samples :
import os
import torch
import pickle
import pandas as pd

def load_graph(path, is_pickle=True):
    """
    Load a molecule graph (.pkl) or a protein graph (.pt).
    If is_pickle is True, use pickle to load the file; otherwise, use torch.load.
    """
    if is_pickle:
        with open(path, 'rb') as f:
            return pickle.load(f)
    else:
        return torch.load(path)

def prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir):
    """
    Incrementally prepares the dataset and saves each (molecule, protein, target) tuple as a separate .pt file.

    Args:
    - filtered_dataset: The filtered KIBA dataset (DataFrame).
    - molecule_graph_dir: Directory where molecule graphs are stored.
    - protein_graph_dir: Directory where protein graphs are stored.
    - output_dir: Directory to save the prepared dataset incrementally.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for index, row in filtered_dataset.iterrows():
        protein_id = row['Target_ID']
        chembl_id = row['Drug_ID']

        # Load the protein graph (.pt)
        pro_graph_path = os.path.join(protein_graph_dir, f"{protein_id}_graph.pt")
        if not os.path.exists(pro_graph_path):
            print(f"Protein graph not found: {protein_id}")
            continue
        pro_graph = load_graph(pro_graph_path, is_pickle=False)

        # Load the molecule graph (.pkl)
        mol_graph_path = os.path.join(molecule_graph_dir, f"{chembl_id}_graph.pkl")
        if not os.path.exists(mol_graph_path):
            print(f"Molecule graph not found: {chembl_id}")
            continue
        mol_graph = load_graph(mol_graph_path)

        # Load target (affinity value)
        target = torch.tensor([row['Y']], dtype=torch.float)

        # Create the sample as a tuple (molecule graph, protein graph, target)
        sample = (mol_graph, pro_graph, target)

        # Save the sample as a .pt file
        sample_path = os.path.join(output_dir, f"sample_{index}.pt")
        torch.save(sample, sample_path)

        if(index%10000 == 0 ):
            print(f"Saved sample {index} as {sample_path}")




# Example usage for individual saving
molecule_graph_dir = '//kaggle/input/davis-graphs/molecule_graphs'  # Directory where molecule graphs are stored
protein_graph_dir = '/kaggle/input/davis-graphs/ProteinGraphs'  # Directory where protein graphs are stored
filtered_dataset_path = '/kaggle/input/davis-graphs/filtered_DavisDataSet.csv'  # Path to the filtered dataset CSV
output_dir = 'prepared_samples/'  # Directory to save individual samples

# Load filtered dataset CSV
filtered_dataset = pd.read_csv(filtered_dataset_path)

# Prepare the dataset incrementally, saving each sample as a .pt file
prepare_dataset_individual_save_as_pt(filtered_dataset, molecule_graph_dir, protein_graph_dir, output_dir)

print("Dataset preparation completed.")



  return torch.load(path)


Saved sample 0 as prepared_samples/sample_0.pt


  return torch.load(path)


Saved sample 10000 as prepared_samples/sample_10000.pt


Dataset preparation completed.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class NonLocalBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=None, chunk_size=2048):
        super().__init__()
        self.output_dim = output_dim or input_dim

        # GCN-based projections (Φ, Θ, Γ)
        # Φ: H = σ(Â H W_φ + b_φ), where Â = D^{-1/2} (A + I) D^{-1/2}
        self.W_phi = GCNConv(input_dim, hidden_dim)
        self.W_theta = GCNConv(input_dim, hidden_dim)
        self.W_gamma = GCNConv(input_dim, hidden_dim)

        # Output projection: O = O_{agg} W_{out} + b_{out}
        self.W_out = nn.Linear(hidden_dim, self.output_dim)

        # Residual connection: V_n = O + V_m (if dimensions match)
        self.res_proj = nn.Linear(input_dim, self.output_dim) if input_dim != self.output_dim else nn.Identity()

        self.chunk_size = chunk_size

    def forward(self, V_m, edge_index):
        # Step 1: Compute Φ, Θ, Γ using GCNConv
        # Φ = GCN(V_m, edge_index), Θ = GCN(V_m, edge_index), Γ = GCN(V_m, edge_index)
        phi = self.W_phi(V_m, edge_index)  # Φ ∈ ℝ^{N×h}
        theta = self.W_theta(V_m, edge_index)  # Θ ∈ ℝ^{N×h}
        gamma = self.W_gamma(V_m, edge_index)  # Γ ∈ ℝ^{N×h}

        N = phi.size(0)
        O = torch.zeros_like(gamma)

        # Step 2: Chunked attention computation
        for i in range(0, N, self.chunk_size):
            # For chunk i: Φ_chunk = Φ[i:i+chunk_size]
            phi_chunk = phi[i:i + self.chunk_size]  # Φ_chunk ∈ ℝ^{C×h}

            # Compute similarity scores: S = (Θ Φ_chunk^T) / √h
            sim_chunk = torch.einsum('nh,ch->cn', theta, phi_chunk)  # S ∈ ℝ^{C×N}
            sim_chunk = sim_chunk / (phi.size(1) ** 0.5)  # Scaling by √h

            # Softmax normalization: α = softmax(S, dim=1)
            attn_chunk = F.softmax(sim_chunk, dim=1)  # α ∈ ℝ^{C×N}

            # Aggregate values: O_chunk = α Γ
            O_chunk = torch.einsum('cn,nh->ch', attn_chunk, gamma)  # O_chunk ∈ ℝ^{C×h}
            O[i:i + self.chunk_size] = O_chunk

            del phi_chunk, sim_chunk, attn_chunk, O_chunk
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Step 3: Projection and residual connection
        O_projected = self.W_out(O)  # O_projected ∈ ℝ^{N×d_out}
        V_n = O_projected + self.res_proj(V_m)  # V_n = O_projected + V_m (with dim alignment)
        return V_n

In [4]:
# !rm -rf "TrainingModelNLB"

In [5]:
import os
training_model_dir = os.path.join(os.getcwd(), 'TrainingModelNLB')
os.makedirs(training_model_dir, exist_ok=True)

In [6]:
!cp '/kaggle/input/nlb/pytorch/default/1/model_epoch150.pt'  'TrainingModelNLB/model_epoch150.pt'
!cp '/kaggle/input/nlb/pytorch/default/1/training_metrics.pt'  'TrainingModelNLB/training_metrics.pt'

In [7]:
#model one
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import MSELoss
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np

# Optional, for plotting:
import matplotlib.pyplot as plt

##############################################################################
#                               1. METRICS
##############################################################################

@torch.no_grad()
def ci_vectorized(preds: torch.Tensor, targets: torch.Tensor) -> float:
    """
    Naive O(N^2) Concordance Index using vectorized GPU operations.
    preds, targets: (N,) on the same device (e.g. cuda).
    """
    # (N, N) differences
    p_diff = preds.unsqueeze(1) - preds.unsqueeze(0)
    t_diff = targets.unsqueeze(1) - targets.unsqueeze(0)
    # mask out pairs where targets are identical
    mask = (t_diff != 0)
    # sign of product => +1 (concordant), 0 (tie), -1 (discordant)
    sign_mat = torch.sign(p_diff * t_diff)
    # step function h(x): 1 if x>0, 0.5 if x=0, 0 if x<0
    h = (sign_mat == 1).float() + 0.5 * (sign_mat == 0).float()
    # apply mask
    h_masked = h * mask.float()
    c = h_masked.sum()
    s = mask.sum().float()
    return (c / s).item() if s > 0 else 0.0

@torch.no_grad()
def mse_torch(preds: torch.Tensor, targets: torch.Tensor) -> float:
    """
    GPU-friendly MSE using torch's built-in mean squared error.
    """
    return F.mse_loss(preds, targets, reduction='mean').item()

@torch.no_grad()
def pearson_torch(preds: torch.Tensor, targets: torch.Tensor) -> float:
    """
    Pearson correlation coefficient in PyTorch (GPU-friendly).
    preds, targets: (N,) on the same device
    """
    p_centered = preds - preds.mean()
    t_centered = targets - targets.mean()
    cov = (p_centered * t_centered).sum()
    denom = torch.sqrt((p_centered**2).sum()) * torch.sqrt((t_centered**2).sum())
    eps = 1e-8
    return (cov / (denom + eps)).item()

##############################################################################
#                       2. GNN MODEL DEFINITION
##############################################################################

class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, hidden_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        self.n_output = n_output
        self.hidden_dim = hidden_dim

        # Molecular graph pipeline
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        self.mol_nonlocal = NonLocalBlock(num_features_mol, hidden_dim, num_features_mol * 2)
        self.mol_conv3 = GCNConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # Protein graph pipeline
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        self.pro_nonlocal = NonLocalBlock(num_features_pro, hidden_dim, num_features_pro * 2)
        self.pro_conv3 = GCNConv(num_features_pro * 2, num_features_pro * 4)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro * 4, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)

        # Combined layers
        self.fc1 = torch.nn.Linear(2 * output_dim, 1024)
        self.fc2 = torch.nn.Linear(1024, 512)
        self.out = torch.nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # Molecular graph input
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        mol_x = self.mol_conv1(mol_x, mol_edge_index)
        mol_x = self.relu(mol_x)
        mol_x = self.mol_nonlocal(mol_x ,mol_edge_index)  # Non-local block
        mol_x = self.mol_conv3(mol_x, mol_edge_index)
        mol_x = self.relu(mol_x)
        mol_x = gep(mol_x, mol_batch)  # Global pooling
        mol_x = self.relu(self.mol_fc_g1(mol_x))
        mol_x = self.dropout(mol_x)
        mol_x = self.mol_fc_g2(mol_x)
        mol_x = self.dropout(mol_x)

        # Protein graph input
        pro_x, pro_edge_index, pro_batch = data_pro.x, data_pro.edge_index, data_pro.batch
        pro_x = self.pro_conv1(pro_x, pro_edge_index)
        pro_x = self.relu(pro_x)
        pro_x = self.pro_nonlocal(pro_x,pro_edge_index)  # Non-local block
        pro_x = self.pro_conv3(pro_x, pro_edge_index)
        pro_x = self.relu(pro_x)
        pro_x = gep(pro_x, pro_batch)  # Global pooling
        pro_x = self.relu(self.pro_fc_g1(pro_x))
        pro_x = self.dropout(pro_x)
        pro_x = self.pro_fc_g2(pro_x)
        pro_x = self.dropout(pro_x)

        # Concatenate molecular and protein features
        xc = torch.cat((mol_x, pro_x), dim=1)
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out



##############################################################################
#                   3. DATA LOADING HELPERS
##############################################################################

warnings.filterwarnings('ignore', category=FutureWarning)

def load_sample(path):
    """
    Load a .pt sample, fix up 'x' and 'edge_index', return (mol_data, pro_data, target).
    """
    sample = torch.load(path)
    mol_data, pro_data, target = sample[0], sample[1], sample[2]

    # Convert dict to Data if needed
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Fix x
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data missing 'x' or 'features'")
    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data missing 'x' or 'features'")

    mol_data.x = torch.as_tensor(mol_data.x, dtype=torch.float32)
    pro_data.x = torch.as_tensor(pro_data.x, dtype=torch.float32)

    # Fix edge_index
    def fix_edge_index(d):
        if not isinstance(d.edge_index, torch.Tensor):
            d.edge_index = torch.tensor(d.edge_index, dtype=torch.long)
        else:
            d.edge_index = d.edge_index.long()
        if d.edge_index.shape[0] != 2:
            d.edge_index = d.edge_index.t()
        d.num_nodes = d.x.size(0)

    fix_edge_index(mol_data)
    fix_edge_index(pro_data)

    return (mol_data, pro_data, target)

def batch_loader(file_list, sample_dir, batch_size):
    """
    Yields batches of (mol_data, pro_data, target) from file_list.
    """
    batch = []
    for file_name in file_list:
        path = os.path.join(sample_dir, file_name)
        sample = load_sample(path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:
        yield batch

##############################################################################
#              4. TRAINING / EVALUATION WITH METRICS EACH EPOCH
##############################################################################
import os
import torch

def train_and_evaluate(sample_dir, num_epochs=10, test_size=0.2, lr=0.001):
    """
    Trains the GNN model, evaluates on train & test each epoch, saves metrics + checkpoints.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    # Gather samples
    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
    assert len(sample_files) > 0, "No .pt files found in sample_dir!"

    # Split
    train_files, test_files = train_test_split(sample_files, test_size=test_size, random_state=42)

    # Make checkpoint dir
    training_model_dir = os.path.join(os.getcwd(), 'TrainingModelNLB')
    os.makedirs(training_model_dir, exist_ok=True)
    print(f"Checkpoints will be saved to: {training_model_dir}")

    metrics_path = os.path.join(training_model_dir, "training_metrics.pt")

    # Load existing metrics if available (Ensures metrics continue from previous runs)
    if os.path.exists(metrics_path):
        saved_metrics = torch.load(metrics_path)
        train_metrics = saved_metrics['train_metrics']
        test_metrics = saved_metrics['test_metrics']
        print("Loaded previous training metrics!")
    else:
        train_metrics = {'epoch': [], 'mse': [], 'ci': [], 'pearson': []}
        test_metrics = {'epoch': [], 'mse': [], 'ci': [], 'pearson': []}
        print("Starting fresh metrics tracking.")

    # Infer input dims from one sample
    sample0 = load_sample(os.path.join(sample_dir, train_files[0]))
    mol_data0, pro_data0 = sample0[0], sample0[1]
    num_features_mol = mol_data0.x.size(1)
    num_features_pro = pro_data0.x.size(1)

    # Initialize model
    model = GNNNet(num_features_mol=num_features_mol,
                   num_features_pro=num_features_pro).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = MSELoss()

    # Possibly resume from checkpoint
    start_epoch = 1
    existing_checkpoints = [
        f for f in os.listdir(training_model_dir)
        if f.endswith('.pt') and f.startswith('model_epoch')
    ]
    last_ckpt_path = None  # Track previous checkpoint for deletion

    if existing_checkpoints:
        latest_ckpt = max(existing_checkpoints, key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
        ckpt_path = os.path.join(training_model_dir, latest_ckpt)
        print(f"Loading checkpoint from {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        loaded_epoch = ckpt['epoch']
        start_epoch = loaded_epoch + 1
        last_ckpt_path = ckpt_path  # Store last checkpoint path for deletion
        print(f"Resuming from epoch {start_epoch}")
    else:
        print("No existing checkpoint found; starting fresh.")

    @torch.no_grad()
    def evaluate(files):
        model.eval()
        all_preds = []
        all_targets = []
        batch_eval_size = 200
        for batch_samples in batch_loader(files, sample_dir, batch_eval_size):
            mol_list, pro_list, tgt_list = [], [], []
            for (md, pd, t) in batch_samples:
                mol_list.append(md)
                pro_list.append(pd)
                tgt_list.append(t)

            mol_batch = Batch.from_data_list(mol_list).to(device)
            pro_batch = Batch.from_data_list(pro_list).to(device)
            t_tensor = torch.tensor(tgt_list, dtype=torch.float32, device=device)

            out = model(mol_batch, pro_batch).view(-1)
            all_preds.append(out)
            all_targets.append(t_tensor)

        all_preds = torch.cat(all_preds, dim=0)
        all_targets = torch.cat(all_targets, dim=0)

        mse_val = mse_torch(all_preds, all_targets)
        ci_val = ci_vectorized(all_preds, all_targets)
        pearson_val = pearson_torch(all_preds, all_targets)
        return mse_val, ci_val, pearson_val

    # Training loop
    batch_size = 61
    for epoch in tqdm(range(start_epoch, num_epochs + 1), desc="Training", unit="epoch"):
        model.train()
        running_loss = 0.0

        for batch_samples in batch_loader(train_files, sample_dir, batch_size):
            mol_list, pro_list, tgt_list = [], [], []
            for (md, pd, t) in batch_samples:
                mol_list.append(md)
                pro_list.append(pd)
                tgt_list.append(t)

            mol_batch = Batch.from_data_list(mol_list).to(device)
            pro_batch = Batch.from_data_list(pro_list).to(device)
            t_tensor = torch.tensor(tgt_list, dtype=torch.float32, device=device).view(-1)

            optimizer.zero_grad()
            out = model(mol_batch, pro_batch).view(-1)
            loss = loss_fn(out, t_tensor)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * len(batch_samples)

        avg_loss = running_loss / len(train_files)
        tqdm.write(f"[Epoch {epoch}/{num_epochs}] Training Loss: {avg_loss:.4f}")

        # Evaluate on train & test
        train_mse, train_ci, train_pearson = evaluate(train_files)
        test_mse, test_ci, test_pearson = evaluate(test_files)

        train_metrics['epoch'].append(epoch)
        train_metrics['mse'].append(train_mse)
        train_metrics['ci'].append(train_ci)
        train_metrics['pearson'].append(train_pearson)

        test_metrics['epoch'].append(epoch)
        test_metrics['mse'].append(test_mse)
        test_metrics['ci'].append(test_ci)
        test_metrics['pearson'].append(test_pearson)

        tqdm.write(f"  Train => MSE={train_mse:.4f}, CI={train_ci:.4f}, Pearson={train_pearson:.4f}")
        tqdm.write(f"  Test  => MSE={test_mse:.4f}, CI={test_ci:.4f}, Pearson={test_pearson:.4f}")

        # Save new checkpoint
        ckpt_name = f"model_epoch{epoch}.pt"
        ckpt_path = os.path.join(training_model_dir, ckpt_name)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, ckpt_path)
        tqdm.write(f"Checkpoint saved at epoch {epoch}")

        # Delete the previous checkpoint after saving the new one
        if last_ckpt_path and os.path.exists(last_ckpt_path):
            os.remove(last_ckpt_path)
            tqdm.write(f"Deleted previous checkpoint: {last_ckpt_path}")

        # Update last checkpoint path
        last_ckpt_path = ckpt_path

        # Save/Update the metrics after each epoch
        torch.save({
            'train_metrics': train_metrics,
            'test_metrics': test_metrics
        }, metrics_path)

    # Final evaluations
    final_train_mse, final_train_ci, final_train_pearson = evaluate(train_files)
    final_test_mse, final_test_ci, final_test_pearson = evaluate(test_files)
    print(f"\nFinal Train => MSE={final_train_mse:.4f}, CI={final_train_ci:.4f}, Pearson={final_train_pearson:.4f}")
    print(f"Final Test  => MSE={final_test_mse:.4f}, CI={final_test_ci:.4f}, Pearson={final_test_pearson:.4f}")

    # Save final metrics
    torch.save({
        'train_metrics': train_metrics,
        'test_metrics': test_metrics
    }, metrics_path)
    print(f"Metrics saved to {metrics_path}")

    return train_metrics, test_metrics


##############################################################################
#               5. OPTIONAL: PLOT THE SAVED METRICS
##############################################################################

def plot_metrics(checkpoint_dir='TrainingModelNLB'):
    """
    Load training_metrics.pt from the checkpoint_dir and plot MSE, CI, Pearson over epochs.
    """
    metrics_path = os.path.join(checkpoint_dir, "training_metrics.pt")
    if not os.path.exists(metrics_path):
        print(f"No metrics file found at {metrics_path}!")
        return

    saved_data = torch.load(metrics_path)
    train_metrics = saved_data['train_metrics']
    test_metrics = saved_data['test_metrics']
    epochs = train_metrics['epoch']

    # Plot MSE
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_metrics['mse'], 'o-', label='Train MSE')
    plt.plot(epochs, test_metrics['mse'], 'o-', label='Test MSE')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.title('Mean Squared Error over Epochs')
    plt.legend()
    plt.savefig(os.path.join(checkpoint_dir, "MSE_plot.png"))  
    plt.show()

    # Plot CI
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_metrics['ci'], 'o-', label='Train CI')
    plt.plot(epochs, test_metrics['ci'], 'o-', label='Test CI')
    plt.xlabel('Epoch')
    plt.ylabel('Concordance Index')
    plt.title('CI over Epochs')
    plt.savefig(os.path.join(checkpoint_dir, "CI_plot.png"))  
    plt.legend()
    plt.show()

    # Plot Pearson
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_metrics['pearson'], 'o-', label='Train Pearson')
    plt.plot(epochs, test_metrics['pearson'], 'o-', label='Test Pearson')
    plt.xlabel('Epoch')
    plt.ylabel('Pearson Correlation')
    plt.title('Pearson Correlation over Epochs')
    plt.legend()
    plt.savefig(os.path.join(checkpoint_dir, "PEARSON_plot.png"))  
    plt.show()

##############################################################################
#                                 MAIN
##############################################################################

if __name__ == "__main__":
    # Adjust the paths/parameters as needed
    SAMPLE_DIR = "prepared_samples"   # Directory with your .pt samples
    NUM_EPOCHS = 250
    TEST_SPLIT = 0.2
    LR = 0.001

    # 1) Train and evaluate
    train_metrics, test_metrics = train_and_evaluate(
        sample_dir=SAMPLE_DIR,
        num_epochs=NUM_EPOCHS,
        test_size=TEST_SPLIT,
        lr=LR
    )

    # 2) Plot the metrics
    plot_metrics('TrainingModelNLB')

Running on cuda.
Checkpoints will be saved to: /kaggle/working/TrainingModelNLB
Loaded previous training metrics!
GNNNet Loaded


Loading checkpoint from /kaggle/working/TrainingModelNLB/model_epoch150.pt
Resuming from epoch 151


Training:   0%|          | 0/100 [00:00<?, ?epoch/s]

                                                    



Training:   0%|          | 0/100 [07:55<?, ?epoch/s]

[Epoch 151/250] Training Loss: 0.3301


                                                    



Training:   0%|          | 0/100 [23:34<?, ?epoch/s]

                                                    



Training:   0%|          | 0/100 [23:34<?, ?epoch/s]

                                                    



Training:   0%|          | 0/100 [23:34<?, ?epoch/s]

                                                    



Training:   0%|          | 0/100 [23:34<?, ?epoch/s]

Training:   1%|          | 1/100 [23:34<38:54:06, 1414.61s/epoch]

  Train => MSE=0.2510, CI=0.8717, Pearson=0.8293
  Test  => MSE=0.2784, CI=0.8620, Pearson=0.8034
Checkpoint saved at epoch 151
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch150.pt


                                                                 



Training:   1%|          | 1/100 [31:26<38:54:06, 1414.61s/epoch]

[Epoch 152/250] Training Loss: 0.3130


                                                                 



Training:   1%|          | 1/100 [47:05<38:54:06, 1414.61s/epoch]

                                                                 



Training:   1%|          | 1/100 [47:05<38:54:06, 1414.61s/epoch]

                                                                 



Training:   1%|          | 1/100 [47:05<38:54:06, 1414.61s/epoch]

                                                                 



Training:   1%|          | 1/100 [47:05<38:54:06, 1414.61s/epoch]

Training:   2%|▏         | 2/100 [47:05<38:26:45, 1412.30s/epoch]

  Train => MSE=0.2394, CI=0.8624, Pearson=0.8305
  Test  => MSE=0.2733, CI=0.8469, Pearson=0.7988
Checkpoint saved at epoch 152
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch151.pt


                                                                 



Training:   2%|▏         | 2/100 [54:57<38:26:45, 1412.30s/epoch]

[Epoch 153/250] Training Loss: 0.2971


                                                                 



Training:   2%|▏         | 2/100 [1:10:35<38:26:45, 1412.30s/epoch]

                                                                   



Training:   2%|▏         | 2/100 [1:10:35<38:26:45, 1412.30s/epoch]

                                                                   



Training:   2%|▏         | 2/100 [1:10:35<38:26:45, 1412.30s/epoch]

                                                                   



Training:   2%|▏         | 2/100 [1:10:35<38:26:45, 1412.30s/epoch]

Training:   3%|▎         | 3/100 [1:10:35<38:01:55, 1411.50s/epoch]

  Train => MSE=0.2231, CI=0.8789, Pearson=0.8416
  Test  => MSE=0.2534, CI=0.8676, Pearson=0.8134
Checkpoint saved at epoch 153
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch152.pt


                                                                   



Training:   3%|▎         | 3/100 [1:18:27<38:01:55, 1411.50s/epoch]

[Epoch 154/250] Training Loss: 0.2948


                                                                   



Training:   3%|▎         | 3/100 [1:34:05<38:01:55, 1411.50s/epoch]

                                                                   



Training:   3%|▎         | 3/100 [1:34:05<38:01:55, 1411.50s/epoch]

                                                                   



Training:   3%|▎         | 3/100 [1:34:05<38:01:55, 1411.50s/epoch]

                                                                   



Training:   3%|▎         | 3/100 [1:34:05<38:01:55, 1411.50s/epoch]

Training:   4%|▍         | 4/100 [1:34:05<37:37:15, 1410.79s/epoch]

  Train => MSE=0.2123, CI=0.8772, Pearson=0.8564
  Test  => MSE=0.2566, CI=0.8606, Pearson=0.8167
Checkpoint saved at epoch 154
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch153.pt


                                                                   



Training:   4%|▍         | 4/100 [1:41:57<37:37:15, 1410.79s/epoch]

[Epoch 155/250] Training Loss: 0.2922


                                                                   



Training:   4%|▍         | 4/100 [1:57:35<37:37:15, 1410.79s/epoch]

                                                                   



Training:   4%|▍         | 4/100 [1:57:35<37:37:15, 1410.79s/epoch]

                                                                   



Training:   4%|▍         | 4/100 [1:57:35<37:37:15, 1410.79s/epoch]

                                                                   



Training:   4%|▍         | 4/100 [1:57:35<37:37:15, 1410.79s/epoch]

Training:   5%|▌         | 5/100 [1:57:35<37:13:22, 1410.56s/epoch]

  Train => MSE=0.2000, CI=0.8854, Pearson=0.8561
  Test  => MSE=0.2396, CI=0.8737, Pearson=0.8222
Checkpoint saved at epoch 155
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch154.pt


                                                                   



Training:   5%|▌         | 5/100 [2:05:27<37:13:22, 1410.56s/epoch]

[Epoch 156/250] Training Loss: 0.2789


                                                                   



Training:   5%|▌         | 5/100 [2:21:06<37:13:22, 1410.56s/epoch]

                                                                   



Training:   5%|▌         | 5/100 [2:21:06<37:13:22, 1410.56s/epoch]

                                                                   



Training:   5%|▌         | 5/100 [2:21:06<37:13:22, 1410.56s/epoch]

                                                                   



Training:   5%|▌         | 5/100 [2:21:06<37:13:22, 1410.56s/epoch]

Training:   6%|▌         | 6/100 [2:21:06<36:50:11, 1410.76s/epoch]

  Train => MSE=0.2320, CI=0.8776, Pearson=0.8377
  Test  => MSE=0.2723, CI=0.8661, Pearson=0.7993
Checkpoint saved at epoch 156
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch155.pt


                                                                   



Training:   6%|▌         | 6/100 [2:28:58<36:50:11, 1410.76s/epoch]

[Epoch 157/250] Training Loss: 0.2752


                                                                   



Training:   6%|▌         | 6/100 [2:44:36<36:50:11, 1410.76s/epoch]

                                                                   



Training:   6%|▌         | 6/100 [2:44:36<36:50:11, 1410.76s/epoch]

                                                                   



Training:   6%|▌         | 6/100 [2:44:36<36:50:11, 1410.76s/epoch]

                                                                   



Training:   6%|▌         | 6/100 [2:44:36<36:50:11, 1410.76s/epoch]

Training:   7%|▋         | 7/100 [2:44:36<36:26:19, 1410.53s/epoch]

  Train => MSE=0.2049, CI=0.8791, Pearson=0.8542
  Test  => MSE=0.2427, CI=0.8599, Pearson=0.8212
Checkpoint saved at epoch 157
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch156.pt


                                                                   



Training:   7%|▋         | 7/100 [2:52:28<36:26:19, 1410.53s/epoch]

[Epoch 158/250] Training Loss: 0.2811


                                                                   



Training:   7%|▋         | 7/100 [3:08:07<36:26:19, 1410.53s/epoch]

                                                                   



Training:   7%|▋         | 7/100 [3:08:07<36:26:19, 1410.53s/epoch]

                                                                   



Training:   7%|▋         | 7/100 [3:08:07<36:26:19, 1410.53s/epoch]

                                                                   



Training:   7%|▋         | 7/100 [3:08:07<36:26:19, 1410.53s/epoch]

Training:   8%|▊         | 8/100 [3:08:07<36:02:44, 1410.49s/epoch]

  Train => MSE=0.2088, CI=0.8781, Pearson=0.8527
  Test  => MSE=0.2557, CI=0.8621, Pearson=0.8103
Checkpoint saved at epoch 158
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch157.pt


                                                                   



Training:   8%|▊         | 8/100 [3:15:59<36:02:44, 1410.49s/epoch]

[Epoch 159/250] Training Loss: 0.2760


                                                                   



Training:   8%|▊         | 8/100 [3:31:36<36:02:44, 1410.49s/epoch]

                                                                   



Training:   8%|▊         | 8/100 [3:31:36<36:02:44, 1410.49s/epoch]

                                                                   



Training:   8%|▊         | 8/100 [3:31:36<36:02:44, 1410.49s/epoch]

                                                                   



Training:   8%|▊         | 8/100 [3:31:36<36:02:44, 1410.49s/epoch]

Training:   9%|▉         | 9/100 [3:31:36<35:38:45, 1410.17s/epoch]

  Train => MSE=0.2068, CI=0.8804, Pearson=0.8522
  Test  => MSE=0.2601, CI=0.8685, Pearson=0.8055
Checkpoint saved at epoch 159
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch158.pt


                                                                   



Training:   9%|▉         | 9/100 [3:39:28<35:38:45, 1410.17s/epoch]

[Epoch 160/250] Training Loss: 0.2737


                                                                   



Training:   9%|▉         | 9/100 [3:55:05<35:38:45, 1410.17s/epoch]

                                                                   



Training:   9%|▉         | 9/100 [3:55:05<35:38:45, 1410.17s/epoch]

                                                                   



Training:   9%|▉         | 9/100 [3:55:06<35:38:45, 1410.17s/epoch]

                                                                   



Training:   9%|▉         | 9/100 [3:55:06<35:38:45, 1410.17s/epoch]

Training:  10%|█         | 10/100 [3:55:06<35:14:49, 1409.88s/epoch]

  Train => MSE=0.2079, CI=0.8824, Pearson=0.8543
  Test  => MSE=0.2644, CI=0.8657, Pearson=0.8047
Checkpoint saved at epoch 160
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch159.pt


                                                                    



Training:  10%|█         | 10/100 [4:02:58<35:14:49, 1409.88s/epoch]

[Epoch 161/250] Training Loss: 0.2734


                                                                    



Training:  10%|█         | 10/100 [4:18:35<35:14:49, 1409.88s/epoch]

                                                                    



Training:  10%|█         | 10/100 [4:18:35<35:14:49, 1409.88s/epoch]

                                                                    



Training:  10%|█         | 10/100 [4:18:35<35:14:49, 1409.88s/epoch]

                                                                    



Training:  10%|█         | 10/100 [4:18:35<35:14:49, 1409.88s/epoch]

Training:  11%|█         | 11/100 [4:18:35<34:51:15, 1409.84s/epoch]

  Train => MSE=0.2058, CI=0.8825, Pearson=0.8581
  Test  => MSE=0.2608, CI=0.8628, Pearson=0.8086
Checkpoint saved at epoch 161
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch160.pt


                                                                    



Training:  11%|█         | 11/100 [4:26:28<34:51:15, 1409.84s/epoch]

[Epoch 162/250] Training Loss: 0.2793


                                                                    



Training:  11%|█         | 11/100 [4:42:06<34:51:15, 1409.84s/epoch]

                                                                    



Training:  11%|█         | 11/100 [4:42:06<34:51:15, 1409.84s/epoch]

                                                                    



Training:  11%|█         | 11/100 [4:42:06<34:51:15, 1409.84s/epoch]

                                                                    



Training:  11%|█         | 11/100 [4:42:06<34:51:15, 1409.84s/epoch]

Training:  12%|█▏        | 12/100 [4:42:06<34:28:18, 1410.22s/epoch]

  Train => MSE=0.2050, CI=0.8813, Pearson=0.8556
  Test  => MSE=0.2557, CI=0.8638, Pearson=0.8104
Checkpoint saved at epoch 162
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch161.pt


                                                                    



Training:  12%|█▏        | 12/100 [4:49:58<34:28:18, 1410.22s/epoch]

[Epoch 163/250] Training Loss: 0.2681


                                                                    



Training:  12%|█▏        | 12/100 [5:05:35<34:28:18, 1410.22s/epoch]

                                                                    



Training:  12%|█▏        | 12/100 [5:05:35<34:28:18, 1410.22s/epoch]

                                                                    



Training:  12%|█▏        | 12/100 [5:05:35<34:28:18, 1410.22s/epoch]

                                                                    



Training:  12%|█▏        | 12/100 [5:05:35<34:28:18, 1410.22s/epoch]

Training:  13%|█▎        | 13/100 [5:05:35<34:04:09, 1409.76s/epoch]

  Train => MSE=0.2019, CI=0.8872, Pearson=0.8634
  Test  => MSE=0.2544, CI=0.8629, Pearson=0.8169
Checkpoint saved at epoch 163
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch162.pt


                                                                    



Training:  13%|█▎        | 13/100 [5:13:26<34:04:09, 1409.76s/epoch]

[Epoch 164/250] Training Loss: 0.2693


                                                                    



Training:  13%|█▎        | 13/100 [5:29:04<34:04:09, 1409.76s/epoch]

                                                                    



Training:  13%|█▎        | 13/100 [5:29:04<34:04:09, 1409.76s/epoch]

                                                                    



Training:  13%|█▎        | 13/100 [5:29:04<34:04:09, 1409.76s/epoch]

                                                                    



Training:  13%|█▎        | 13/100 [5:29:04<34:04:09, 1409.76s/epoch]

Training:  14%|█▍        | 14/100 [5:29:04<33:40:22, 1409.56s/epoch]

  Train => MSE=0.1922, CI=0.8880, Pearson=0.8716
  Test  => MSE=0.2487, CI=0.8669, Pearson=0.8222
Checkpoint saved at epoch 164
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch163.pt


                                                                    



Training:  14%|█▍        | 14/100 [5:36:55<33:40:22, 1409.56s/epoch]

[Epoch 165/250] Training Loss: 0.2640


                                                                    



Training:  14%|█▍        | 14/100 [5:52:33<33:40:22, 1409.56s/epoch]

                                                                    



Training:  14%|█▍        | 14/100 [5:52:33<33:40:22, 1409.56s/epoch]

                                                                    



Training:  14%|█▍        | 14/100 [5:52:33<33:40:22, 1409.56s/epoch]

                                                                    



Training:  14%|█▍        | 14/100 [5:52:33<33:40:22, 1409.56s/epoch]

Training:  15%|█▌        | 15/100 [5:52:33<33:16:32, 1409.32s/epoch]

  Train => MSE=0.1953, CI=0.8776, Pearson=0.8605
  Test  => MSE=0.2483, CI=0.8568, Pearson=0.8154
Checkpoint saved at epoch 165
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch164.pt


                                                                    



Training:  15%|█▌        | 15/100 [6:00:25<33:16:32, 1409.32s/epoch]

[Epoch 166/250] Training Loss: 0.2739


                                                                    



Training:  15%|█▌        | 15/100 [6:16:03<33:16:32, 1409.32s/epoch]

                                                                    



Training:  15%|█▌        | 15/100 [6:16:03<33:16:32, 1409.32s/epoch]

                                                                    



Training:  15%|█▌        | 15/100 [6:16:03<33:16:32, 1409.32s/epoch]

                                                                    



Training:  15%|█▌        | 15/100 [6:16:03<33:16:32, 1409.32s/epoch]

Training:  16%|█▌        | 16/100 [6:16:03<32:53:33, 1409.68s/epoch]

  Train => MSE=0.2096, CI=0.8830, Pearson=0.8537
  Test  => MSE=0.2682, CI=0.8569, Pearson=0.8015
Checkpoint saved at epoch 166
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch165.pt


                                                                    



Training:  16%|█▌        | 16/100 [6:23:56<32:53:33, 1409.68s/epoch]

[Epoch 167/250] Training Loss: 0.2609


                                                                    



Training:  16%|█▌        | 16/100 [6:39:34<32:53:33, 1409.68s/epoch]

                                                                    



Training:  16%|█▌        | 16/100 [6:39:34<32:53:33, 1409.68s/epoch]

                                                                    



Training:  16%|█▌        | 16/100 [6:39:34<32:53:33, 1409.68s/epoch]

                                                                    



Training:  16%|█▌        | 16/100 [6:39:34<32:53:33, 1409.68s/epoch]

Training:  17%|█▋        | 17/100 [6:39:34<32:30:35, 1410.06s/epoch]

  Train => MSE=0.1912, CI=0.8875, Pearson=0.8693
  Test  => MSE=0.2494, CI=0.8590, Pearson=0.8174
Checkpoint saved at epoch 167
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch166.pt


                                                                    



Training:  17%|█▋        | 17/100 [6:47:26<32:30:35, 1410.06s/epoch]

[Epoch 168/250] Training Loss: 0.2593


                                                                    



Training:  17%|█▋        | 17/100 [7:03:03<32:30:35, 1410.06s/epoch]

                                                                    



Training:  17%|█▋        | 17/100 [7:03:03<32:30:35, 1410.06s/epoch]

                                                                    



Training:  17%|█▋        | 17/100 [7:03:03<32:30:35, 1410.06s/epoch]

                                                                    



Training:  17%|█▋        | 17/100 [7:03:03<32:30:35, 1410.06s/epoch]

Training:  18%|█▊        | 18/100 [7:03:03<32:06:28, 1409.61s/epoch]

  Train => MSE=0.1980, CI=0.8845, Pearson=0.8597
  Test  => MSE=0.2502, CI=0.8613, Pearson=0.8140
Checkpoint saved at epoch 168
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch167.pt


                                                                    



Training:  18%|█▊        | 18/100 [7:10:56<32:06:28, 1409.61s/epoch]

[Epoch 169/250] Training Loss: 0.2624


                                                                    



Training:  18%|█▊        | 18/100 [7:26:33<32:06:28, 1409.61s/epoch]

                                                                    



Training:  18%|█▊        | 18/100 [7:26:33<32:06:28, 1409.61s/epoch]

                                                                    



Training:  18%|█▊        | 18/100 [7:26:33<32:06:28, 1409.61s/epoch]

                                                                    



Training:  18%|█▊        | 18/100 [7:26:33<32:06:28, 1409.61s/epoch]

Training:  19%|█▉        | 19/100 [7:26:33<31:43:13, 1409.80s/epoch]

  Train => MSE=0.2013, CI=0.8851, Pearson=0.8681
  Test  => MSE=0.2569, CI=0.8595, Pearson=0.8172
Checkpoint saved at epoch 169
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch168.pt


                                                                    



Training:  19%|█▉        | 19/100 [7:34:25<31:43:13, 1409.80s/epoch]

[Epoch 170/250] Training Loss: 0.2582


                                                                    



Training:  19%|█▉        | 19/100 [7:50:03<31:43:13, 1409.80s/epoch]

                                                                    



Training:  19%|█▉        | 19/100 [7:50:03<31:43:13, 1409.80s/epoch]

                                                                    



Training:  19%|█▉        | 19/100 [7:50:03<31:43:13, 1409.80s/epoch]

                                                                    



Training:  19%|█▉        | 19/100 [7:50:03<31:43:13, 1409.80s/epoch]

Training:  20%|██        | 20/100 [7:50:03<31:19:40, 1409.76s/epoch]

  Train => MSE=0.1929, CI=0.8843, Pearson=0.8660
  Test  => MSE=0.2476, CI=0.8604, Pearson=0.8176
Checkpoint saved at epoch 170
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch169.pt


                                                                    



Training:  20%|██        | 20/100 [7:57:53<31:19:40, 1409.76s/epoch]

[Epoch 171/250] Training Loss: 0.2559


                                                                    



Training:  20%|██        | 20/100 [8:13:31<31:19:40, 1409.76s/epoch]

                                                                    



Training:  20%|██        | 20/100 [8:13:31<31:19:40, 1409.76s/epoch]

                                                                    



Training:  20%|██        | 20/100 [8:13:31<31:19:40, 1409.76s/epoch]

                                                                    



Training:  20%|██        | 20/100 [8:13:31<31:19:40, 1409.76s/epoch]

Training:  21%|██        | 21/100 [8:13:31<30:55:36, 1409.32s/epoch]

  Train => MSE=0.1887, CI=0.8819, Pearson=0.8661
  Test  => MSE=0.2524, CI=0.8559, Pearson=0.8118
Checkpoint saved at epoch 171
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch170.pt


                                                                    



Training:  21%|██        | 21/100 [8:21:24<30:55:36, 1409.32s/epoch]

[Epoch 172/250] Training Loss: 0.2537


                                                                    



Training:  21%|██        | 21/100 [8:37:01<30:55:36, 1409.32s/epoch]

                                                                    



Training:  21%|██        | 21/100 [8:37:01<30:55:36, 1409.32s/epoch]

                                                                    



Training:  21%|██        | 21/100 [8:37:01<30:55:36, 1409.32s/epoch]

                                                                    



Training:  21%|██        | 21/100 [8:37:01<30:55:36, 1409.32s/epoch]

Training:  22%|██▏       | 22/100 [8:37:01<30:32:23, 1409.53s/epoch]

  Train => MSE=0.1875, CI=0.8854, Pearson=0.8673
  Test  => MSE=0.2553, CI=0.8567, Pearson=0.8097
Checkpoint saved at epoch 172
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch171.pt


                                                                    



Training:  22%|██▏       | 22/100 [8:44:53<30:32:23, 1409.53s/epoch]

[Epoch 173/250] Training Loss: 0.2541


                                                                    



Training:  22%|██▏       | 22/100 [9:00:31<30:32:23, 1409.53s/epoch]

                                                                    



Training:  22%|██▏       | 22/100 [9:00:31<30:32:23, 1409.53s/epoch]

                                                                    



Training:  22%|██▏       | 22/100 [9:00:31<30:32:23, 1409.53s/epoch]

                                                                    



Training:  22%|██▏       | 22/100 [9:00:31<30:32:23, 1409.53s/epoch]

Training:  23%|██▎       | 23/100 [9:00:31<30:09:02, 1409.64s/epoch]

  Train => MSE=0.2089, CI=0.8772, Pearson=0.8536
  Test  => MSE=0.2693, CI=0.8470, Pearson=0.8009
Checkpoint saved at epoch 173
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch172.pt


                                                                    



Training:  23%|██▎       | 23/100 [9:08:23<30:09:02, 1409.64s/epoch]

[Epoch 174/250] Training Loss: 0.2568


                                                                    



Training:  23%|██▎       | 23/100 [9:24:00<30:09:02, 1409.64s/epoch]

                                                                    



Training:  23%|██▎       | 23/100 [9:24:00<30:09:02, 1409.64s/epoch]

                                                                    



Training:  23%|██▎       | 23/100 [9:24:01<30:09:02, 1409.64s/epoch]

                                                                    



Training:  23%|██▎       | 23/100 [9:24:01<30:09:02, 1409.64s/epoch]

Training:  24%|██▍       | 24/100 [9:24:01<29:45:29, 1409.60s/epoch]

  Train => MSE=0.1805, CI=0.8919, Pearson=0.8786
  Test  => MSE=0.2456, CI=0.8608, Pearson=0.8223
Checkpoint saved at epoch 174
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch173.pt


                                                                    



Training:  24%|██▍       | 24/100 [9:31:52<29:45:29, 1409.60s/epoch]

[Epoch 175/250] Training Loss: 0.2463


                                                                    



Training:  24%|██▍       | 24/100 [9:47:30<29:45:29, 1409.60s/epoch]

                                                                    



Training:  24%|██▍       | 24/100 [9:47:30<29:45:29, 1409.60s/epoch]

                                                                    



Training:  24%|██▍       | 24/100 [9:47:30<29:45:29, 1409.60s/epoch]

                                                                    



Training:  24%|██▍       | 24/100 [9:47:30<29:45:29, 1409.60s/epoch]

Training:  25%|██▌       | 25/100 [9:47:30<29:21:57, 1409.56s/epoch]

  Train => MSE=0.1808, CI=0.8927, Pearson=0.8726
  Test  => MSE=0.2453, CI=0.8695, Pearson=0.8197
Checkpoint saved at epoch 175
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch174.pt


                                                                    



Training:  25%|██▌       | 25/100 [9:55:22<29:21:57, 1409.56s/epoch]

[Epoch 176/250] Training Loss: 0.2458


                                                                    



Training:  25%|██▌       | 25/100 [10:11:00<29:21:57, 1409.56s/epoch]

                                                                     



Training:  25%|██▌       | 25/100 [10:11:00<29:21:57, 1409.56s/epoch]

                                                                     



Training:  25%|██▌       | 25/100 [10:11:00<29:21:57, 1409.56s/epoch]

                                                                     



Training:  25%|██▌       | 25/100 [10:11:00<29:21:57, 1409.56s/epoch]

Training:  26%|██▌       | 26/100 [10:11:00<28:58:36, 1409.68s/epoch]

  Train => MSE=0.1697, CI=0.8943, Pearson=0.8810
  Test  => MSE=0.2358, CI=0.8660, Pearson=0.8255
Checkpoint saved at epoch 176
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch175.pt


                                                                     



Training:  26%|██▌       | 26/100 [10:18:54<28:58:36, 1409.68s/epoch]

[Epoch 177/250] Training Loss: 0.2449


                                                                     



Training:  26%|██▌       | 26/100 [10:34:33<28:58:36, 1409.68s/epoch]

                                                                     



Training:  26%|██▌       | 26/100 [10:34:33<28:58:36, 1409.68s/epoch]

                                                                     



Training:  26%|██▌       | 26/100 [10:34:33<28:58:36, 1409.68s/epoch]

                                                                     



Training:  26%|██▌       | 26/100 [10:34:33<28:58:36, 1409.68s/epoch]

Training:  27%|██▋       | 27/100 [10:34:33<28:36:20, 1410.70s/epoch]

  Train => MSE=0.1868, CI=0.8850, Pearson=0.8695
  Test  => MSE=0.2551, CI=0.8558, Pearson=0.8109
Checkpoint saved at epoch 177
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch176.pt


                                                                     



Training:  27%|██▋       | 27/100 [10:42:28<28:36:20, 1410.70s/epoch]

[Epoch 178/250] Training Loss: 0.2481


                                                                     



Training:  27%|██▋       | 27/100 [10:58:06<28:36:20, 1410.70s/epoch]

                                                                     



Training:  27%|██▋       | 27/100 [10:58:06<28:36:20, 1410.70s/epoch]

                                                                     



Training:  27%|██▋       | 27/100 [10:58:06<28:36:20, 1410.70s/epoch]

                                                                     



Training:  27%|██▋       | 27/100 [10:58:06<28:36:20, 1410.70s/epoch]

Training:  28%|██▊       | 28/100 [10:58:06<28:13:32, 1411.28s/epoch]

  Train => MSE=0.1953, CI=0.8814, Pearson=0.8646
  Test  => MSE=0.2578, CI=0.8493, Pearson=0.8099
Checkpoint saved at epoch 178
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch177.pt


                                                                     



Training:  28%|██▊       | 28/100 [11:05:59<28:13:32, 1411.28s/epoch]

[Epoch 179/250] Training Loss: 0.2428


                                                                     



Training:  28%|██▊       | 28/100 [11:21:38<28:13:32, 1411.28s/epoch]

                                                                     



Training:  28%|██▊       | 28/100 [11:21:38<28:13:32, 1411.28s/epoch]

                                                                     



Training:  28%|██▊       | 28/100 [11:21:38<28:13:32, 1411.28s/epoch]

                                                                     



Training:  28%|██▊       | 28/100 [11:21:38<28:13:32, 1411.28s/epoch]

Training:  29%|██▉       | 29/100 [11:21:38<27:50:23, 1411.60s/epoch]

  Train => MSE=0.1763, CI=0.8860, Pearson=0.8757
  Test  => MSE=0.2488, CI=0.8634, Pearson=0.8149
Checkpoint saved at epoch 179
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch178.pt


                                                                     



Training:  29%|██▉       | 29/100 [11:29:32<27:50:23, 1411.60s/epoch]

[Epoch 180/250] Training Loss: 0.2443


                                                                     



Training:  29%|██▉       | 29/100 [11:45:10<27:50:23, 1411.60s/epoch]

                                                                     



Training:  29%|██▉       | 29/100 [11:45:10<27:50:23, 1411.60s/epoch]

                                                                     



Training:  29%|██▉       | 29/100 [11:45:10<27:50:23, 1411.60s/epoch]

                                                                     



Training:  29%|██▉       | 29/100 [11:45:10<27:50:23, 1411.60s/epoch]

Training:  30%|███       | 30/100 [11:45:10<27:26:57, 1411.68s/epoch]

  Train => MSE=0.1796, CI=0.8904, Pearson=0.8796
  Test  => MSE=0.2550, CI=0.8626, Pearson=0.8150
Checkpoint saved at epoch 180
Deleted previous checkpoint: /kaggle/working/TrainingModelNLB/model_epoch179.pt


                                                                     

In [None]:

# import os
# import torch

# def train_and_evaluate(sample_dir, num_epochs=10, test_size=0.2, lr=0.001):
#     """
#     Trains the GNN model, evaluates on train & test each epoch, saves metrics + checkpoints.
#     """
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Running on {device}.")

#     # Gather samples
#     sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]
#     assert len(sample_files) > 0, "No .pt files found in sample_dir!"

#     # Split
#     train_files, test_files = train_test_split(sample_files, test_size=test_size, random_state=42)

#     # Make checkpoint dir
#     training_model_dir = os.path.join(os.getcwd(), 'TrainingModelMul')
#     os.makedirs(training_model_dir, exist_ok=True)
#     print(f"Checkpoints will be saved to: {training_model_dir}")

#     metrics_path = os.path.join(training_model_dir, "training_metrics.pt")

#     # Load existing metrics if available (Ensures metrics continue from previous runs)
#     if os.path.exists(metrics_path):
#         saved_metrics = torch.load(metrics_path)
#         train_metrics = saved_metrics['train_metrics']
#         test_metrics = saved_metrics['test_metrics']
#         print("Loaded previous training metrics!")
#     else:
#         train_metrics = {'epoch': [], 'mse': [], 'ci': [], 'pearson': []}
#         test_metrics = {'epoch': [], 'mse': [], 'ci': [], 'pearson': []}
#         print("Starting fresh metrics tracking.")

#     # Infer input dims from one sample
#     sample0 = load_sample(os.path.join(sample_dir, train_files[0]))
#     mol_data0, pro_data0 = sample0[0], sample0[1]
#     num_features_mol = mol_data0.x.size(1)
#     num_features_pro = pro_data0.x.size(1)

#     # Initialize model
#     model = GNNNet(num_features_mol=num_features_mol,
#                    num_features_pro=num_features_pro).to(device)
#     optimizer = optim.Adam(model.parameters(), lr=lr)
#     loss_fn = MSELoss()

#     # Possibly resume from checkpoint
#     start_epoch = 1
#     existing_checkpoints = [
#         f for f in os.listdir(training_model_dir)
#         if f.endswith('.pt') and f.startswith('model_epoch')
#     ]
#     last_ckpt_path = None  # Track previous checkpoint for deletion

#     if existing_checkpoints:
#         latest_ckpt = max(existing_checkpoints, key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
#         ckpt_path = os.path.join(training_model_dir, latest_ckpt)
#         print(f"Loading checkpoint from {ckpt_path}")
#         ckpt = torch.load(ckpt_path, map_location=device)
#         model.load_state_dict(ckpt['model_state_dict'])
#         optimizer.load_state_dict(ckpt['optimizer_state_dict'])
#         loaded_epoch = ckpt['epoch']
#         start_epoch = loaded_epoch + 1
#         last_ckpt_path = ckpt_path  # Store last checkpoint path for deletion
#         print(f"Resuming from epoch {start_epoch}")
#     else:
#         print("No existing checkpoint found; starting fresh.")

#     @torch.no_grad()
#     def evaluate(files):
#         model.eval()
#         all_preds = []
#         all_targets = []
#         batch_eval_size = 200
#         for batch_samples in batch_loader(files, sample_dir, batch_eval_size):
#             mol_list, pro_list, tgt_list = [], [], []
#             for (md, pd, t) in batch_samples:
#                 mol_list.append(md)
#                 pro_list.append(pd)
#                 tgt_list.append(t)

#             mol_batch = Batch.from_data_list(mol_list).to(device)
#             pro_batch = Batch.from_data_list(pro_list).to(device)
#             t_tensor = torch.tensor(tgt_list, dtype=torch.float32, device=device)

#             out = model(mol_batch, pro_batch).view(-1)
#             all_preds.append(out)
#             all_targets.append(t_tensor)

#         all_preds = torch.cat(all_preds, dim=0)
#         all_targets = torch.cat(all_targets, dim=0)

#         mse_val = mse_torch(all_preds, all_targets)
#         ci_val = ci_vectorized(all_preds, all_targets)
#         pearson_val = pearson_torch(all_preds, all_targets)
#         return mse_val, ci_val, pearson_val

#     # Training loop
#     batch_size = 500
#     for epoch in tqdm(range(start_epoch, num_epochs + 1), desc="Training", unit="epoch"):
#         model.train()
#         running_loss = 0.0

#         for batch_samples in batch_loader(train_files, sample_dir, batch_size):
#             mol_list, pro_list, tgt_list = [], [], []
#             for (md, pd, t) in batch_samples:
#                 mol_list.append(md)
#                 pro_list.append(pd)
#                 tgt_list.append(t)

#             mol_batch = Batch.from_data_list(mol_list).to(device)
#             pro_batch = Batch.from_data_list(pro_list).to(device)
#             t_tensor = torch.tensor(tgt_list, dtype=torch.float32, device=device).view(-1)

#             optimizer.zero_grad()
#             out = model(mol_batch, pro_batch).view(-1)
#             loss = loss_fn(out, t_tensor)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item() * len(batch_samples)

#         avg_loss = running_loss / len(train_files)
#         tqdm.write(f"[Epoch {epoch}/{num_epochs}] Training Loss: {avg_loss:.4f}")

#         # Evaluate on train & test
#         train_mse, train_ci, train_pearson = evaluate(train_files)
#         test_mse, test_ci, test_pearson = evaluate(test_files)

#         train_metrics['epoch'].append(epoch)
#         train_metrics['mse'].append(train_mse)
#         train_metrics['ci'].append(train_ci)
#         train_metrics['pearson'].append(train_pearson)

#         test_metrics['epoch'].append(epoch)
#         test_metrics['mse'].append(test_mse)
#         test_metrics['ci'].append(test_ci)
#         test_metrics['pearson'].append(test_pearson)

#         tqdm.write(f"  Train => MSE={train_mse:.4f}, CI={train_ci:.4f}, Pearson={train_pearson:.4f}")
#         tqdm.write(f"  Test  => MSE={test_mse:.4f}, CI={test_ci:.4f}, Pearson={test_pearson:.4f}")

#         # Save new checkpoint
#         ckpt_name = f"model_epoch{epoch}.pt"
#         ckpt_path = os.path.join(training_model_dir, ckpt_name)
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#         }, ckpt_path)
#         tqdm.write(f"Checkpoint saved at epoch {epoch}")

#         # Delete the previous checkpoint after saving the new one
#         if last_ckpt_path and os.path.exists(last_ckpt_path):
#             os.remove(last_ckpt_path)
#             tqdm.write(f"Deleted previous checkpoint: {last_ckpt_path}")

#         # Update last checkpoint path
#         last_ckpt_path = ckpt_path

#         # Save/Update the metrics after each epoch
#         torch.save({
#             'train_metrics': train_metrics,
#             'test_metrics': test_metrics
#         }, metrics_path)

#     # Final evaluations
#     final_train_mse, final_train_ci, final_train_pearson = evaluate(train_files)
#     final_test_mse, final_test_ci, final_test_pearson = evaluate(test_files)
#     print(f"\nFinal Train => MSE={final_train_mse:.4f}, CI={final_train_ci:.4f}, Pearson={final_train_pearson:.4f}")
#     print(f"Final Test  => MSE={final_test_mse:.4f}, CI={final_test_ci:.4f}, Pearson={final_test_pearson:.4f}")

#     # Save final metrics
#     torch.save({
#         'train_metrics': train_metrics,
#         'test_metrics': test_metrics
#     }, metrics_path)
#     print(f"Metrics saved to {metrics_path}")

#     return train_metrics, test_metrics


# ##############################################################################
# #               5. OPTIONAL: PLOT THE SAVED METRICS
# ##############################################################################

# def plot_metrics(checkpoint_dir='TrainingModelMul'):
#     """
#     Load training_metrics.pt from the checkpoint_dir and plot MSE, CI, Pearson over epochs.
#     """
#     metrics_path = os.path.join(checkpoint_dir, "training_metrics.pt")
#     if not os.path.exists(metrics_path):
#         print(f"No metrics file found at {metrics_path}!")
#         return

#     saved_data = torch.load(metrics_path)
#     train_metrics = saved_data['train_metrics']
#     test_metrics = saved_data['test_metrics']
#     epochs = train_metrics['epoch']

#     # Plot MSE
#     plt.figure(figsize=(8, 6))
#     plt.plot(epochs, train_metrics['mse'], 'o-', label='Train MSE')
#     plt.plot(epochs, test_metrics['mse'], 'o-', label='Test MSE')
#     plt.xlabel('Epoch')
#     plt.ylabel('MSE')
#     plt.title('Mean Squared Error over Epochs')
#     plt.legend()
#     plt.savefig(os.path.join(checkpoint_dir, "MSE_plot.png"))  
#     plt.show()

#     # Plot CI
#     plt.figure(figsize=(8, 6))
#     plt.plot(epochs, train_metrics['ci'], 'o-', label='Train CI')
#     plt.plot(epochs, test_metrics['ci'], 'o-', label='Test CI')
#     plt.xlabel('Epoch')
#     plt.ylabel('Concordance Index')
#     plt.title('CI over Epochs')
#     plt.savefig(os.path.join(checkpoint_dir, "CI_plot.png"))  
#     plt.legend()
#     plt.show()

#     # Plot Pearson
#     plt.figure(figsize=(8, 6))
#     plt.plot(epochs, train_metrics['pearson'], 'o-', label='Train Pearson')
#     plt.plot(epochs, test_metrics['pearson'], 'o-', label='Test Pearson')
#     plt.xlabel('Epoch')
#     plt.ylabel('Pearson Correlation')
#     plt.title('Pearson Correlation over Epochs')
#     plt.legend()
#     plt.savefig(os.path.join(checkpoint_dir, "PEARSON_plot.png"))  
#     plt.show()

# ##############################################################################
# #                                 MAIN
# ##############################################################################

# if __name__ == "__main__":
#     # Adjust the paths/parameters as needed
#     SAMPLE_DIR = "prepared_samples"   # Directory with your .pt samples
#     NUM_EPOCHS = 250
#     TEST_SPLIT = 0.2
#     LR = 0.001

#     # 1) Train and evaluate
#     train_metrics, test_metrics = train_and_evaluate(
#         sample_dir=SAMPLE_DIR,
#         num_epochs=NUM_EPOCHS,
#         test_size=TEST_SPLIT,
#         lr=LR
#     )

#     # 2) Plot the metrics
#     plot_metrics('TrainingModelMul')