In [1]:
#this code goes with multiplication  insted of concatenation 

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj


# GCN based model
class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        self.n_output = n_output
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        self.mol_conv2 = GCNConv(num_features_mol, num_features_mol * 2)
        self.mol_conv3 = GCNConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # self.pro_conv1 = GCNConv(embed_dim, embed_dim)
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        self.pro_conv2 = GCNConv(num_features_pro, num_features_pro * 2)
        self.pro_conv3 = GCNConv(num_features_pro * 2, num_features_pro * 4)
        # self.pro_conv4 = GCNConv(embed_dim * 4, embed_dim * 8)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro * 4, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # combined layers
        self.fc1 = nn.Linear(output_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # get graph input
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        # get protein input
        target_x, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch

        # target_seq=data_pro.target

        # print('size')
        # print('mol_x', mol_x.size(), 'edge_index', mol_edge_index.size(), 'batch', mol_batch.size())
        # print('target_x', target_x.size(), 'target_edge_index', target_batch.size(), 'batch', target_batch.size())

        x = self.mol_conv1(mol_x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv2(x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv3(x, mol_edge_index)
        x = self.relu(x)
        x = gep(x, mol_batch)  # global pooling

        # flatten
        x = self.relu(self.mol_fc_g1(x))
        x = self.dropout(x)
        x = self.mol_fc_g2(x)
        x = self.dropout(x)

        xt = self.pro_conv1(target_x, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv2(xt, target_edge_index)
        xt = self.relu(xt)

        # target_edge_index, _ = dropout_adj(target_edge_index, training=self.training)
        xt = self.pro_conv3(xt, target_edge_index)
        xt = self.relu(xt)

        # xt = self.pro_conv4(xt, target_edge_index)
        # xt = self.relu(xt)
        xt = gep(xt, target_batch)  # global pooling

        # flatten
        xt = self.relu(self.pro_fc_g1(xt))
        xt = self.dropout(xt)
        xt = self.pro_fc_g2(xt)
        xt = self.dropout(xt)

        # print(x.size(), xt.size())
        #  not concat ! multiple applid
        # xc = torch.cat((x, xt), 1)
        xc = x * xt
        # print(xc.size())
        # add some dense layers
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out


In [5]:
import os
import torch
import torch.optim as optim
from torch.nn import MSELoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool as gep
from scipy.stats import pearsonr
import warnings
import itertools
from sklearn.model_selection import train_test_split


# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


# Define the load_sample function
def load_sample(path):
    # Load individual sample from file
    sample = torch.load(path)
    mol_data = sample[0]
    pro_data = sample[1]
    target = sample[2]

    # Convert dictionaries to Data objects if necessary
    if isinstance(mol_data, dict):
        mol_data = Data(**mol_data)
    if isinstance(pro_data, dict):
        pro_data = Data(**pro_data)

    # Ensure that 'x' attribute is set
    if not hasattr(mol_data, 'x') or mol_data.x is None:
        if hasattr(mol_data, 'features'):
            mol_data.x = mol_data.features
            del mol_data.features
        else:
            raise ValueError("mol_data does not have 'x' or 'features' attribute")

    if not hasattr(pro_data, 'x') or pro_data.x is None:
        if hasattr(pro_data, 'features'):
            pro_data.x = pro_data.features
            del pro_data.features
        else:
            raise ValueError("pro_data does not have 'x' or 'features' attribute")

    # Ensure 'x' is a float tensor
    if not isinstance(mol_data.x, torch.Tensor):
        mol_data.x = torch.tensor(mol_data.x)
    if not isinstance(pro_data.x, torch.Tensor):
        pro_data.x = torch.tensor(pro_data.x)

    if mol_data.x.dtype != torch.float:
        mol_data.x = mol_data.x.float()
    if pro_data.x.dtype != torch.float:
        pro_data.x = pro_data.x.float()

    # Adjust 'edge_index' for mol_data
    # Ensure 'edge_index' is a tensor of type torch.long
    if not isinstance(mol_data.edge_index, torch.Tensor):
        mol_data.edge_index = torch.tensor(mol_data.edge_index, dtype=torch.long)
    else:
        mol_data.edge_index = mol_data.edge_index.long()

    # Ensure 'edge_index' has shape [2, num_edges]
    if mol_data.edge_index.shape[0] != 2:
        mol_data.edge_index = mol_data.edge_index.t()

    # Adjust 'edge_index' for pro_data
    if not isinstance(pro_data.edge_index, torch.Tensor):
        pro_data.edge_index = torch.tensor(pro_data.edge_index, dtype=torch.long)
    else:
        pro_data.edge_index = pro_data.edge_index.long()

    if pro_data.edge_index.shape[0] != 2:
        pro_data.edge_index = pro_data.edge_index.t()

    # Set 'num_nodes' attribute to suppress warnings
    mol_data.num_nodes = mol_data.x.size(0)
    pro_data.num_nodes = pro_data.x.size(0)

    return (mol_data, pro_data, target)

# Define the batch_loader function
def batch_loader(file_list, sample_dir, batch_size):
    batch = []
    for idx, file_name in enumerate(file_list):
        sample_path = os.path.join(sample_dir, file_name)
        sample = load_sample(sample_path)
        batch.append(sample)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch

# Define the evaluation metrics functions
def get_mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

def get_ci(y_true, y_pred):
    """
    Compute the concordance index between true and predicted values.
    """
    pairs = itertools.combinations(range(len(y_true)), 2)
    c = 0
    s = 0
    for i, j in pairs:
        if y_true[i] != y_true[j]:
            s += 1
            if (y_true[i] < y_true[j] and y_pred[i] < y_pred[j]) or \
               (y_true[i] > y_true[j] and y_pred[i] > y_pred[j]):
                c += 1
            elif y_pred[i] == y_pred[j]:
                c += 0.5
    return c / s if s != 0 else 0

def get_pearson(y_true, y_pred):
    return pearsonr(y_true.flatten(), y_pred.flatten())[0]



# Suppress FutureWarning related to torch.load
warnings.filterwarnings('ignore', category=FutureWarning)


In [None]:
#calculate metrics on training and testing data at the end of training
def train_and_evaluate(sample_dir, num_epochs=100, test_size=0.2, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}.")

    sample_files = [f for f in os.listdir(sample_dir) if f.endswith('.pt')]

    # Split data into training and test sets
    train_files, test_files = train_test_split(sample_files, test_size=test_size, random_state=42)

    # Create a directory for the model checkpoint
    training_model_dir = os.path.join(sample_dir, 'TrainingModel')
    if not os.path.exists(training_model_dir):
        os.makedirs(training_model_dir)
        print(f"Created directory for checkpoints at {training_model_dir}")
    else:
        print(f"Using existing TrainingModel directory at {training_model_dir}")

    # Determine input feature dimensions from your data
    sample = load_sample(os.path.join(sample_dir, train_files[0]))
    mol_data = sample[0]
    pro_data = sample[1]

    num_features_mol = mol_data.x.size(1)
    num_features_pro = pro_data.x.size(1)

    # Initialize the GNN model with correct input dimensions
    model = GNNNet(
        num_features_mol=num_features_mol,
        num_features_pro=num_features_pro
    ).to(device)
    print(f"Model is on device: {next(model.parameters()).device}")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = MSELoss()

    # Initialize starting epoch
    start_epoch = 1

    # Check for existing checkpoints in TrainingModel directory
    existing_checkpoints = [f for f in os.listdir(training_model_dir)
                            if f.endswith('.pt') and f.startswith('model_epoch')]

    if existing_checkpoints:
        # Find the latest checkpoint based on epoch number
        latest_checkpoint = max(existing_checkpoints,
                                key=lambda x: int(x.split('_epoch')[1].split('.pt')[0]))
        checkpoint_path = os.path.join(training_model_dir, latest_checkpoint)
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loaded_epoch = checkpoint['epoch']
        start_epoch = loaded_epoch + 1
        print(f"Resuming training from epoch {start_epoch}")
    else:
        print("No checkpoint found, starting training from scratch.")

    # Training loop with progress bar over epochs
    for epoch in tqdm(range(start_epoch, num_epochs + 1),
                      desc="Training", unit="epoch"):
        model.train()
        running_loss = 0.0

        # Prepare batch loader
        batch_size = 200  # Adjust batch size as needed
        batch_loader_iter = batch_loader(train_files, sample_dir, batch_size=batch_size)

        for batch_samples in batch_loader_iter:
            mol_data_list = []
            pro_data_list = []
            target_list = []

            for sample in batch_samples:
                mol_data = sample[0]
                pro_data = sample[1]
                target = sample[2]

                mol_data_list.append(mol_data)
                pro_data_list.append(pro_data)
                target_list.append(target)

            mol_batch = Batch.from_data_list(mol_data_list).to(device)
            pro_batch = Batch.from_data_list(pro_data_list).to(device)
            target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

            optimizer.zero_grad()
            output = model(mol_batch, pro_batch)
            loss = loss_fn(output.view(-1), target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * len(batch_samples)

        avg_loss = running_loss / len(train_files)
        tqdm.write(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Save the model and optimizer states after each epoch
        checkpoint_filename = f"model_epoch{epoch}.pt"
        checkpoint_path = os.path.join(training_model_dir, checkpoint_filename)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        tqdm.write(f"Checkpoint saved at epoch {epoch}")

    # Final evaluation on test data
    print("\nEvaluating model after training...")
    model.eval()

    def evaluate(files):
        total_preds, total_labels = [], []
        with torch.no_grad():
            batch_loader_iter = batch_loader(files, sample_dir, batch_size=batch_size)

            for batch_samples in batch_loader_iter:
                mol_data_list = []
                pro_data_list = []
                target_list = []

                for sample in batch_samples:
                    mol_data = sample[0]
                    pro_data = sample[1]
                    target = sample[2]

                    mol_data_list.append(mol_data)
                    pro_data_list.append(pro_data)
                    target_list.append(target)

                mol_batch = Batch.from_data_list(mol_data_list).to(device)
                pro_batch = Batch.from_data_list(pro_data_list).to(device)
                target = torch.tensor(target_list, dtype=torch.float32).view(-1).to(device)

                output = model(mol_batch, pro_batch)
                total_preds.append(output.cpu().numpy())
                total_labels.append(target.cpu().numpy())

        # Convert lists to numpy arrays for evaluation
        total_preds = np.concatenate(total_preds)
        total_labels = np.concatenate(total_labels)

        # Calculate metrics
        mse = get_mse(total_labels, total_preds)
        ci = get_ci(total_labels, total_preds)
        pearson = get_pearson(total_labels, total_preds)

        return mse, ci, pearson

    train_mse, train_ci, train_pearson = evaluate(train_files)
    test_mse, test_ci, test_pearson = evaluate(test_files)

    print(f"Final Training Metrics: MSE: {train_mse:.4f}, CI: {train_ci:.4f}, Pearson: {train_pearson:.4f}")
    print(f"Final Test Metrics: MSE: {test_mse:.4f}, CI: {test_ci:.4f}, Pearson: {test_pearson:.4f}")

if __name__ == "__main__":
    sample_dir = 'prepared_samples'  # Adjust the path to your samples directory
    num_epochs = 300  # Adjust the number of epochs as needed
    test_size = 0.2   # Proportion of the dataset to include in the test split
    learning_rate = 0.001  # Learning rate

    # Run the training function
    train_and_evaluate(
        sample_dir,
        num_epochs=num_epochs,
        test_size=test_size,
        lr=learning_rate
    )


Running on cuda.
Using existing TrainingModel directory at prepared_samples/TrainingModel
GNNNet Loaded
Model is on device: cuda:0
Loading checkpoint from prepared_samples/TrainingModel/model_epoch1.pt
Resuming training from epoch 2


Training:   0%|                          | 1/299 [04:05<20:19:58, 245.63s/epoch]

Epoch 2/300, Loss: 1.1101
Checkpoint saved at epoch 2


Training:   1%|▏                         | 2/299 [06:20<14:53:31, 180.51s/epoch]

Epoch 3/300, Loss: 0.8842
Checkpoint saved at epoch 3


Training:   1%|▎                         | 3/299 [08:31<12:58:37, 157.83s/epoch]

Epoch 4/300, Loss: 0.8255
Checkpoint saved at epoch 4


Training:   1%|▎                         | 4/299 [12:06<14:46:30, 180.31s/epoch]

Epoch 5/300, Loss: 0.8007
Checkpoint saved at epoch 5


Training:   2%|▍                         | 5/299 [14:39<13:55:03, 170.42s/epoch]

Epoch 6/300, Loss: 0.7767
Checkpoint saved at epoch 6


Training:   2%|▌                         | 6/299 [17:04<13:11:20, 162.05s/epoch]

Epoch 7/300, Loss: 0.7626
Checkpoint saved at epoch 7


Training:   2%|▌                         | 7/299 [18:59<11:52:30, 146.41s/epoch]

Epoch 8/300, Loss: 0.7402
Checkpoint saved at epoch 8


Training:   3%|▋                         | 8/299 [20:43<10:45:34, 133.11s/epoch]

Epoch 9/300, Loss: 0.7296
Checkpoint saved at epoch 9


Training:   3%|▊                         | 9/299 [22:30<10:03:24, 124.84s/epoch]

Epoch 10/300, Loss: 0.7236
Checkpoint saved at epoch 10


Training:   3%|▊                         | 10/299 [24:15<9:32:16, 118.81s/epoch]

Epoch 11/300, Loss: 0.6977
Checkpoint saved at epoch 11


Training:   4%|▉                         | 11/299 [26:06<9:18:24, 116.33s/epoch]

Epoch 12/300, Loss: 0.6882
Checkpoint saved at epoch 12


Training:   4%|█                        | 12/299 [28:32<10:00:28, 125.53s/epoch]

Epoch 13/300, Loss: 0.6697
Checkpoint saved at epoch 13


Training:   4%|█                        | 13/299 [31:23<11:02:57, 139.08s/epoch]

Epoch 14/300, Loss: 0.6518
Checkpoint saved at epoch 14


Training:   5%|█▏                       | 14/299 [34:11<11:42:43, 147.94s/epoch]

Epoch 15/300, Loss: 0.6387
Checkpoint saved at epoch 15


Training:   5%|█▎                       | 15/299 [37:07<12:20:05, 156.36s/epoch]

Epoch 16/300, Loss: 0.6215
Checkpoint saved at epoch 16


Training:   5%|█▎                       | 16/299 [39:00<11:15:15, 143.16s/epoch]

Epoch 17/300, Loss: 0.6064
Checkpoint saved at epoch 17


Training:   6%|█▍                       | 17/299 [40:46<10:21:04, 132.14s/epoch]

Epoch 18/300, Loss: 0.6013
Checkpoint saved at epoch 18


Training:   6%|█▌                        | 18/299 [42:31<9:39:57, 123.83s/epoch]

Epoch 19/300, Loss: 0.5855
Checkpoint saved at epoch 19


Training:   6%|█▋                        | 19/299 [44:16<9:12:40, 118.43s/epoch]

Epoch 20/300, Loss: 0.5683
Checkpoint saved at epoch 20


Training:   7%|█▋                        | 20/299 [45:59<8:49:00, 113.77s/epoch]

Epoch 21/300, Loss: 0.5583
Checkpoint saved at epoch 21


Training:   7%|█▊                        | 21/299 [47:41<8:30:48, 110.25s/epoch]

Epoch 22/300, Loss: 0.5596
Checkpoint saved at epoch 22


Training:   7%|█▉                        | 22/299 [49:27<8:22:03, 108.75s/epoch]

Epoch 23/300, Loss: 0.5375
Checkpoint saved at epoch 23


Training:   8%|██                        | 23/299 [51:13<8:16:27, 107.92s/epoch]

Epoch 24/300, Loss: 0.5332
Checkpoint saved at epoch 24


Training:   8%|██                        | 24/299 [52:58<8:11:37, 107.27s/epoch]

Epoch 25/300, Loss: 0.5190
Checkpoint saved at epoch 25


Training:   8%|██▏                       | 25/299 [54:45<8:08:30, 106.97s/epoch]

Epoch 26/300, Loss: 0.5193
Checkpoint saved at epoch 26


Training:   9%|██▎                       | 26/299 [56:32<8:07:39, 107.18s/epoch]

Epoch 27/300, Loss: 0.5018
Checkpoint saved at epoch 27


Training:   9%|██▎                       | 27/299 [58:18<8:03:37, 106.68s/epoch]

Epoch 28/300, Loss: 0.4891
Checkpoint saved at epoch 28


Training:   9%|██▏                     | 28/299 [1:00:02<7:59:06, 106.08s/epoch]

Epoch 29/300, Loss: 0.4877
Checkpoint saved at epoch 29


Training:  10%|██▎                     | 29/299 [1:01:48<7:56:50, 105.96s/epoch]

Epoch 30/300, Loss: 0.4802
Checkpoint saved at epoch 30


Training:  10%|██▍                     | 30/299 [1:03:34<7:55:14, 106.00s/epoch]

Epoch 31/300, Loss: 0.4723
Checkpoint saved at epoch 31


Training:  10%|██▍                     | 31/299 [1:05:19<7:51:24, 105.54s/epoch]

Epoch 32/300, Loss: 0.4636
Checkpoint saved at epoch 32


Training:  11%|██▌                     | 32/299 [1:07:03<7:48:04, 105.19s/epoch]

Epoch 33/300, Loss: 0.4579
Checkpoint saved at epoch 33


Training:  11%|██▋                     | 33/299 [1:08:49<7:47:36, 105.48s/epoch]

Epoch 34/300, Loss: 0.4496
Checkpoint saved at epoch 34


Training:  11%|██▋                     | 34/299 [1:10:34<7:44:51, 105.25s/epoch]

Epoch 35/300, Loss: 0.4485
Checkpoint saved at epoch 35


Training:  12%|██▊                     | 35/299 [1:12:19<7:43:01, 105.23s/epoch]

Epoch 36/300, Loss: 0.4295
Checkpoint saved at epoch 36


Training:  12%|██▉                     | 36/299 [1:14:04<7:40:14, 105.00s/epoch]

Epoch 37/300, Loss: 0.4308
Checkpoint saved at epoch 37


Training:  12%|██▉                     | 37/299 [1:15:48<7:37:13, 104.71s/epoch]

Epoch 38/300, Loss: 0.4243
Checkpoint saved at epoch 38


Training:  13%|███                     | 38/299 [1:17:32<7:35:10, 104.64s/epoch]

Epoch 39/300, Loss: 0.4183
Checkpoint saved at epoch 39


Training:  13%|███▏                    | 39/299 [1:19:17<7:33:12, 104.59s/epoch]

Epoch 40/300, Loss: 0.4117
Checkpoint saved at epoch 40


Training:  13%|███▏                    | 40/299 [1:21:00<7:30:19, 104.32s/epoch]

Epoch 41/300, Loss: 0.4029
Checkpoint saved at epoch 41


Training:  14%|███▎                    | 41/299 [1:22:45<7:29:09, 104.46s/epoch]

Epoch 42/300, Loss: 0.3967
Checkpoint saved at epoch 42


Training:  14%|███▎                    | 42/299 [1:24:30<7:27:43, 104.53s/epoch]

Epoch 43/300, Loss: 0.3893
Checkpoint saved at epoch 43


Training:  14%|███▍                    | 43/299 [1:26:15<7:27:21, 104.85s/epoch]

Epoch 44/300, Loss: 0.3826
Checkpoint saved at epoch 44


Training:  15%|███▌                    | 44/299 [1:27:59<7:24:20, 104.55s/epoch]

Epoch 45/300, Loss: 0.3844
Checkpoint saved at epoch 45


Training:  15%|███▌                    | 45/299 [1:29:43<7:22:15, 104.47s/epoch]

Epoch 46/300, Loss: 0.3793
Checkpoint saved at epoch 46


Training:  15%|███▋                    | 46/299 [1:31:29<7:21:18, 104.66s/epoch]

Epoch 47/300, Loss: 0.3716
Checkpoint saved at epoch 47


Training:  16%|███▊                    | 47/299 [1:33:12<7:17:30, 104.17s/epoch]

Epoch 48/300, Loss: 0.3694
Checkpoint saved at epoch 48


Training:  16%|███▊                    | 48/299 [1:34:56<7:15:58, 104.22s/epoch]

Epoch 49/300, Loss: 0.3654
Checkpoint saved at epoch 49


Training:  16%|███▉                    | 49/299 [1:36:40<7:13:35, 104.06s/epoch]

Epoch 50/300, Loss: 0.3579
Checkpoint saved at epoch 50


Training:  17%|████                    | 50/299 [1:38:26<7:14:39, 104.74s/epoch]

Epoch 51/300, Loss: 0.3486
Checkpoint saved at epoch 51


Training:  17%|████                    | 51/299 [1:40:12<7:14:51, 105.21s/epoch]

Epoch 52/300, Loss: 0.3485
Checkpoint saved at epoch 52


Training:  17%|████▏                   | 52/299 [1:41:57<7:12:55, 105.17s/epoch]

Epoch 53/300, Loss: 0.3381
Checkpoint saved at epoch 53


Training:  18%|████▎                   | 53/299 [1:43:41<7:09:57, 104.87s/epoch]

Epoch 54/300, Loss: 0.3353
Checkpoint saved at epoch 54


Training:  18%|████▎                   | 54/299 [1:45:24<7:05:13, 104.14s/epoch]

Epoch 55/300, Loss: 0.3334
Checkpoint saved at epoch 55


Training:  18%|████▍                   | 55/299 [1:47:06<7:00:37, 103.43s/epoch]

Epoch 56/300, Loss: 0.3300
Checkpoint saved at epoch 56


Training:  19%|████▍                   | 56/299 [1:48:51<7:01:35, 104.10s/epoch]

Epoch 57/300, Loss: 0.3206
Checkpoint saved at epoch 57


Training:  19%|████▌                   | 57/299 [1:50:35<6:59:48, 104.09s/epoch]

Epoch 58/300, Loss: 0.3145
Checkpoint saved at epoch 58


Training:  19%|████▋                   | 58/299 [1:52:19<6:57:43, 104.00s/epoch]

Epoch 59/300, Loss: 0.3161
Checkpoint saved at epoch 59


Training:  20%|████▋                   | 59/299 [1:54:03<6:55:42, 103.93s/epoch]

Epoch 60/300, Loss: 0.3079
Checkpoint saved at epoch 60


Training:  20%|████▊                   | 60/299 [1:55:46<6:53:15, 103.75s/epoch]

Epoch 61/300, Loss: 0.3033
Checkpoint saved at epoch 61


Training:  20%|████▉                   | 61/299 [1:57:30<6:51:16, 103.68s/epoch]

Epoch 62/300, Loss: 0.3036
Checkpoint saved at epoch 62


Training:  21%|████▉                   | 62/299 [1:59:15<6:51:36, 104.20s/epoch]

Epoch 63/300, Loss: 0.2958
Checkpoint saved at epoch 63


Training:  21%|█████                   | 63/299 [2:00:58<6:47:38, 103.64s/epoch]

Epoch 64/300, Loss: 0.2929
Checkpoint saved at epoch 64


Training:  21%|█████▏                  | 64/299 [2:02:43<6:47:39, 104.08s/epoch]

Epoch 65/300, Loss: 0.2887
Checkpoint saved at epoch 65


Training:  22%|█████▏                  | 65/299 [2:04:27<6:46:11, 104.15s/epoch]

Epoch 66/300, Loss: 0.3137
Checkpoint saved at epoch 66


Training:  22%|█████▎                  | 66/299 [2:06:10<6:43:06, 103.81s/epoch]

Epoch 67/300, Loss: 0.2858
Checkpoint saved at epoch 67


Training:  22%|█████▍                  | 67/299 [2:07:56<6:43:48, 104.43s/epoch]

Epoch 68/300, Loss: 0.2799
Checkpoint saved at epoch 68


Training:  23%|█████▍                  | 68/299 [2:09:39<6:41:08, 104.19s/epoch]

Epoch 69/300, Loss: 0.2750
Checkpoint saved at epoch 69


Training:  23%|█████▌                  | 69/299 [2:11:23<6:38:57, 104.08s/epoch]

Epoch 70/300, Loss: 0.2767
Checkpoint saved at epoch 70


Training:  23%|█████▌                  | 70/299 [2:13:09<6:39:12, 104.60s/epoch]

Epoch 71/300, Loss: 0.2733
Checkpoint saved at epoch 71


Training:  24%|█████▋                  | 71/299 [2:14:56<6:39:52, 105.23s/epoch]

Epoch 72/300, Loss: 0.2692
Checkpoint saved at epoch 72


Training:  24%|█████▊                  | 72/299 [2:16:40<6:37:22, 105.03s/epoch]

Epoch 73/300, Loss: 0.2690
Checkpoint saved at epoch 73


Training:  24%|█████▊                  | 73/299 [2:18:26<6:35:56, 105.12s/epoch]

Epoch 74/300, Loss: 0.2615
Checkpoint saved at epoch 74


Training:  25%|█████▉                  | 74/299 [2:20:10<6:33:33, 104.95s/epoch]

Epoch 75/300, Loss: 0.2596
Checkpoint saved at epoch 75


Training:  25%|██████                  | 75/299 [2:21:55<6:31:22, 104.83s/epoch]

Epoch 76/300, Loss: 0.2588
Checkpoint saved at epoch 76


Training:  25%|██████                  | 76/299 [2:23:39<6:28:43, 104.59s/epoch]

Epoch 77/300, Loss: 0.2534
Checkpoint saved at epoch 77


Training:  26%|██████▏                 | 77/299 [2:25:24<6:27:13, 104.65s/epoch]

Epoch 78/300, Loss: 0.2502
Checkpoint saved at epoch 78


Training:  26%|██████▎                 | 78/299 [2:27:05<6:21:57, 103.70s/epoch]

Epoch 79/300, Loss: 0.2547
Checkpoint saved at epoch 79


Training:  26%|██████▎                 | 79/299 [2:28:49<6:20:26, 103.76s/epoch]

Epoch 80/300, Loss: 0.2472
Checkpoint saved at epoch 80


Training:  27%|██████▍                 | 80/299 [2:30:31<6:16:33, 103.17s/epoch]

Epoch 81/300, Loss: 0.2453
Checkpoint saved at epoch 81


Training:  27%|██████▌                 | 81/299 [2:32:18<6:19:05, 104.34s/epoch]

Epoch 82/300, Loss: 0.2428
Checkpoint saved at epoch 82


Training:  27%|██████▌                 | 82/299 [2:34:03<6:18:22, 104.62s/epoch]

Epoch 83/300, Loss: 0.2444
Checkpoint saved at epoch 83


Training:  28%|██████▋                 | 83/299 [2:35:47<6:15:26, 104.29s/epoch]

Epoch 84/300, Loss: 0.2393
Checkpoint saved at epoch 84


Training:  28%|██████▋                 | 84/299 [2:37:32<6:14:35, 104.54s/epoch]

Epoch 85/300, Loss: 0.2318
Checkpoint saved at epoch 85


Training:  28%|██████▊                 | 85/299 [2:39:16<6:12:42, 104.50s/epoch]

Epoch 86/300, Loss: 0.2314
Checkpoint saved at epoch 86


Training:  29%|██████▉                 | 86/299 [2:41:00<6:10:26, 104.35s/epoch]

Epoch 87/300, Loss: 0.2293
Checkpoint saved at epoch 87


Training:  29%|██████▉                 | 87/299 [2:42:44<6:08:39, 104.34s/epoch]

Epoch 88/300, Loss: 0.2312
Checkpoint saved at epoch 88


Training:  29%|███████                 | 88/299 [2:44:29<6:07:09, 104.41s/epoch]

Epoch 89/300, Loss: 0.2279
Checkpoint saved at epoch 89


Training:  30%|███████▏                | 89/299 [2:46:16<6:07:57, 105.13s/epoch]

Epoch 90/300, Loss: 0.2248
Checkpoint saved at epoch 90


Training:  30%|███████▏                | 90/299 [2:48:01<6:06:41, 105.27s/epoch]

Epoch 91/300, Loss: 0.2332
Checkpoint saved at epoch 91


Training:  30%|███████▎                | 91/299 [2:49:46<6:04:26, 105.13s/epoch]

Epoch 92/300, Loss: 0.2227
Checkpoint saved at epoch 92


Training:  31%|███████▍                | 92/299 [2:52:21<6:54:17, 120.08s/epoch]

Epoch 93/300, Loss: 0.2204
Checkpoint saved at epoch 93


Training:  31%|███████▍                | 93/299 [2:55:20<7:52:27, 137.61s/epoch]

Epoch 94/300, Loss: 0.2167
Checkpoint saved at epoch 94


In [None]:
dot_products = torch.sum(batch1 * batch2, dim=1, keepdim=True)
print("Dot Products Shape with keepdim=True:", dot_products.shape)  # Output: torch.Size([200, 1])
