In [2]:
pip install transformers PyPDF2 python-docx matplotlib scikit-learn seaborn torch numpy

Collecting transformers
  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting PyPDF2
  Using cached pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Collecting python-docx
  Using cached python_docx-1.2.0-py3-none-any.whl.metadata (2.0 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.7-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting scikit-learn
  Using cached scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting torch
  Using cached torch-2.9.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting numpy
  Using cached numpy-2.3.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Collecting filelock (from transformers)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)


In [7]:
# IMPORTS AND DIRECTORY INITIALIZATION
import numpy as np
import torch
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
from docx import Document
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
from torch.nn import TripletMarginLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

BASE_DIR = Path.cwd().parent
DATA_DIR = BASE_DIR / "data"
INPUT_DIR = DATA_DIR / "input" / "clustering_info"
OUTPUT_DIR = DATA_DIR / "output"

In [8]:
def compute_embeddings(texts, model, tokenizer, device):
    """Compute embeddings for a list of texts."""
    if not texts:
        return np.array([])
    
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :]
    
    return embeddings.cpu().numpy()

In [9]:
# CLUSTER LOADING

cluster_path = INPUT_DIR / "oc_mini_clusters_0.001.csv"
metadata_path = INPUT_DIR / "oc_mini_node_metadata.csv"

cluster_df = pd.read_csv(cluster_path)

# Sanity Check
print(f"Loaded cluster data: {cluster_df.shape[0]} nodes")
print(f"\nFirst few rows:")
print(cluster_df.head(10))

print(f"\nCluster statistics:")
print(f"  - Total unique nodes: {cluster_df['node'].nunique()}")
print(f"  - Total unique clusters: {cluster_df['cluster'].nunique()}")
print(f"\nCluster size distribution:")
cluster_sizes = cluster_df['cluster'].value_counts()
print(f"  - Mean cluster size: {cluster_sizes.mean():.2f}")
print(f"  - Median cluster size: {cluster_sizes.median():.0f}")
print(f"  - Largest cluster: {cluster_sizes.max()} nodes")
print(f"  - Smallest cluster: {cluster_sizes.min()} nodes")

print(f"\nCluster assignments loaded!")

Loaded cluster data: 19705 nodes

First few rows:
      node  cluster
0    45066        5
1   989648        0
2  1146632        0
3  3732252        0
4  9488729        5
5  9489474        5
6  9489060        5
7  6382148        5
8  6382959        5
9  1623959        5

Cluster statistics:
  - Total unique nodes: 19705
  - Total unique clusters: 5

Cluster size distribution:
  - Mean cluster size: 3941.00
  - Median cluster size: 2480
  - Largest cluster: 8989 nodes
  - Smallest cluster: 2068 nodes

Cluster assignments loaded!


In [10]:
# LINKING NODES TO ACTUAL PAPERS
metadata_df = pd.read_csv(metadata_path)

print(f"Loaded metadata: {metadata_df.shape[0]} papers")
print(f"\nColumns: {list(metadata_df.columns)}")
print(f"\nFirst few rows:")
print(metadata_df.head())

# Missing data sanity check
print(f"\nMissing data:")
print(metadata_df.isnull().sum())

# Merge to get clustering
papers_df = metadata_df.merge(
    cluster_df, 
    left_on='id',
    right_on='node', 
    how='inner'
)

print(f"\nMerged dataset: {papers_df.shape[0]} papers with cluster assignments")
print(f"  - Papers with clusters: {papers_df.shape[0]}")
print(f"  - Papers without clusters: {metadata_df.shape[0] - papers_df.shape[0]}")


# Combine title and abstract for later embeddings, can switch to full papers later if we want
papers_df['text'] = papers_df['title'].fillna('') + ' ' + papers_df['abstract'].fillna('')

print(f"\nFinal Merged Table for Fine-Tuning")
print(papers_df.head())

Loaded metadata: 14442 papers

Columns: ['id', 'doi', 'title', 'abstract']

First few rows:
     id                        doi  \
0   128  10.1101/2021.05.10.443415   
1   163  10.1101/2021.05.07.443114   
2   200  10.1101/2021.05.11.443555   
3   941       10.3390/ijms20020449   
4  1141       10.3390/ijms20040865   

                                               title  \
