In [5]:
import os
import json
import argparse
import logging
import pickle
import random

import pandas as pd
import torch
from torch.utils import data
import lightning.pytorch as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger

from proteinclip import data_utils, fasta_utils, swissprot, hparams
from proteinclip import contrastive

Downloading shards: 100%|██████████| 2/2 [00:41<00:00, 20.77s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.81s/it]
Device set to use cuda:0


In [4]:
os.chdir('../')


In [45]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
os.chdir(os.path.join(os.getcwd(), '../'))

In [4]:
struct = pd.read_pickle('../data/structural/structural_embeddings_0_2000.pkl')
del struct

In [10]:
args = {
    "sequence_embeddings": "../data/esm2/sequence_embedding_t33_0_2000.pkl",
    'structural_embed':'../data/structural/structural_embeddings_0_2000.pkl',
    "functional_embeddings_file": "protclip_embed_dataset.parquet",
    'unitnorm': True,
    'batch_size': 256,
    'dim': 128,
    'nhidden': 1,
    'learning_rate':1e-4,
    'out':'training',
    'name':'testrun',
    'max_epochs': 1000,
    
}

In [11]:
pl.seed_everything(seed=6489)


[rank: 0] Seed set to 6489


6489

In [12]:
structural_embeddings = None
with open(args['structural_embed'], "rb") as f:
    structural_embeddings = pickle.load(f)
structural_embeddings = dict(zip(
    structural_embeddings["protein_id"],
    structural_embeddings["embedding"]
))
list(structural_embeddings.items())[0]

('Q15375',
 array([ -5.27637  ,   9.696514 ,  11.407264 , ...,  -5.5494313,
        -11.50502  , -10.721539 ], dtype=float32))

In [13]:
sequence_embeddings = None
with open(args['sequence_embeddings'], "rb") as f:
    sequence_embeddings = pickle.load(f)

list(sequence_embeddings.items())[0]

('Q15375',
 array([ 0.00049137, -0.04332284, -0.01638226, ..., -0.10839974,
         0.02102051,  0.13916047], dtype=float32))

In [14]:
functional_embeddings = None
df = pd.read_parquet(args["functional_embeddings_file"])
functional_embeddings: dict[str, str] = dict(df[["id", "textual_embedding"]].values)


In [16]:
shared_keys = sorted(set(sequence_embeddings.keys()).intersection(set(structural_embeddings.keys())).intersection(functional_embeddings.keys()))


In [17]:
random.seed(42)  # Set your desired seed here
random.shuffle(shared_keys)


In [18]:
type(functional_embeddings)

dict

In [19]:
from proteinclip.data_utils import CLIPDataset3D
dset = data_utils.CLIPDataset3D(
    triples=shared_keys,
    map1=functional_embeddings,
    map2=sequence_embeddings,
    map3=structural_embeddings,
    enforce_unit_norm=args['unitnorm'],
)

Checking for missing/zero embeddings: 100%|██████████| 1997/1997 [00:00<00:00, 13776.90it/s]
INFO:root:Trimmed 0 triples with missing/zero embeddings; 1997 remain.


In [20]:
split_indices = data_utils.random_split(len(dset), [0.9, 0.05, 0.05])
dset_splits = [data.Subset(dset, idx) for idx in split_indices]

In [21]:
# Create data loaders
train_dl, valid_dl, _test_dl = [
    data.DataLoader(
        ds,
        batch_size=args['batch_size'],
        shuffle=(i == 0),
        # drop_last=(i == 0),
        num_workers=8,
        pin_memory=True,
    )
    for i, ds in enumerate(dset_splits)
]




In [22]:
# Define network
input_dim_1 = next(iter(train_dl))["x_1"].shape[-1]
input_dim_2 = next(iter(train_dl))["x_2"].shape[-1]
input_dim_3 = next(iter(train_dl))["x_3"].shape[-1]
model_class = (
    contrastive.ContrastiveEmbedding3D
)
net = model_class(
    input_dim_1=input_dim_1,
    input_dim_2=input_dim_2,
    input_dim_3=input_dim_3,
    shared_dim=args['dim'],
    num_hidden=args['nhidden'],
    lr=args['learning_rate'],
)

