In [23]:
import os
import json
import argparse
import logging
import pickle
import random

import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
from torch.utils import data
import lightning.pytorch as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger

from proteinclip import data_utils, fasta_utils, swissprot, hparams
from proteinclip import contrastive

In [1]:
import os
os.chdir(os.path.join(os.getcwd(), '../'))

In [None]:
# get proteinclip_trained embeddings


In [3]:
args = {
    'structural_embed':'../data/structural/structural_embeddings_0_10000.pkl',
    'proteinclip_embed':'protclip_embed_dataset.parquet',
    'unitnorm': True,
    'batch_size': 512,
    'dim': 128,
    'nhidden': 1,
    'learning_rate':1e-4,
    'out':'training',
    'name':'testrun',
    'max_epochs': 1000,
    
}

In [4]:
pl.seed_everything(seed=6489)


[rank: 0] Seed set to 6489


6489

In [5]:
structural_embeddings = None
with open(args['structural_embed'], "rb") as f:
    structural_embeddings = pickle.load(f)
structural_embeddings = dict(zip(
    structural_embeddings["protein_id"],
    structural_embeddings["embedding"]
))

In [6]:
df = pd.read_parquet('protclip_embed_dataset.parquet')

In [8]:
import onnxruntime as ort

session = ort.InferenceSession("pc_project_1_10000.onnx", providers=["CUDAExecutionProvider"])  # or "CPUExecutionProvider"