0  Improved protein contact prediction using dime...   
1  Following the Trail of One Million Genomes: Fo...   
2  Mechanism of molnupiravir-induced SARS-CoV-2 m...   
3  Bactericidal and Cytotoxic Properties of Silve...   
4  Silver Nanoparticles: Synthesis and Applicatio...   

                                            abstract  
0  AbstractDeep residual learning has shown great...  
1  AbstractSevere acute respiratory syndrome coro...  
2  Molnupiravir is an orally available antiviral ...  
3  Silver nanoparticles (AgNPs) can be synthesize...  
4  Over the past few decades, metal nanoparticles...  

Missing dat

In [None]:
# TRIPLET GENERATION FROM CLUSTERS
def create_triplets_from_dataframe(papers_df, n_triplets=5000, seed=42):
    """
    Generate triplets (anchor, positive, negative) from clustered papers.
    
    Args:
        papers_df: DataFrame with columns ['id', 'text', 'cluster']
        n_triplets: number of triplets to generate
        seed: random seed for reproducibility
    
    Returns:
        List of (anchor_text, positive_text, negative_text)
    """
    np.random.seed(seed)
    
    # Group papers by cluster
    cluster_to_papers = papers_df.groupby('cluster')['text'].apply(list).to_dict()
    cluster_ids = list(cluster_to_papers.keys())
    
    # Filter out clusters with only 1 paper because there can be no pairing then
    cluster_ids = [cid for cid in cluster_ids if len(cluster_to_papers[cid]) >= 2]
    
    print(f"Creating triplets from {len(cluster_ids)} clusters...")
    print(f"Filtered out {len(cluster_to_papers) - len(cluster_ids)} single-paper clusters")
    
    triplets = []
    failed_attempts = 0
    max_attempts = n_triplets * 3  # Prevent infinite loops
    
    with tqdm(total=n_triplets, desc="Generating triplets") as pbar:
        while len(triplets) < n_triplets and failed_attempts < max_attempts:
            
            # Random Anchor Cluster
            anchor_cluster = np.random.choice(cluster_ids)
            
            # Select anchor and positive from same cluster
            anchor_text, positive_text = np.random.choice(
                cluster_to_papers[anchor_cluster], 
                size=2, 
                replace=False
            )
            
            # Select negative from different cluster
            negative_cluster = np.random.choice(
                [c for c in cluster_ids if c != anchor_cluster]
            )
            negative_text = np.random.choice(cluster_to_papers[negative_cluster])
            
            triplets.append((anchor_text, positive_text, negative_text))
            pbar.update(1)
    
    print(f"Generated {len(triplets)} triplets")
    return triplets

# Generate triplets
triplets = create_triplets_from_dataframe(papers_df, n_triplets = len(papers_df) * 2)

# Show example
print(f"\nExample triplet:")
print(f"ANCHOR: {triplets[0][0][:200]}...")
print(f"POSITIVE: {triplets[0][1][:200]}...")
print(f"NEGATIVE: {triplets[0][2][:200]}...")

Creating triplets from 5 clusters...
Filtered out 0 single-paper clusters


Generating triplets:   2%|▏         | 489/28884 [00:10<10:03, 47.06it/s]

In [7]:
# HYPERPARAMS AND MODEL LOADING
EPOCHS = 3
BATCH_SIZE = 32  # Adjust based on your GPU memory
LEARNING_RATE = 2e-5
MARGIN = 1.0  # Triplet loss margin
SAVE_STEPS = 500  # Save model every N steps

In [8]:
# TRIPLET LOSS TRAINING FUNCTION (WITH RUN-SPECIFIC DIRECTORIES)
from datetime import datetime