In [23]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
early_stop_callback = EarlyStopping(
    monitor='val_loss',         # metric name to monitor
    patience=50,                 # epochs with no improvement after which to stop
    mode='min',                 # 'min' if you want to minimize the metric
    verbose=True                # optional: prints a message when triggered
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename='best-checkpoint'
)


In [24]:
# Train
trainer = pl.Trainer(
    max_epochs=1000,
    accelerator="cuda",
    devices=1,
    enable_progress_bar=True,
    # logger=logger,
    # log_every_n_steps=10,
    deterministic=True,
    callbacks=[early_stop_callback, checkpoint_callback],

)
trainer.fit(net, train_dataloaders=train_dl, val_dataloaders=valid_dl)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/hice1/savunuri3/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float

                                                                           

/home/hice1/savunuri3/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (29) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 29/29 [00:04<00:00,  7.01it/s, v_num=2509879, train_loss=1.190, val_loss=3.220]

Metric val_loss improved. New best score: 3.223


Epoch 1: 100%|██████████| 29/29 [00:00<00:00, 43.26it/s, v_num=2509879, train_loss=0.922, val_loss=2.950]

Metric val_loss improved by 0.277 >= min_delta = 0.0. New best score: 2.946


Epoch 2: 100%|██████████| 29/29 [00:00<00:00, 43.75it/s, v_num=2509879, train_loss=0.923, val_loss=2.860]

Metric val_loss improved by 0.085 >= min_delta = 0.0. New best score: 2.861


Epoch 3: 100%|██████████| 29/29 [00:00<00:00, 44.22it/s, v_num=2509879, train_loss=0.734, val_loss=2.800]

Metric val_loss improved by 0.059 >= min_delta = 0.0. New best score: 2.802


Epoch 4: 100%|██████████| 29/29 [00:00<00:00, 43.92it/s, v_num=2509879, train_loss=0.779, val_loss=2.720]

Metric val_loss improved by 0.079 >= min_delta = 0.0. New best score: 2.723


Epoch 5: 100%|██████████| 29/29 [00:00<00:00, 44.16it/s, v_num=2509879, train_loss=0.525, val_loss=2.720]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.721


Epoch 6: 100%|██████████| 29/29 [00:00<00:00, 43.30it/s, v_num=2509879, train_loss=0.433, val_loss=2.670]

Metric val_loss improved by 0.055 >= min_delta = 0.0. New best score: 2.666


Epoch 7: 100%|██████████| 29/29 [00:00<00:00, 44.50it/s, v_num=2509879, train_loss=0.702, val_loss=2.660]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.660


Epoch 9: 100%|██████████| 29/29 [00:00<00:00, 45.08it/s, v_num=2509879, train_loss=0.562, val_loss=2.630]

Metric val_loss improved by 0.027 >= min_delta = 0.0. New best score: 2.633


Epoch 10: 100%|██████████| 29/29 [00:00<00:00, 44.45it/s, v_num=2509879, train_loss=0.679, val_loss=2.620]

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 2.616


Epoch 11: 100%|██████████| 29/29 [00:00<00:00, 41.72it/s, v_num=2509879, train_loss=0.894, val_loss=2.560]

Metric val_loss improved by 0.055 >= min_delta = 0.0. New best score: 2.562


Epoch 20: 100%|██████████| 29/29 [00:00<00:00, 44.65it/s, v_num=2509879, train_loss=0.401, val_loss=2.560]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 2.558


Epoch 25: 100%|██████████| 29/29 [00:00<00:00, 45.27it/s, v_num=2509879, train_loss=0.404, val_loss=2.540]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 2.537


Epoch 27: 100%|██████████| 29/29 [00:00<00:00, 44.48it/s, v_num=2509879, train_loss=0.285, val_loss=2.530]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 2.535


Epoch 28: 100%|██████████| 29/29 [00:00<00:00, 43.11it/s, v_num=2509879, train_loss=0.463, val_loss=2.530]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.530


Epoch 30: 100%|██████████| 29/29 [00:00<00:00, 45.25it/s, v_num=2509879, train_loss=0.433, val_loss=2.530]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.529


Epoch 32: 100%|██████████| 29/29 [00:00<00:00, 45.74it/s, v_num=2509879, train_loss=0.235, val_loss=2.520]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 2.516


