In [4]:
import os
import json
import argparse
import logging
import pickle
import random

import pandas as pd
import torch
from torch.utils import data
import lightning.pytorch as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger

from proteinclip import data_utils, fasta_utils, swissprot, hparams
from proteinclip import contrastive

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.96s/it]
Device set to use cuda:0


In [3]:
os.chdir(os.path.join(os.getcwd(), '../'))

In [6]:
struct = pd.read_pickle('../structural_embeddings_0_2000.pkl')

In [43]:
args = {
    'structural_embed':'../structural_embeddings_0_2000.pkl',
    'proteinclip_embed':'protclip_embed_dataset.parquet',
    'unitnorm': True,
    'batch_size': 64,
    'dim': 128,
    'nhidden': 1,
    'learning_rate':1e-4,
    'out':'training',
    'name':'testrun',
    'max_epochs': 1000,
    
}

In [None]:
pl.seed_everything(seed=6489)


In [10]:
structural_embeddings = None
with open(args['structural_embed'], "rb") as f:
    structural_embeddings = pickle.load(f)
structural_embeddings = dict(zip(
    structural_embeddings["protein_id"],
    structural_embeddings["embedding"]
))

In [15]:
proteinclip_embeddings = pd.read_parquet(args['proteinclip_embed'])
proteinclip_embeddings = dict(zip(
    proteinclip_embeddings["id"],
    proteinclip_embeddings["proteinclip_embed"]
))

In [16]:
shared_keys = sorted(set(proteinclip_embeddings.keys()).intersection(set(structural_embeddings.keys())))


In [19]:
random.seed(42)  # Set your desired seed here
random.shuffle(shared_keys)


In [25]:
dset = data_utils.CLIPDataset(
    pairs=shared_keys,
    map1=proteinclip_embeddings,
    map2=structural_embeddings,
    enforce_unit_norm=args['unitnorm'],
)

Checking for missing/zero embeddings: 100%|██████████| 1997/1997 [00:00<00:00, 23063.08it/s]
INFO:root:Trimmed 0 pairs with missing/zero embeddings, 1997 remain.


In [27]:
split_indices = data_utils.random_split(len(dset), [0.9, 0.05, 0.05])
dset_splits = [data.Subset(dset, idx) for idx in split_indices]

In [31]:
# Create data loaders
train_dl, valid_dl, _test_dl = [
    data.DataLoader(
        ds,
        batch_size=args['batch_size'],
        shuffle=(i == 0),
        # drop_last=(i == 0),
        num_workers=8,
        pin_memory=True,
    )
    for i, ds in enumerate(dset_splits)
]




In [36]:
# Define network
input_dim_1 = next(iter(train_dl))["x_1"].shape[-1]
input_dim_2 = next(iter(train_dl))["x_2"].shape[-1]
model_class = (
    contrastive.ContrastiveEmbedding
)
net = model_class(
    input_dim_1=input_dim_1,
    input_dim_2=input_dim_2,
    shared_dim=args['dim'],
    num_hidden=args['nhidden'],
    lr=args['learning_rate'],
)



In [None]:
# # Define logger, write configuration files and data splits
# logger = CSVLogger(save_dir=args['out'], name=args['name'])
# # logger.log_hyperparams(hyperparameters.as_dict())
# write_split_identifiers(
#     train_ids=[dset.pairs[i] for i in split_indices[0]],
#     valid_ids=[dset.pairs[i] for i in split_indices[1]],
#     test_ids=[dset.pairs[i] for i in split_indices[2]],
#     out_file=os.path.join(logger.log_dir, "data_splits.json"),
# )
# net.write_config_json(os.path.join(logger.log_dir, "model_config.json"))
# with open(os.path.join(logger.log_dir, "training_config.json"), "w") as sink:
#     json.dump(vars(args), sink, indent=4)


In [60]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
early_stop_callback = EarlyStopping(
    monitor='val_loss',         # metric name to monitor
    patience=50,                 # epochs with no improvement after which to stop
    mode='min',                 # 'min' if you want to minimize the metric
    verbose=True                # optional: prints a message when triggered
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename='best-checkpoint'
)


