In [50]:
import os
os.chdir('../')
import argparse
import json
from glob import glob
import logging
from pathlib import Path
from typing import Any, Dict, Literal, List, Tuple

import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from tqdm.auto import tqdm

import sys
import os

from proteinclip import contrastive
from proteinclip import data_utils
from proteinclip import hparams

from proteinclip.ppi_data import load_ppi_data
from proteinclip.model_utils import ONNXModel

PPI_DATA_DIR = data_utils.DATA_DIR / "ppi"


'/storage/ice1/6/2/savunuri3/cse7850-compbio-project/proteinclip/proteinclip'

In [2]:
def infer_esm_size_from_path(path: str) -> int:
    bname = os.path.basename(path)
    if "33layer" in bname and "30layer" not in bname:
        return 33
    elif "30layer" in bname and "33layer" not in bname:
        return 30
    raise ValueError(f"Could not infer ESM size from path {path}")


In [3]:
def load_model_config(model_dir: Path | str) -> Dict[str, Any]:
    """Load model and training configuration."""
    model_config_path = Path(model_dir) / "model_config.json"
    if model_config_path.exists():
        with open(model_config_path) as source:
            model_config = json.load(source)

    training_config_path = Path(model_dir) / "training_config.json"
    with open(training_config_path) as source:
        training_config = json.load(source)
    return model_config, training_config


In [4]:
train_pairs, valid_pairs, test_pairs = (
    load_ppi_data(split) for split in ("train", "valid", "test")
)

In [6]:
all_keys = set()
for p in train_pairs + valid_pairs + test_pairs:
    all_keys.add(p[0])
    all_keys.add(p[1])


In [54]:
full_embed_map = pd.read_parquet('text+ESM+protclip.parquet')

In [59]:
full_embed_map.columns

Index(['organism', 'organism_id', 'name', 'evidence', 'function', 'id',
       'textual_embedding', 'sequence', 'proteinclip_embed', 'ESM_embed'],
      dtype='object')

In [60]:
embed_map = {
    row['id']: torch.from_numpy(row['ESM_embed']).float().numpy()  # Cast to Float32
    for _, row in tqdm(full_embed_map.iterrows(), desc="Loading relevant embeddings")
    if row['id'] in all_keys
}


Loading relevant embeddings: 10000it [00:00, 36050.70it/s]


In [61]:
embed_map = {k: v / np.linalg.norm(v) for k, v in embed_map.items()}


In [62]:
import onnxruntime as ort

session = ort.InferenceSession("pc_project_1_10000.onnx", providers=["CUDAExecutionProvider"])  # or "CPUExecutionProvider"


[1;31m2025-04-25 23:08:49.341692656 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3842959, index: 0, mask: {24, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-25 23:08:49.341754252 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3842966, index: 7, mask: {4, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-25 23:08:49.341763165 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3842962, index: 3, mask: {32, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2025-04-25 23:08:49.341692875 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3842960, index: 1, mask: {48, }, error code: 22 er

In [63]:
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_type = session.get_inputs()[0].type
print(f"Input name: {input_name}, shape: {input_shape}, type: {input_type}")


Input name: input, shape: ['batch_size', 1280], type: tensor(float)


In [64]:
def generate_embeds(mp,batch_size=32):
    clip_embeds = []
    ls = []
    keys = []
    for i in mp:
        keys.append(i)
        ls.append(mp[i])

    for i in tqdm(range(0, len(ls), batch_size)):
        batch_seqs = ls[i:i+batch_size]


        batch_clip_plus_plus_embs = session.run(None, {input_name: batch_seqs})[0]  


        clip_embeds.extend(batch_clip_plus_plus_embs)
    
    ans = {}
    for ind,i in enumerate(keys):
        ans[i] = clip_embeds[ind]

    return ans


In [65]:
embed_map = generate_embeds(embed_map)

100%|██████████| 171/171 [00:00<00:00, 4947.55it/s]

(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)
(32, 128)





In [66]:
    train_dset, valid_dset, test_dset = [
        data_utils.LabeledPairDataset(p, embed_map)
        for p in (train_pairs, valid_pairs, test_pairs)
    ]




In [67]:
    train_dl, valid_dl, test_dl = [
        DataLoader(
            d, batch_size=128, shuffle=(i == 0), num_workers=16
        )
        for i, d in enumerate([train_dset, valid_dset, test_dset])
    ]




In [68]:
    pl.seed_everything(seed=6489)


[rank: 0] Seed set to 6489


6489

In [69]:
net = contrastive.MLPForClassification(
    init_dim=next(iter(train_dl))[0].shape[-1],
    dims=[128, 1],
    lr=1e-4,
    unit_norm=False,
)


In [70]:
ckpt_callback = ModelCheckpoint(
    dirpath=None,
    filename="{epoch}-{val_auprc:.4f}",
    monitor="val_auprc",
    mode="max",
    save_top_k=1,
    save_weights_only=True,
    auto_insert_metric_name=True,
)

early_stop_callback = EarlyStopping(
    monitor='val_auprc',         # metric name to monitor
    patience=50,                 # epochs with no improvement after which to stop
    mode='max',                 # 'min' if you want to minimize the metric
    verbose=True                # optional: prints a message when triggered
)


In [71]:
    trainer = pl.Trainer(
        max_epochs=1000,
        accelerator="cuda",
        enable_progress_bar=True,
        # logger=logger,
        callbacks=[ckpt_callback, early_stop_callback],
        # log_every_n_steps=1,
        deterministic=True,
        devices=1
    )


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [73]:
trainer.fit(net, train_dataloaders=train_dl, val_dataloaders=valid_dl)

/home/hice1/savunuri3/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /storage/ice1/6/2/savunuri3/cse7850-compbio-project/proteinclip/lightning_logs/version_2509887/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params | Mode 
----------------------------------------------
0 | layers | Sequential | 33.3 K | train
----------------------------------------------
33.3 K    Trainable params
0         Non-trainable params
33.3 K    Total params
0.133     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



                                                                            



Epoch 262: 100%|██████████| 284/284 [00:02<00:00, 113.29it/s, v_num=2509887, train_loss=0.337, val_loss=0.906, val_auroc=0.573, val_auprc=0.588]

Monitored metric val_auprc did not improve in the last 51 records. Best score: 0.594. Signaling Trainer to stop.


Epoch 262: 100%|██████████| 284/284 [00:02<00:00, 113.22it/s, v_num=2509887, train_loss=0.337, val_loss=0.906, val_auroc=0.573, val_auprc=0.588]