Epoch 36: 100%|██████████| 29/29 [00:00<00:00, 42.02it/s, v_num=2509879, train_loss=0.273, val_loss=2.500]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 2.504


Epoch 39: 100%|██████████| 29/29 [00:00<00:00, 44.26it/s, v_num=2509879, train_loss=0.584, val_loss=2.480]

Metric val_loss improved by 0.024 >= min_delta = 0.0. New best score: 2.480


Epoch 48: 100%|██████████| 29/29 [00:00<00:00, 42.29it/s, v_num=2509879, train_loss=0.265, val_loss=2.470]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.475


Epoch 50: 100%|██████████| 29/29 [00:00<00:00, 45.35it/s, v_num=2509879, train_loss=0.195, val_loss=2.470]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.467


Epoch 57: 100%|██████████| 29/29 [00:00<00:00, 45.03it/s, v_num=2509879, train_loss=0.284, val_loss=2.460]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 2.457


Epoch 62: 100%|██████████| 29/29 [00:00<00:00, 44.57it/s, v_num=2509879, train_loss=0.224, val_loss=2.440]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 2.442


Epoch 66: 100%|██████████| 29/29 [00:00<00:00, 44.62it/s, v_num=2509879, train_loss=0.301, val_loss=2.430]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.433


Epoch 67: 100%|██████████| 29/29 [00:00<00:00, 44.54it/s, v_num=2509879, train_loss=0.317, val_loss=2.430]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.427


Epoch 68: 100%|██████████| 29/29 [00:00<00:00, 44.16it/s, v_num=2509879, train_loss=0.203, val_loss=2.420]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.422


Epoch 69: 100%|██████████| 29/29 [00:00<00:00, 43.16it/s, v_num=2509879, train_loss=0.206, val_loss=2.410]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.414


Epoch 72: 100%|██████████| 29/29 [00:00<00:00, 45.78it/s, v_num=2509879, train_loss=0.312, val_loss=2.400]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 2.397


Epoch 77: 100%|██████████| 29/29 [00:00<00:00, 44.62it/s, v_num=2509879, train_loss=0.284, val_loss=2.390]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.391


Epoch 78: 100%|██████████| 29/29 [00:00<00:00, 44.31it/s, v_num=2509879, train_loss=0.315, val_loss=2.390]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.386


Epoch 85: 100%|██████████| 29/29 [00:00<00:00, 43.64it/s, v_num=2509879, train_loss=0.206, val_loss=2.370]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 2.373


Epoch 88: 100%|██████████| 29/29 [00:00<00:00, 45.55it/s, v_num=2509879, train_loss=0.248, val_loss=2.360]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 2.362


Epoch 92: 100%|██████████| 29/29 [00:00<00:00, 45.71it/s, v_num=2509879, train_loss=0.227, val_loss=2.350]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 2.349


Epoch 96: 100%|██████████| 29/29 [00:00<00:00, 45.16it/s, v_num=2509879, train_loss=0.157, val_loss=2.340]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.340


Epoch 99: 100%|██████████| 29/29 [00:00<00:00, 44.97it/s, v_num=2509879, train_loss=0.171, val_loss=2.340]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 2.338


Epoch 105: 100%|██████████| 29/29 [00:00<00:00, 45.72it/s, v_num=2509879, train_loss=0.168, val_loss=2.330]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.330


Epoch 114: 100%|██████████| 29/29 [00:00<00:00, 46.05it/s, v_num=2509879, train_loss=0.305, val_loss=2.310]

Metric val_loss improved by 0.023 >= min_delta = 0.0. New best score: 2.307


Epoch 119: 100%|██████████| 29/29 [00:00<00:00, 45.77it/s, v_num=2509879, train_loss=0.196, val_loss=2.300]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.301


Epoch 125: 100%|██████████| 29/29 [00:00<00:00, 45.75it/s, v_num=2509879, train_loss=0.170, val_loss=2.300]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.296


Epoch 130: 100%|██████████| 29/29 [00:00<00:00, 46.20it/s, v_num=2509879, train_loss=0.214, val_loss=2.290]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.289


Epoch 131: 100%|██████████| 29/29 [00:00<00:00, 43.69it/s, v_num=2509879, train_loss=0.114, val_loss=2.280]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 2.279


Epoch 132: 100%|██████████| 29/29 [00:00<00:00, 43.39it/s, v_num=2509879, train_loss=0.144, val_loss=2.270]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 2.268