In [61]:
# Train
trainer = pl.Trainer(
    max_epochs=1000,
    accelerator="cuda",
    devices=1,
    enable_progress_bar=True,
    # logger=logger,
    # log_every_n_steps=10,
    deterministic=True,
    callbacks=[early_stop_callback, checkpoint_callback],

)
trainer.fit(net, train_dataloaders=train_dl, val_dataloaders=valid_dl)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type | Params | Mode 
----------------------------------------------
0 | project_1    | MLP  | 33.3 K | train
1 | project_2    | MLP  | 1.2 M  | train
  | other params | n/a  | 1      | n/a  
----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.865     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0: 100%|██████████| 29/29 [00:00<00:00, 61.43it/s, v_num=6, train_loss=0.215, val_loss=1.470]

Metric val_loss improved. New best score: 1.472


Epoch 2: 100%|██████████| 29/29 [00:00<00:00, 40.72it/s, v_num=6, train_loss=0.0965, val_loss=1.440]

Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 1.442


Epoch 5: 100%|██████████| 29/29 [00:00<00:00, 42.70it/s, v_num=6, train_loss=0.150, val_loss=1.440] 

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.440


Epoch 11: 100%|██████████| 29/29 [00:00<00:00, 40.79it/s, v_num=6, train_loss=0.0821, val_loss=1.420]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 1.420


Epoch 15: 100%|██████████| 29/29 [00:00<00:00, 38.90it/s, v_num=6, train_loss=0.0564, val_loss=1.410]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.407


Epoch 18: 100%|██████████| 29/29 [00:00<00:00, 40.07it/s, v_num=6, train_loss=0.0877, val_loss=1.390]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 1.393


Epoch 20: 100%|██████████| 29/29 [00:00<00:00, 38.48it/s, v_num=6, train_loss=0.103, val_loss=1.390] 

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 1.387


Epoch 21: 100%|██████████| 29/29 [00:00<00:00, 39.91it/s, v_num=6, train_loss=0.202, val_loss=1.390]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.386


Epoch 24: 100%|██████████| 29/29 [00:00<00:00, 39.75it/s, v_num=6, train_loss=0.108, val_loss=1.380] 

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 1.382


Epoch 25: 100%|██████████| 29/29 [00:00<00:00, 38.29it/s, v_num=6, train_loss=0.0832, val_loss=1.380]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 1.378


Epoch 26: 100%|██████████| 29/29 [00:00<00:00, 40.26it/s, v_num=6, train_loss=0.0838, val_loss=1.340]

Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 1.336


Epoch 33: 100%|██████████| 29/29 [00:00<00:00, 41.07it/s, v_num=6, train_loss=0.108, val_loss=1.320] 

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 1.325


Epoch 35: 100%|██████████| 29/29 [00:00<00:00, 39.65it/s, v_num=6, train_loss=0.0877, val_loss=1.310]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 1.314


Epoch 40: 100%|██████████| 29/29 [00:00<00:00, 38.87it/s, v_num=6, train_loss=0.0791, val_loss=1.290]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 1.292


Epoch 45: 100%|██████████| 29/29 [00:00<00:00, 40.13it/s, v_num=6, train_loss=0.245, val_loss=1.280] 

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 1.282


Epoch 49: 100%|██████████| 29/29 [00:00<00:00, 39.88it/s, v_num=6, train_loss=0.0586, val_loss=1.270]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.269


Epoch 50: 100%|██████████| 29/29 [00:00<00:00, 40.05it/s, v_num=6, train_loss=0.0766, val_loss=1.260]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 1.260


Epoch 51: 100%|██████████| 29/29 [00:00<00:00, 40.32it/s, v_num=6, train_loss=0.081, val_loss=1.250] 

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 1.251


Epoch 53: 100%|██████████| 29/29 [00:00<00:00, 38.27it/s, v_num=6, train_loss=0.0423, val_loss=1.230]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 1.232


Epoch 58: 100%|██████████| 29/29 [00:00<00:00, 39.72it/s, v_num=6, train_loss=0.111, val_loss=1.210] 

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 1.214


Epoch 64: 100%|██████████| 29/29 [00:00<00:00, 39.39it/s, v_num=6, train_loss=0.0607, val_loss=1.190]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 1.186


Epoch 65: 100%|██████████| 29/29 [00:00<00:00, 38.15it/s, v_num=6, train_loss=0.0454, val_loss=1.180]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 1.178