[1;31m2025-04-26 11:44:54.455889158 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 396065, index: 0, mask: {1, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 11:44:54.455904838 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 396080, index: 15, mask: {16, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 11:44:54.455922336 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 396100, index: 35, mask: {36, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 11:44:54.455932614 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 396119, index: 54, mask: {55, }, error code: 22 err

In [9]:
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_type = session.get_inputs()[0].type
print(f"Input name: {input_name}, shape: {input_shape}, type: {input_type}")


Input name: input, shape: ['batch_size', 1280], type: tensor(float)


In [58]:
def compute_proteinclip_trained_embeddings(ls, batch_size=32, device="cuda"):
    clip_embeds = []

    # Make sure input is a numpy array
    ls = np.array(ls)

    for i in tqdm(range(0, len(ls), batch_size)):
        esm_embs = ls[i:i+batch_size]  # shape (batch_size, emb_dim)

        # Step 2: Normalize
        esm_embs = esm_embs / np.linalg.norm(esm_embs, axis=1, keepdims=True)

        # Ensure type is float32 (many ONNX models expect float32)
        esm_embs = esm_embs.astype(np.float32)

        # Step 3: Run inference
        batch_clip_embs = session.run(None, {input_name: esm_embs})[0]  

        clip_embeds.append(batch_clip_embs)

    return np.concatenate(clip_embeds, axis=0)  # shape (N, 128)


In [62]:
dataset = pd.read_parquet('text+ESM+protclip.parquet')

In [64]:
esm_embeddings = np.stack(dataset['ESM_embed'].values)  # (N, emb_dim)

# Step 2: Compute proteinclip embeddings
proteinclip_embeds = compute_proteinclip_trained_embeddings(esm_embeddings)

# Step 3: Save back into DataFrame
dataset['proteinclip_trained_embed'] = proteinclip_embeds.tolist()


100%|██████████| 313/313 [00:00<00:00, 4862.57it/s]


In [67]:
proteinclip_embeddings = proteinclip_embeds

In [69]:
# proteinclip_embeddings = pd.read_parquet(args['proteinclip_embed'])
proteinclip_embeddings = dict(zip(
    dataset["id"],
    dataset["proteinclip_trained_embed"]
))

In [70]:
shared_keys = sorted(set(proteinclip_embeddings.keys()).intersection(set(structural_embeddings.keys())))


In [71]:
random.seed(42)  # Set your desired seed here
random.shuffle(shared_keys)


In [72]:
dset = data_utils.CLIPDataset(
    pairs=shared_keys,
    map1=proteinclip_embeddings,
    map2=structural_embeddings,
    enforce_unit_norm=args['unitnorm'],
)

Checking for missing/zero embeddings: 100%|██████████| 9986/9986 [00:00<00:00, 21247.53it/s]
INFO:root:Trimmed 0 pairs with missing/zero embeddings, 9986 remain.


In [73]:
split_indices = data_utils.random_split(len(dset), [0.9, 0.05, 0.05])
dset_splits = [data.Subset(dset, idx) for idx in split_indices]

In [74]:
# Create data loaders
train_dl, valid_dl, _test_dl = [
    data.DataLoader(
        ds,
        batch_size=args['batch_size'],
        shuffle=(i == 0),
        # drop_last=(i == 0),
        num_workers=8,
        pin_memory=True,
    )
    for i, ds in enumerate(dset_splits)
]




In [75]:
# Define network
input_dim_1 = next(iter(train_dl))["x_1"].shape[-1]
input_dim_2 = next(iter(train_dl))["x_2"].shape[-1]
model_class = (
    contrastive.ContrastiveEmbedding
)
net = model_class(
    input_dim_1=input_dim_1,
    input_dim_2=input_dim_2,
    shared_dim=args['dim'],
    num_hidden=args['nhidden'],
    lr=args['learning_rate'],
)

In [None]:
# # Define logger, write configuration files and data splits
# logger = CSVLogger(save_dir=args['out'], name=args['name'])
# # logger.log_hyperparams(hyperparameters.as_dict())
# write_split_identifiers(
#     train_ids=[dset.pairs[i] for i in split_indices[0]],
#     valid_ids=[dset.pairs[i] for i in split_indices[1]],
#     test_ids=[dset.pairs[i] for i in split_indices[2]],
#     out_file=os.path.join(logger.log_dir, "data_splits.json"),
# )
# net.write_config_json(os.path.join(logger.log_dir, "model_config.json"))
# with open(os.path.join(logger.log_dir, "training_config.json"), "w") as sink:
#     json.dump(vars(args), sink, indent=4)


In [76]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
early_stop_callback = EarlyStopping(
    monitor='val_loss',         # metric name to monitor
    patience=50,                 # epochs with no improvement after which to stop
    mode='min',                 # 'min' if you want to minimize the metric
    verbose=True                # optional: prints a message when triggered
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename='best-checkpoint'
)


In [78]:
# Train
trainer = pl.Trainer(
    max_epochs=10000,
    accelerator="cuda",
    devices=1,
    enable_progress_bar=True,
    # logger=logger,
    # log_every_n_steps=10,
    deterministic=True,
    callbacks=[early_stop_callback, checkpoint_callback],

)
trainer.fit(net, train_dataloaders=train_dl, val_dataloaders=valid_dl)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/hice1/savunuri3/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /storage/ice1/6/2/savunuri3/cse7850-compbio-project/proteinclip/lightning_logs/version_2509889/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type | Params | Mode 
----------------------------------------------
0 | project_1    | MLP  | 33.3 K | train
1 | project_2    | MLP  | 1.2 M  | train
  | other params | n/a  | 1      | n/a  
----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.865     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 1: 100%|██████████| 18/18 [00:00<00:00, 26.03it/s, v_num=2509890, train_loss=0.263, val_loss=1.040]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 1.045


Epoch 2: 100%|██████████| 18/18 [00:00<00:00, 26.39it/s, v_num=2509890, train_loss=0.197, val_loss=1.040]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 1.039


Epoch 3: 100%|██████████| 18/18 [00:00<00:00, 26.63it/s, v_num=2509890, train_loss=0.204, val_loss=1.030]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 1.028


Epoch 5: 100%|██████████| 18/18 [00:00<00:00, 27.17it/s, v_num=2509890, train_loss=0.221, val_loss=1.030]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.026


Epoch 38: 100%|██████████| 18/18 [00:00<00:00, 26.71it/s, v_num=2509890, train_loss=0.164, val_loss=1.020]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 1.018


Epoch 41: 100%|██████████| 18/18 [00:00<00:00, 26.07it/s, v_num=2509890, train_loss=0.168, val_loss=1.000]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 1.004


Epoch 43: 100%|██████████| 18/18 [00:00<00:00, 26.28it/s, v_num=2509890, train_loss=0.187, val_loss=1.000]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 1.001


Epoch 46: 100%|██████████| 18/18 [00:00<00:00, 27.08it/s, v_num=2509890, train_loss=0.179, val_loss=0.996]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.996


Epoch 49: 100%|██████████| 18/18 [00:00<00:00, 26.83it/s, v_num=2509890, train_loss=0.175, val_loss=0.994]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.994


Epoch 54: 100%|██████████| 18/18 [00:00<00:00, 26.79it/s, v_num=2509890, train_loss=0.174, val_loss=0.992]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.992


Epoch 68: 100%|██████████| 18/18 [00:00<00:00, 27.08it/s, v_num=2509890, train_loss=0.165, val_loss=0.991]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.991


Epoch 70: 100%|██████████| 18/18 [00:00<00:00, 27.15it/s, v_num=2509890, train_loss=0.146, val_loss=0.963]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 0.963


Epoch 71: 100%|██████████| 18/18 [00:00<00:00, 25.15it/s, v_num=2509890, train_loss=0.132, val_loss=0.958]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.958


Epoch 104: 100%|██████████| 18/18 [00:00<00:00, 27.60it/s, v_num=2509890, train_loss=0.154, val_loss=0.950]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.950


Epoch 106: 100%|██████████| 18/18 [00:00<00:00, 25.41it/s, v_num=2509890, train_loss=0.118, val_loss=0.931]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.931


Epoch 107: 100%|██████████| 18/18 [00:00<00:00, 22.67it/s, v_num=2509890, train_loss=0.143, val_loss=0.929]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.929


Epoch 124: 100%|██████████| 18/18 [00:00<00:00, 25.79it/s, v_num=2509890, train_loss=0.128, val_loss=0.921]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.921


Epoch 132: 100%|██████████| 18/18 [00:00<00:00, 26.37it/s, v_num=2509890, train_loss=0.135, val_loss=0.917]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.917


Epoch 153: 100%|██████████| 18/18 [00:00<00:00, 25.24it/s, v_num=2509890, train_loss=0.106, val_loss=0.902] 

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.902


Epoch 191: 100%|██████████| 18/18 [00:00<00:00, 26.99it/s, v_num=2509890, train_loss=0.103, val_loss=0.886] 

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.886


Epoch 201: 100%|██████████| 18/18 [00:00<00:00, 27.14it/s, v_num=2509890, train_loss=0.0735, val_loss=0.882]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.882


Epoch 236: 100%|██████████| 18/18 [00:00<00:00, 25.88it/s, v_num=2509890, train_loss=0.0779, val_loss=0.875]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.875


Epoch 258: 100%|██████████| 18/18 [00:00<00:00, 26.87it/s, v_num=2509890, train_loss=0.0757, val_loss=0.858]

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 0.858


Epoch 269: 100%|██████████| 18/18 [00:00<00:00, 24.53it/s, v_num=2509890, train_loss=0.0647, val_loss=0.854]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.854


Epoch 319: 100%|██████████| 18/18 [00:00<00:00, 27.02it/s, v_num=2509890, train_loss=0.0609, val_loss=0.854]

Monitored metric val_loss did not improve in the last 50 records. Best score: 0.854. Signaling Trainer to stop.


Epoch 319: 100%|██████████| 18/18 [00:00<00:00, 26.95it/s, v_num=2509890, train_loss=0.0609, val_loss=0.854]


In [79]:



# Export model as ONNX files
contrastive.model_to_onnx(
    net.project_1,
    "pc+_project_1_10000.onnx",
    input_shape=(input_dim_1,),
)
contrastive.model_to_onnx(
    net.project_2,
    "pc+_project_2_10000.onnx",
    input_shape=(input_dim_2,),
)


[1;31m2025-04-26 12:24:14.637051914 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 493951, index: 0, mask: {1, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 12:24:14.637062159 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 493952, index: 1, mask: {2, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 12:24:14.638514746 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 494010, index: 59, mask: {60, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-26 12:24:14.642024115 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 493957, index: 6, mask: {7, }, error code: 22 error m