Epoch 134: 100%|██████████| 29/29 [00:00<00:00, 45.60it/s, v_num=2509879, train_loss=0.174, val_loss=2.250]

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 2.252


Epoch 143: 100%|██████████| 29/29 [00:00<00:00, 45.81it/s, v_num=2509879, train_loss=0.169, val_loss=2.240]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 2.245


Epoch 145: 100%|██████████| 29/29 [00:00<00:00, 43.77it/s, v_num=2509879, train_loss=0.0997, val_loss=2.240]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 2.243


Epoch 147: 100%|██████████| 29/29 [00:00<00:00, 43.61it/s, v_num=2509879, train_loss=0.212, val_loss=2.230] 

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 2.234


Epoch 148: 100%|██████████| 29/29 [00:00<00:00, 43.78it/s, v_num=2509879, train_loss=0.0949, val_loss=2.230]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 2.230


Epoch 150: 100%|██████████| 29/29 [00:00<00:00, 45.50it/s, v_num=2509879, train_loss=0.217, val_loss=2.230] 

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.225


Epoch 155: 100%|██████████| 29/29 [00:00<00:00, 45.47it/s, v_num=2509879, train_loss=0.130, val_loss=2.210] 

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 2.207


Epoch 163: 100%|██████████| 29/29 [00:00<00:00, 46.39it/s, v_num=2509879, train_loss=0.105, val_loss=2.210] 

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.206


Epoch 168: 100%|██████████| 29/29 [00:00<00:00, 43.67it/s, v_num=2509879, train_loss=0.131, val_loss=2.190] 

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 2.194


Epoch 170: 100%|██████████| 29/29 [00:00<00:00, 44.29it/s, v_num=2509879, train_loss=0.167, val_loss=2.190] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.186


Epoch 179: 100%|██████████| 29/29 [00:00<00:00, 45.61it/s, v_num=2509879, train_loss=0.085, val_loss=2.170] 

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 2.167


Epoch 184: 100%|██████████| 29/29 [00:00<00:00, 45.98it/s, v_num=2509879, train_loss=0.0917, val_loss=2.140]

Metric val_loss improved by 0.026 >= min_delta = 0.0. New best score: 2.140


Epoch 194: 100%|██████████| 29/29 [00:00<00:00, 44.59it/s, v_num=2509879, train_loss=0.148, val_loss=2.140] 

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 2.138


Epoch 198: 100%|██████████| 29/29 [00:00<00:00, 45.24it/s, v_num=2509879, train_loss=0.0865, val_loss=2.140]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.137


Epoch 199: 100%|██████████| 29/29 [00:00<00:00, 42.62it/s, v_num=2509879, train_loss=0.104, val_loss=2.130] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.130


Epoch 201: 100%|██████████| 29/29 [00:00<00:00, 45.54it/s, v_num=2509879, train_loss=0.0798, val_loss=2.130]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 2.129


Epoch 203: 100%|██████████| 29/29 [00:00<00:00, 46.12it/s, v_num=2509879, train_loss=0.179, val_loss=2.130] 

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.129


Epoch 204: 100%|██████████| 29/29 [00:00<00:00, 41.71it/s, v_num=2509879, train_loss=0.0709, val_loss=2.120]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 2.115


Epoch 211: 100%|██████████| 29/29 [00:00<00:00, 44.66it/s, v_num=2509879, train_loss=0.110, val_loss=2.100] 

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 2.098


Epoch 224: 100%|██████████| 29/29 [00:00<00:00, 45.80it/s, v_num=2509879, train_loss=0.0811, val_loss=2.090]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 2.088


Epoch 230: 100%|██████████| 29/29 [00:00<00:00, 45.57it/s, v_num=2509879, train_loss=0.0414, val_loss=2.070]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 2.073


Epoch 244: 100%|██████████| 29/29 [00:00<00:00, 46.00it/s, v_num=2509879, train_loss=0.0399, val_loss=2.070]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 2.070


Epoch 245: 100%|██████████| 29/29 [00:00<00:00, 44.13it/s, v_num=2509879, train_loss=0.0514, val_loss=2.060]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.063


Epoch 254: 100%|██████████| 29/29 [00:00<00:00, 44.78it/s, v_num=2509879, train_loss=0.0297, val_loss=2.060]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.062