Epoch 91: 100%|██████████| 29/29 [16:31<00:00,  0.03it/s, v_num=1, train_loss=0.133, val_loss=1.790] 
Epoch 91: 100%|██████████| 29/29 [16:31<00:00,  0.03it/s, v_num=1, train_loss=0.133, val_loss=1.790]
Epoch 75: 100%|██████████| 29/29 [00:01<00:00, 21.41it/s, v_num=6, train_loss=0.0982, val_loss=1.160]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 1.163


Epoch 78: 100%|██████████| 29/29 [00:00<00:00, 38.74it/s, v_num=6, train_loss=0.0539, val_loss=1.150]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.151


Epoch 80: 100%|██████████| 29/29 [00:00<00:00, 40.43it/s, v_num=6, train_loss=0.0377, val_loss=1.150]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 1.146


Epoch 82: 100%|██████████| 29/29 [00:00<00:00, 40.18it/s, v_num=6, train_loss=0.0599, val_loss=1.140]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.144


Epoch 83: 100%|██████████| 29/29 [00:00<00:00, 39.18it/s, v_num=6, train_loss=0.0726, val_loss=1.130]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 1.133


Epoch 84: 100%|██████████| 29/29 [00:00<00:00, 37.82it/s, v_num=6, train_loss=0.0566, val_loss=1.130]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.129


Epoch 85: 100%|██████████| 29/29 [00:00<00:00, 39.25it/s, v_num=6, train_loss=0.0421, val_loss=1.120]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.116


Epoch 90: 100%|██████████| 29/29 [00:00<00:00, 40.08it/s, v_num=6, train_loss=0.0871, val_loss=1.110]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 1.107


Epoch 92: 100%|██████████| 29/29 [00:00<00:00, 39.58it/s, v_num=6, train_loss=0.0212, val_loss=1.110]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.106


Epoch 96: 100%|██████████| 29/29 [00:00<00:00, 40.13it/s, v_num=6, train_loss=0.0241, val_loss=1.090]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.094


Epoch 97: 100%|██████████| 29/29 [00:00<00:00, 41.14it/s, v_num=6, train_loss=0.0412, val_loss=1.090]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.090


Epoch 98: 100%|██████████| 29/29 [00:00<00:00, 39.36it/s, v_num=6, train_loss=0.129, val_loss=1.070] 

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 1.075


Epoch 99: 100%|██████████| 29/29 [00:00<00:00, 42.43it/s, v_num=6, train_loss=0.0475, val_loss=1.060]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 1.061


Epoch 107: 100%|██████████| 29/29 [00:00<00:00, 42.44it/s, v_num=6, train_loss=0.103, val_loss=1.040] 

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 1.043


Epoch 108: 100%|██████████| 29/29 [00:00<00:00, 41.93it/s, v_num=6, train_loss=0.035, val_loss=1.040]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.042


Epoch 115: 100%|██████████| 29/29 [00:00<00:00, 46.12it/s, v_num=6, train_loss=0.0276, val_loss=1.040]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 1.042


Epoch 117: 100%|██████████| 29/29 [00:00<00:00, 40.41it/s, v_num=6, train_loss=0.0202, val_loss=1.040]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 1.035


Epoch 119: 100%|██████████| 29/29 [00:00<00:00, 40.37it/s, v_num=6, train_loss=0.0484, val_loss=1.010]

Metric val_loss improved by 0.025 >= min_delta = 0.0. New best score: 1.010


Epoch 122: 100%|██████████| 29/29 [00:00<00:00, 41.75it/s, v_num=6, train_loss=0.0514, val_loss=0.994]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.994


Epoch 129: 100%|██████████| 29/29 [00:00<00:00, 42.13it/s, v_num=6, train_loss=0.019, val_loss=0.990] 

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.990


Epoch 131: 100%|██████████| 29/29 [00:00<00:00, 40.53it/s, v_num=6, train_loss=0.0597, val_loss=0.971]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.971


Epoch 133: 100%|██████████| 29/29 [00:00<00:00, 41.64it/s, v_num=6, train_loss=0.0353, val_loss=0.965]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.965


Epoch 137: 100%|██████████| 29/29 [00:00<00:00, 40.54it/s, v_num=6, train_loss=0.0894, val_loss=0.935]

Metric val_loss improved by 0.030 >= min_delta = 0.0. New best score: 0.935


Epoch 148: 100%|██████████| 29/29 [00:00<00:00, 42.04it/s, v_num=6, train_loss=0.0449, val_loss=0.932] 

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.932


