In [1]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import warnings
import json
warnings.filterwarnings('ignore')

# Paths
BASE_DIR = Path.cwd().parent.parent.parent
DATA_DIR = BASE_DIR / "ppi-assembly" / "processed_data"

# Add hcat package to path
sys.path.insert(0, str(BASE_DIR / "cat" / "dcat"))

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Load tree and metadata
clustering_path = DATA_DIR / "clustering" / "disjoint" / "brca_ppi_leiden.csv"
metadata_path = DATA_DIR / "metadata" / "node_to_sequence.json"

clustering_df = pd.read_csv(clustering_path)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

print(f"Clustering loaded: {len(clustering_df)} entries")
print(f"Metadata loaded: {len(metadata)} entries")

Clustering loaded: 2394 entries
Metadata loaded: 2394 entries


In [3]:
import pandas as pd
import json

# Load your metadata JSON
with open(metadata_path, 'r') as f:
    metadata_dict = json.load(f)

# Convert to DataFrame with SPACE-SEPARATED sequences for ProteinBERT
metadata_df = pd.DataFrame([
    {
        'id': int(node_id),
        'title': '',  # Empty title
        'abstract': ' '.join(list(sequence))  # SPACE-SEPARATE the amino acids!
    }
    for node_id, sequence in metadata_dict.items()
])

# Save to CSV
metadata_df.to_csv('metadata.csv', index=False)

metadata_df.head()

Unnamed: 0,id,title,abstract
0,0,,M G L T V S A L F S R I F G K K Q M R I L M V ...
1,1,,M T A E E M K A T E S G A Q S A P L P M E G V ...
2,2,,M T A E L Q Q D D A A G A A D G H G S S C Q M ...
3,3,,M A A N K P K G Q N S L A L H K V I M V G S G ...
4,4,,M N R G V P F R H L L L V L Q L A L L P A A T ...


In [4]:
# IMPORTANT: Create test split BEFORE training
# This ensures validation and test sets use the same nodes
from notebook_utils import create_test_split

all_node_ids = [str(node_id) for node_id in metadata_df['id'].values]
test_val_nodes = create_test_split(all_node_ids, test_ratio=0.1, seed=42)

print(f"\nTest/Val set: {len(test_val_nodes)} nodes ({len(test_val_nodes)/len(all_node_ids)*100:.1f}%)")
print(f"Train set: {len(all_node_ids) - len(test_val_nodes)} nodes")

metadata_df.head()


Test/Val set: 239 nodes (10.0%)
Train set: 2155 nodes


Unnamed: 0,id,title,abstract
0,0,,M G L T V S A L F S R I F G K K Q M R I L M V ...
1,1,,M T A E E M K A T E S G A Q S A P L P M E G V ...
2,2,,M T A E L Q Q D D A A G A A D G H G S S C Q M ...
3,3,,M A A N K P K G Q N S L A L H K V I M V G S G ...
4,4,,M N R G V P F R H L L L V L Q L A L L P A A T ...


In [8]:
clustering_df.columns = ['node', 'cluster']
clustering_df.to_csv('reformatted_clustering.csv')

In [7]:
clustering_df.head()

Unnamed: 0,node,cluster
0,0,11
1,271,11
2,652,11
3,751,11
4,1811,11


In [None]:
from train import train_model

# Train model with standard triplet loss
# Using test_val_nodes for validation to ensure consistency
finetuned_model, tokenizer, history = train_model(
    clustering_csv_path='reformatted_clustering.csv',
    metadata_csv_path='metadata.csv',
    output_dir=str(BASE_DIR / "cat" / "models" / "finetuned_dcat_esm_triplet"),
    model_name='facebook/esm2_t12_35M_UR50D',
    device=str(device),
    batch_size=16,
    epochs=3,
    lr=1e-5,
    margin=0.5,              # Standard triplet margin
    samples_per_node=3,      
    pooling='cls',
    loss_type='triplet',     # Standard triplet loss (not adaptive)
    val_nodes=test_val_nodes  # Use same nodes for validation as we'll use for testing
)

Using device: cuda

Loading clustering from reformatted_clustering.csv...
Loaded clustering:
  Total nodes: 2394
  Total clusters: 251
  Avg cluster size: 9.5

Loading metadata from metadata.csv...
  Entries: 2394

Loading model: facebook/esm2_t12_35M_UR50D...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Creating dataset...


Generating triplets: 100%|██████████| 2394/2394 [00:14<00:00, 160.81it/s]


  Generated 7182 triplets
  Using 239 nodes for validation (node-level split)
  Train: 6465 triplets | Val: 717 triplets
  Train nodes: 2155 (approx)
  Val nodes: 239

Configuring loss function...
Loss: TripletLoss
  Margin: 0.5

Starting training for 3 epochs

Epoch 1/3
------------------------------------------------------------


Training:  65%|██████▌   | 264/405 [03:39<01:56,  1.21it/s, loss=0.4517]