Epoch 255: 100%|██████████| 29/29 [00:00<00:00, 44.35it/s, v_num=2509879, train_loss=0.0572, val_loss=2.040]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 2.042


Epoch 258: 100%|██████████| 29/29 [00:00<00:00, 44.75it/s, v_num=2509879, train_loss=0.0473, val_loss=2.040]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.042


Epoch 264: 100%|██████████| 29/29 [00:00<00:00, 44.73it/s, v_num=2509879, train_loss=0.0391, val_loss=2.040]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.040


Epoch 269: 100%|██████████| 29/29 [00:00<00:00, 45.85it/s, v_num=2509879, train_loss=0.0604, val_loss=2.030]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 2.029


Epoch 281: 100%|██████████| 29/29 [00:00<00:00, 45.92it/s, v_num=2509879, train_loss=0.074, val_loss=2.020] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 2.021


Epoch 286: 100%|██████████| 29/29 [00:00<00:00, 45.98it/s, v_num=2509879, train_loss=0.0514, val_loss=2.020]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 2.016


Epoch 293: 100%|██████████| 29/29 [00:00<00:00, 45.58it/s, v_num=2509879, train_loss=0.0196, val_loss=2.020]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 2.015


Epoch 302: 100%|██████████| 29/29 [00:00<00:00, 44.95it/s, v_num=2509879, train_loss=0.0231, val_loss=2.010]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 2.010


Epoch 305: 100%|██████████| 29/29 [00:00<00:00, 43.13it/s, v_num=2509879, train_loss=0.0175, val_loss=2.000]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 2.004


Epoch 307: 100%|██████████| 29/29 [00:00<00:00, 44.61it/s, v_num=2509879, train_loss=0.0226, val_loss=1.990]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 1.990


Epoch 313: 100%|██████████| 29/29 [00:00<00:00, 45.29it/s, v_num=2509879, train_loss=0.0328, val_loss=1.980]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 1.976


Epoch 315: 100%|██████████| 29/29 [00:00<00:00, 46.75it/s, v_num=2509879, train_loss=0.0112, val_loss=1.970]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 1.970


Epoch 351: 100%|██████████| 29/29 [00:00<00:00, 45.93it/s, v_num=2509879, train_loss=0.0118, val_loss=1.970] 

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.966


Epoch 354: 100%|██████████| 29/29 [00:00<00:00, 45.30it/s, v_num=2509879, train_loss=0.0204, val_loss=1.960] 

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.965


Epoch 403: 100%|██████████| 29/29 [00:00<00:00, 43.32it/s, v_num=2509879, train_loss=0.0106, val_loss=1.960] 

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 1.959


Epoch 411: 100%|██████████| 29/29 [00:00<00:00, 45.62it/s, v_num=2509879, train_loss=0.264, val_loss=1.960]  

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.955


Epoch 412: 100%|██████████| 29/29 [00:00<00:00, 42.04it/s, v_num=2509879, train_loss=0.0087, val_loss=1.940]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.942


Epoch 462: 100%|██████████| 29/29 [00:00<00:00, 38.22it/s, v_num=2509879, train_loss=0.00296, val_loss=2.010]

Monitored metric val_loss did not improve in the last 50 records. Best score: 1.942. Signaling Trainer to stop.


Epoch 462: 100%|██████████| 29/29 [00:00<00:00, 38.13it/s, v_num=2509879, train_loss=0.00296, val_loss=2.010]


In [26]:



# Export model as ONNX files
contrastive.model_to_onnx(
    net.project_1,
    "project3d_1.onnx",
    input_shape=(input_dim_1,),
)
contrastive.model_to_onnx(
    net.project_2,
    "project3d_2.onnx",
    input_shape=(input_dim_2,),
)
contrastive.model_to_onnx(
    net.project_3,
    "project3d_3.onnx",
    input_shape=(input_dim_3,),
)


[1;31m2025-04-21 22:52:53.612941261 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3066267, index: 36, mask: {37, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-21 22:52:53.616865080 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3066236, index: 5, mask: {6, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-21 22:52:53.618623127 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3066268, index: 37, mask: {38, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-21 22:52:53.620847119 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3066234, index: 3, mask: {4, }, error code: 22 e