Epoch 151: 100%|██████████| 29/29 [00:00<00:00, 41.48it/s, v_num=6, train_loss=0.0102, val_loss=0.918]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.918


Epoch 157: 100%|██████████| 29/29 [00:00<00:00, 41.62it/s, v_num=6, train_loss=0.0583, val_loss=0.908]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.908


Epoch 158: 100%|██████████| 29/29 [00:00<00:00, 41.25it/s, v_num=6, train_loss=0.065, val_loss=0.900] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.900


Epoch 164: 100%|██████████| 29/29 [00:00<00:00, 42.42it/s, v_num=6, train_loss=0.0839, val_loss=0.891] 

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.891


Epoch 165: 100%|██████████| 29/29 [00:00<00:00, 41.25it/s, v_num=6, train_loss=0.0194, val_loss=0.873]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 0.873


Epoch 171: 100%|██████████| 29/29 [00:00<00:00, 42.31it/s, v_num=6, train_loss=0.0171, val_loss=0.864] 

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.864


Epoch 175: 100%|██████████| 29/29 [00:00<00:00, 41.32it/s, v_num=6, train_loss=0.088, val_loss=0.844] 

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.844


Epoch 181: 100%|██████████| 29/29 [00:00<00:00, 42.09it/s, v_num=6, train_loss=0.0124, val_loss=0.838] 

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.838


Epoch 185: 100%|██████████| 29/29 [00:00<00:00, 41.96it/s, v_num=6, train_loss=0.0149, val_loss=0.831] 

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.831


Epoch 190: 100%|██████████| 29/29 [00:00<00:00, 43.34it/s, v_num=6, train_loss=0.0285, val_loss=0.783]

Metric val_loss improved by 0.048 >= min_delta = 0.0. New best score: 0.783


Epoch 201: 100%|██████████| 29/29 [00:00<00:00, 42.41it/s, v_num=6, train_loss=0.0261, val_loss=0.779] 

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.779


Epoch 206: 100%|██████████| 29/29 [00:00<00:00, 42.59it/s, v_num=6, train_loss=0.0188, val_loss=0.749]

Metric val_loss improved by 0.030 >= min_delta = 0.0. New best score: 0.749


Epoch 216: 100%|██████████| 29/29 [00:00<00:00, 42.62it/s, v_num=6, train_loss=0.012, val_loss=0.739]  

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.739


Epoch 217: 100%|██████████| 29/29 [00:00<00:00, 41.66it/s, v_num=6, train_loss=0.0575, val_loss=0.730]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.730


Epoch 226: 100%|██████████| 29/29 [00:00<00:00, 47.22it/s, v_num=6, train_loss=0.00866, val_loss=0.727]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.727


Epoch 227: 100%|██████████| 29/29 [00:00<00:00, 41.11it/s, v_num=6, train_loss=0.00802, val_loss=0.721]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.721


Epoch 238: 100%|██████████| 29/29 [00:00<00:00, 41.03it/s, v_num=6, train_loss=0.00315, val_loss=0.705]

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.705


Epoch 239: 100%|██████████| 29/29 [00:00<00:00, 41.34it/s, v_num=6, train_loss=0.00775, val_loss=0.683]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 0.683


Epoch 247: 100%|██████████| 29/29 [00:00<00:00, 43.50it/s, v_num=6, train_loss=0.0176, val_loss=0.676] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.676


Epoch 249: 100%|██████████| 29/29 [00:00<00:00, 40.68it/s, v_num=6, train_loss=0.00884, val_loss=0.673]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.673


Epoch 253: 100%|██████████| 29/29 [00:00<00:00, 41.59it/s, v_num=6, train_loss=0.0152, val_loss=0.659] 

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.659


Epoch 254: 100%|██████████| 29/29 [00:00<00:00, 42.08it/s, v_num=6, train_loss=0.00481, val_loss=0.652]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.652


Epoch 270: 100%|██████████| 29/29 [00:00<00:00, 41.86it/s, v_num=6, train_loss=0.00963, val_loss=0.638]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.638


Epoch 273: 100%|██████████| 29/29 [00:00<00:00, 41.67it/s, v_num=6, train_loss=0.00854, val_loss=0.629]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.629


Epoch 279: 100%|██████████| 29/29 [00:00<00:00, 43.39it/s, v_num=6, train_loss=0.00915, val_loss=0.626]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.626