def train_with_triplet_loss(
    model, 
    tokenizer, 
    triplets, 
    device,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    margin=MARGIN
):
    """Fine-tune model using triplet loss."""
    
    # Create unique run directory with timestamp
    timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    run_dir = OUTPUT_DIR / "training_runs" / f"run_{timestamp}"
    checkpoint_dir = run_dir / "checkpoints"
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Training run directory: {run_dir}")
    print(f"Checkpoints will be saved to: {checkpoint_dir}\n")
    
    # Save training configuration
    config = {
        'timestamp': timestamp,
        'epochs': epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'margin': margin,
        'n_triplets': len(triplets)
    }
    
    import json
    with open(run_dir / "config.json", 'w') as f:
        json.dump(config, f, indent=2)
    
    model.train()
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    criterion = TripletMarginLoss(margin=margin, p=2)
    
    losses = []
    step = 0
    
    for epoch in range(epochs):
        epoch_losses = []
        
        # Shuffle triplets each epoch
        shuffled_triplets = triplets.copy()
        np.random.shuffle(shuffled_triplets)
        
        # Process in batches
        progress_bar = tqdm(
            range(0, len(shuffled_triplets), batch_size), 
            desc=f"Epoch {epoch+1}/{epochs}"
        )
        
        for i in progress_bar:
            batch = shuffled_triplets[i:i+batch_size]
            
            # Separate anchor, positive, negative
            anchors = [t[0] for t in batch]
            positives = [t[1] for t in batch]
            negatives = [t[2] for t in batch]
            
            # Tokenize
            anchor_inputs = tokenizer(
                anchors, padding=True, truncation=True, 
                max_length=512, return_tensors="pt"
            ).to(device)
            
            positive_inputs = tokenizer(
                positives, padding=True, truncation=True,
                max_length=512, return_tensors="pt"
            ).to(device)
            
            negative_inputs = tokenizer(
                negatives, padding=True, truncation=True,
                max_length=512, return_tensors="pt"
            ).to(device)
            
            # Forward pass
            anchor_outputs = model(**anchor_inputs)
            positive_outputs = model(**positive_inputs)
            negative_outputs = model(**negative_inputs)
            
            # Extract [CLS] embeddings
            anchor_emb = anchor_outputs.last_hidden_state[:, 0, :]
            positive_emb = positive_outputs.last_hidden_state[:, 0, :]
            negative_emb = negative_outputs.last_hidden_state[:, 0, :]
            
            # Compute triplet loss
            loss = criterion(anchor_emb, positive_emb, negative_emb)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            step += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Save checkpoint
            if step % SAVE_STEPS == 0:
                checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"
                torch.save({
                    'step': step,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss.item(),
                }, checkpoint_path)
                print(f"\nCheckpoint saved: {checkpoint_path.name}")
        
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        print(f"\nEpoch {epoch+1}/{epochs} - Average Loss: {avg_loss:.4f}")
        
        # Save epoch checkpoint
        epoch_checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, epoch_checkpoint_path)
    
    # Save final model
    final_model_path = run_dir / "final_model.pt"
    torch.save(model.state_dict(), final_model_path)
    
    # Save training losses
    loss_data = {
        'losses': losses,
        'epochs': list(range(1, len(losses) + 1))
    }
    with open(run_dir / "training_losses.json", 'w') as f:
        json.dump(loss_data, f, indent=2)
    
    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(losses) + 1), losses, marker='o', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.title(f'Triplet Loss Training Progress - {timestamp}')
    plt.grid(True, alpha=0.3)
    plt.savefig(run_dir / "training_loss.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n{'='*60}")
    print(f"Training complete!")
    print(f"Final model saved to: {final_model_path}")
    print(f"All outputs saved to: {run_dir}")
    print(f"{'='*60}")
    
    return model, losses, run_dir

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Loading MedCPT model...")
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder").to(device)
model.train()
print("Model loaded!")

Using device: cuda
Loading MedCPT model...
Model loaded!


In [None]:
# RUN TRAINING
print("Starting fine-tuning with triplet loss...\n")

finetuned_model, training_losses, run_dir = train_with_triplet_loss(
    model, 
    tokenizer, 
    triplets, 
    device,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    margin=MARGIN
)

print("\nTraining complete! Model is ready for evaluation.")

Starting fine-tuning with triplet loss...

Training run directory: /home/ajgrama2/data/output/training_runs/run_20251106_193924
Checkpoints will be saved to: /home/ajgrama2/data/output/training_runs/run_20251106_193924/checkpoints



Epoch 1/3:  10%|▉         | 89/903 [02:33<23:11,  1.71s/it, loss=0.0329]

In [10]:
# EMERGENCY: Clear GPU memory
import gc

# Clear any existing tensors
torch.cuda.empty_cache()
gc.collect()

# Check what's using memory
print("GPU Memory Status:")
print(f"Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
print(f"Max allocated: {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB")

GPU Memory Status:
Allocated: 0.44 GB
Reserved: 0.49 GB
Max allocated: 0.44 GB