Epoch 281: 100%|██████████| 29/29 [00:00<00:00, 41.88it/s, v_num=6, train_loss=0.00442, val_loss=0.616]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.616


Epoch 286: 100%|██████████| 29/29 [00:00<00:00, 42.55it/s, v_num=6, train_loss=0.00265, val_loss=0.608]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.608


Epoch 288: 100%|██████████| 29/29 [00:00<00:00, 42.07it/s, v_num=6, train_loss=0.00398, val_loss=0.593]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.593


Epoch 305: 100%|██████████| 29/29 [00:00<00:00, 41.61it/s, v_num=6, train_loss=0.00267, val_loss=0.589]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.589


Epoch 311: 100%|██████████| 29/29 [00:00<00:00, 41.85it/s, v_num=6, train_loss=0.00335, val_loss=0.563]

Metric val_loss improved by 0.026 >= min_delta = 0.0. New best score: 0.563


Epoch 316: 100%|██████████| 29/29 [00:00<00:00, 42.42it/s, v_num=6, train_loss=0.00174, val_loss=0.547]

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.547


Epoch 333: 100%|██████████| 29/29 [00:00<00:00, 43.22it/s, v_num=6, train_loss=0.00385, val_loss=0.544]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.544


Epoch 338: 100%|██████████| 29/29 [00:00<00:00, 40.50it/s, v_num=6, train_loss=0.00793, val_loss=0.543]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.543


Epoch 344: 100%|██████████| 29/29 [00:00<00:00, 41.61it/s, v_num=6, train_loss=0.00259, val_loss=0.543]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.543


Epoch 348: 100%|██████████| 29/29 [00:00<00:00, 41.25it/s, v_num=6, train_loss=0.00147, val_loss=0.528]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.528


Epoch 349: 100%|██████████| 29/29 [00:00<00:00, 42.34it/s, v_num=6, train_loss=0.00164, val_loss=0.521]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.521


Epoch 353: 100%|██████████| 29/29 [00:00<00:00, 41.36it/s, v_num=6, train_loss=0.00363, val_loss=0.511]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.511


Epoch 365: 100%|██████████| 29/29 [00:00<00:00, 53.34it/s, v_num=6, train_loss=0.0014, val_loss=0.505]  

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.505


Epoch 370: 100%|██████████| 29/29 [00:00<00:00, 52.42it/s, v_num=6, train_loss=0.00145, val_loss=0.503]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.503


Epoch 372: 100%|██████████| 29/29 [00:00<00:00, 52.70it/s, v_num=6, train_loss=0.00599, val_loss=0.491]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.491


Epoch 386: 100%|██████████| 29/29 [00:00<00:00, 49.92it/s, v_num=6, train_loss=0.00086, val_loss=0.484] 

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.484


Epoch 391: 100%|██████████| 29/29 [00:00<00:00, 53.21it/s, v_num=6, train_loss=0.000494, val_loss=0.472]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.472


Epoch 401: 100%|██████████| 29/29 [00:00<00:00, 52.60it/s, v_num=6, train_loss=0.0124, val_loss=0.463]  

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.463


Epoch 405: 100%|██████████| 29/29 [00:00<00:00, 53.21it/s, v_num=6, train_loss=0.00172, val_loss=0.446]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.446


Epoch 434: 100%|██████████| 29/29 [00:00<00:00, 50.40it/s, v_num=6, train_loss=0.000323, val_loss=0.442]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.442


Epoch 439: 100%|██████████| 29/29 [00:00<00:00, 53.16it/s, v_num=6, train_loss=0.00462, val_loss=0.436] 

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.436


Epoch 446: 100%|██████████| 29/29 [00:00<00:00, 52.09it/s, v_num=6, train_loss=0.000746, val_loss=0.435]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.435


Epoch 447: 100%|██████████| 29/29 [00:00<00:00, 51.14it/s, v_num=6, train_loss=0.000474, val_loss=0.427]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.427


Epoch 459: 100%|██████████| 29/29 [00:00<00:00, 52.15it/s, v_num=6, train_loss=0.00488, val_loss=0.410] 

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.410


Epoch 484: 100%|██████████| 29/29 [00:00<00:00, 49.76it/s, v_num=6, train_loss=0.00227, val_loss=0.407] 

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.407


Epoch 486: 100%|██████████| 29/29 [00:00<00:00, 48.89it/s, v_num=6, train_loss=0.00225, val_loss=0.396]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.396


Epoch 490: 100%|██████████| 29/29 [00:00<00:00, 49.10it/s, v_num=6, train_loss=0.00046, val_loss=0.390] 

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.390


Epoch 501: 100%|██████████| 29/29 [00:00<00:00, 48.82it/s, v_num=6, train_loss=0.000198, val_loss=0.387]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.387


Epoch 505: 100%|██████████| 29/29 [00:00<00:00, 49.09it/s, v_num=6, train_loss=0.0004, val_loss=0.372]  

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.372


Epoch 523: 100%|██████████| 29/29 [00:00<00:00, 45.29it/s, v_num=6, train_loss=9.85e-5, val_loss=0.363] 

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.363


Epoch 542: 100%|██████████| 29/29 [00:00<00:00, 45.04it/s, v_num=6, train_loss=0.000341, val_loss=0.351]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.351


Epoch 556: 100%|██████████| 29/29 [00:00<00:00, 44.40it/s, v_num=6, train_loss=0.000203, val_loss=0.344]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.344


Epoch 563: 100%|██████████| 29/29 [00:00<00:00, 43.38it/s, v_num=6, train_loss=0.000516, val_loss=0.343]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.343


Epoch 564: 100%|██████████| 29/29 [00:00<00:00, 46.15it/s, v_num=6, train_loss=3.93e-5, val_loss=0.332] 

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.332


Epoch 582: 100%|██████████| 29/29 [00:00<00:00, 44.16it/s, v_num=6, train_loss=0.00045, val_loss=0.319] 

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 0.319


Epoch 597: 100%|██████████| 29/29 [00:00<00:00, 45.32it/s, v_num=6, train_loss=0.000323, val_loss=0.301]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 0.301


Epoch 618: 100%|██████████| 29/29 [00:00<00:00, 50.61it/s, v_num=6, train_loss=0.000122, val_loss=0.296]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.296


Epoch 619: 100%|██████████| 29/29 [00:00<00:00, 49.64it/s, v_num=6, train_loss=0.000105, val_loss=0.296]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.296


Epoch 658: 100%|██████████| 29/29 [00:00<00:00, 52.23it/s, v_num=6, train_loss=8.8e-5, val_loss=0.284]  

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.284


Epoch 668: 100%|██████████| 29/29 [00:00<00:00, 52.56it/s, v_num=6, train_loss=1.4e-5, val_loss=0.263]  

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 0.263


Epoch 670: 100%|██████████| 29/29 [00:00<00:00, 52.99it/s, v_num=6, train_loss=0.00371, val_loss=0.252] 

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.252


Epoch 682: 100%|██████████| 29/29 [00:00<00:00, 53.28it/s, v_num=6, train_loss=4.71e-6, val_loss=0.252] 

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.252


Epoch 684: 100%|██████████| 29/29 [00:00<00:00, 52.62it/s, v_num=6, train_loss=0.000838, val_loss=0.238]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.238


Epoch 686: 100%|██████████| 29/29 [00:00<00:00, 52.65it/s, v_num=6, train_loss=1.53e-5, val_loss=0.226] 

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.226


Epoch 690: 100%|██████████| 29/29 [00:00<00:00, 53.08it/s, v_num=6, train_loss=1.42e-5, val_loss=0.224]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.224


Epoch 740: 100%|██████████| 29/29 [00:00<00:00, 47.54it/s, v_num=6, train_loss=1.41e-5, val_loss=0.287] 

Monitored metric val_loss did not improve in the last 50 records. Best score: 0.224. Signaling Trainer to stop.


Epoch 740: 100%|██████████| 29/29 [00:00<00:00, 47.38it/s, v_num=6, train_loss=1.41e-5, val_loss=0.287]


In [62]:



# Export model as ONNX files
contrastive.model_to_onnx(
    net.project_1,
    "project_1.onnx",
    input_shape=(input_dim_1,),
)
contrastive.model_to_onnx(
    net.project_2,
    "project_2.onnx",
    input_shape=(input_dim_2,),
)


[1;31m2025-04-20 00:16:05.619536333 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 418099, index: 0, mask: {1, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-20 00:16:05.619552234 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 418100, index: 1, mask: {2, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-20 00:16:05.619569698 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 418101, index: 2, mask: {3, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-20 00:16:05.621094235 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 418119, index: 20, mask: {21, }, error code: 22 error m