# Disjoint Clustering Fine-Tuning

**Model: MedCPT**

**Method: D-CAT (Disjoint Cluster)**

**Data: OC-Mini**

**Tuner: [Your Name]**

In [None]:
# IMPORTS AND DIRECTORY INITIALIZATION
import numpy as np
import torch
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from torch.nn import TripletMarginLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import warnings
import sys
warnings.filterwarnings('ignore')

BASE_DIR = Path.cwd().parent.parent
DATA_DIR = BASE_DIR / "oc_mini"

# Add dcat module to path
sys.path.insert(0, str(BASE_DIR / "dcat"))

## Load Model and Data

In [None]:
def compute_embeddings(texts, model, tokenizer, device):
    """Compute embeddings for a list of texts."""
    if not texts:
        return np.array([])
    
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :]
    
    return embeddings.cpu().numpy()

In [None]:
# CLUSTER LOADING

cluster_path = DATA_DIR / "clustering" / "disjoint" / "oc_mini_clusters_0.001.csv"
metadata_path = DATA_DIR / "metadata" / "oc_mini_node_metadata.csv"

cluster_df = pd.read_csv(cluster_path)

# Sanity Check
print(f"Loaded cluster data: {cluster_df.shape[0]} nodes")
print(f"\nFirst few rows:")
print(cluster_df.head(10))

print(f"\nCluster statistics:")
print(f"  - Total unique nodes: {cluster_df['node'].nunique()}")
print(f"  - Total unique clusters: {cluster_df['cluster'].nunique()}")
print(f"\nCluster size distribution:")
cluster_sizes = cluster_df['cluster'].value_counts()
print(f"  - Mean cluster size: {cluster_sizes.mean():.2f}")
print(f"  - Median cluster size: {cluster_sizes.median():.0f}")
print(f"  - Largest cluster: {cluster_sizes.max()} nodes")
print(f"  - Smallest cluster: {cluster_sizes.min()} nodes")

print(f"\nCluster assignments loaded!")

In [None]:
# Load MedCPT model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_name = "ncbi/MedCPT-Article-Encoder"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

print(f"MedCPT model loaded: {model_name}")

# Embed a sample sentence
sample_sentence = "The relationship between quantum mechanics and general relativity remains one of the most important unsolved problems in theoretical physics."

embedding = compute_embeddings([sample_sentence], model, tokenizer, device)

print(f"\nSample sentence: {sample_sentence}")
print(f"Embedding shape: {embedding.shape}")
print(f"Embedding (first 10 dimensions): {embedding[0][:10]}")

In [None]:
metadata_df = pd.read_csv(metadata_path)

metadata_df.head()

## Baseline Evaluation

MedCPT is fine-tuned for medical literature. How well does it perform on downstream network-content tasks?

In [None]:
# ============================================================================
# EVALUATE ORIGINAL MEDCPT EMBEDDINGS (BEFORE FINE-TUNING)
# ============================================================================

print("="*60)
print("EVALUATING ORIGINAL MEDCPT ON LINK PREDICTION")
print("="*60)

# Import evaluation functions
sys.path.append(str(BASE_DIR / "utils" / "evaluation"))
from link_prediction import (
    evaluate_network_link_prediction,
    plot_link_prediction_results,
    get_node_degree,
    load_network_edges
)

In [None]:
# ============================================================================
# STEP 1: COMPUTE EMBEDDINGS FOR ALL NODES
# ============================================================================

def compute_all_node_embeddings(model, tokenizer, metadata_df, device, batch_size=32):
    """Compute embeddings for all nodes in metadata"""
    embeddings_dict = {}
    
    # Get all node IDs from metadata
    node_ids = metadata_df['id'].astype(str).values
    texts = []
    valid_ids = []
    
    print(f"Preparing texts for {len(node_ids)} nodes...")
    for node_id in tqdm(node_ids, desc="Preparing"):
        row = metadata_df[metadata_df['id'] == int(node_id)].iloc[0]
        title = str(row['title']) if pd.notna(row['title']) else ""
        abstract = str(row['abstract']) if pd.notna(row['abstract']) else ""
        text = f"{title} {abstract}".strip()
        
        if text:  # Only add if we have text
            texts.append(text)
            valid_ids.append(node_id)
    
    print(f"Computing embeddings for {len(texts)} nodes...")
    
    # Compute embeddings in batches
    model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Computing embeddings"):
            batch_texts = texts[i:i+batch_size]
            batch_ids = valid_ids[i:i+batch_size]
            
            inputs = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(device)
            
            outputs = model(**inputs)
            # Use CLS token
            embeddings = outputs.last_hidden_state[:, 0, :]
            embeddings = embeddings.cpu().numpy()
            
            for node_id, emb in zip(batch_ids, embeddings):
                embeddings_dict[node_id] = emb
    
    return embeddings_dict

# Compute embeddings using the ORIGINAL MedCPT model
print("\nComputing embeddings with original MedCPT...")
original_embeddings = compute_all_node_embeddings(
    model,
    tokenizer,
    metadata_df,
    device,
    batch_size=32
)

print(f"\n✓ Computed embeddings for {len(original_embeddings)} nodes")

In [None]:
# ============================================================================
# STEP 2: CREATE TEST SET (CLUSTER-BASED SPLIT)
# ============================================================================

# Import split utilities
from split_utils import create_cluster_based_split

# Create cluster-based train/test split (NO DATA LEAKAGE!)
train_cluster_ids, test_cluster_ids, train_node_ids, test_node_ids = \
    create_cluster_based_split(cluster_df, test_ratio=0.1, seed=42)

print(f"\n✓ Test set: {len(test_node_ids)} nodes from {len(test_cluster_ids)} clusters")

In [None]:
# ============================================================================
# STEP 3: RUN BASELINE LINK PREDICTION EVALUATION
# ============================================================================

edgelist_path = DATA_DIR / "network" / "oc_mini_edgelist.csv"

# Run evaluation on baseline model
baseline_results = evaluate_network_link_prediction(
    edgelist_path=str(edgelist_path),
    embeddings_dict=original_embeddings,
    test_nodes=test_node_ids,  # Only test on held-out nodes!
    k_values=[5, 10, 20, 50, 100],
    compute_auc=True,
    num_negative_samples=10
)

## D-CAT Fine-Tuning

### Part 1: Disjoint Cluster Dataset

Create a dataset that respects cluster boundaries for train/test split.

In [None]:
# ============================================================================
# PART 1: DISJOINT CLUSTER TRIPLET DATASET
# ============================================================================

from typing import List, Tuple, Dict

class DisjointClusterTripletDataset(Dataset):
    """Dataset for disjoint cluster triplet learning"""

    def __init__(
        self,
        cluster_df: pd.DataFrame,
        metadata_df: pd.DataFrame,
        tokenizer,
        max_length: int = 512,
        samples_per_cluster: int = 5,
        train_clusters: List[int] = None
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.metadata_df = metadata_df.set_index('id')
        
        # Filter to train clusters if specified
        if train_clusters is not None:
            self.cluster_df = cluster_df[cluster_df['cluster'].isin(train_clusters)]
        else:
            self.cluster_df = cluster_df
        
        # Build cluster mappings
        self.cluster_to_nodes = self._build_cluster_mapping()
        
        # Filter out single-node clusters
        self.cluster_to_nodes = {
            cid: nodes for cid, nodes in self.cluster_to_nodes.items() 
            if len(nodes) >= 2
        }
        
        self.cluster_ids = list(self.cluster_to_nodes.keys())
        
        # Generate triplets
        print(f"Generating triplets from {len(self.cluster_ids)} clusters...")
        self.triplets = self._generate_triplets(samples_per_cluster)
        print(f"Generated {len(self.triplets)} triplets")

    def _build_cluster_mapping(self) -> Dict[int, List[str]]:
        """Build mapping from cluster ID to list of node IDs"""
        cluster_to_nodes = {}
        for cluster_id in self.cluster_df['cluster'].unique():
            nodes = self.cluster_df[self.cluster_df['cluster'] == cluster_id]['node'].astype(str).tolist()
            cluster_to_nodes[cluster_id] = nodes
        return cluster_to_nodes

    def _get_text(self, node_id: str) -> str:
        """Get combined title + abstract for a node"""
        try:
            row = self.metadata_df.loc[int(node_id)]
            title = str(row['title']) if pd.notna(row['title']) else ""
            abstract = str(row['abstract']) if pd.notna(row['abstract']) else ""
            return f"{title} {abstract}".strip()
        except (KeyError, ValueError):
            return f"Document {node_id}"

    def _generate_triplets(self, samples_per_cluster: int) -> List[Tuple[str, str, str]]:
        """Generate (anchor, positive, negative) triplets"""
        triplets = []
        
        for cluster_id in tqdm(self.cluster_ids, desc="Mining triplets"):
            cluster_nodes = self.cluster_to_nodes[cluster_id]
            
            for _ in range(samples_per_cluster):
                # Sample anchor and positive from same cluster
                anchor_node, positive_node = np.random.choice(
                    cluster_nodes, size=2, replace=False
                )
                
                # Sample negative from different cluster
                negative_cluster = np.random.choice(
                    [c for c in self.cluster_ids if c != cluster_id]
                )
                negative_node = np.random.choice(
                    self.cluster_to_nodes[negative_cluster]
                )
                
                # Get texts
                anchor_text = self._get_text(anchor_node)
                positive_text = self._get_text(positive_node)
                negative_text = self._get_text(negative_node)
                
                triplets.append((anchor_text, positive_text, negative_text))
        
        return triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor, positive, negative = self.triplets[idx]

        # Tokenize all three
        anchor_encoded = self.tokenizer(
            anchor,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        positive_encoded = self.tokenizer(
            positive,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        negative_encoded = self.tokenizer(
            negative,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'anchor_input_ids': anchor_encoded['input_ids'].squeeze(0),
            'anchor_attention_mask': anchor_encoded['attention_mask'].squeeze(0),
            'positive_input_ids': positive_encoded['input_ids'].squeeze(0),
            'positive_attention_mask': positive_encoded['attention_mask'].squeeze(0),
            'negative_input_ids': negative_encoded['input_ids'].squeeze(0),
            'negative_attention_mask': negative_encoded['attention_mask'].squeeze(0),
        }

### Part 2: Training Functions

In [None]:
# ============================================================================
# PART 2: TRAINING FUNCTIONS
# ============================================================================

def mean_pooling(model_output, attention_mask):
    """Mean pooling with attention mask"""
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


def encode_batch(model, input_ids, attention_mask, pooling='cls'):
    """Encode a batch of text to embeddings"""
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    if pooling == 'mean':
        embeddings = mean_pooling(outputs, attention_mask)
    else:  # cls
        embeddings = outputs.last_hidden_state[:, 0, :]

    # Normalize
    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    return embeddings


def train_epoch(model, dataloader, optimizer, device, margin=1.0, pooling='cls'):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    triplet_loss_fn = TripletMarginLoss(margin=margin)

    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        # Move to device
        anchor_ids = batch['anchor_input_ids'].to(device)
        anchor_mask = batch['anchor_attention_mask'].to(device)
        pos_ids = batch['positive_input_ids'].to(device)
        pos_mask = batch['positive_attention_mask'].to(device)
        neg_ids = batch['negative_input_ids'].to(device)
        neg_mask = batch['negative_attention_mask'].to(device)

        # Forward pass
        anchor_emb = encode_batch(model, anchor_ids, anchor_mask, pooling)
        pos_emb = encode_batch(model, pos_ids, pos_mask, pooling)
        neg_emb = encode_batch(model, neg_ids, neg_mask, pooling)

        # Compute loss
        loss = triplet_loss_fn(anchor_emb, pos_emb, neg_emb)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / len(dataloader)


def evaluate(model, dataloader, device, margin=1.0, pooling='cls'):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    triplet_loss_fn = TripletMarginLoss(margin=margin)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            anchor_ids = batch['anchor_input_ids'].to(device)
            anchor_mask = batch['anchor_attention_mask'].to(device)
            pos_ids = batch['positive_input_ids'].to(device)
            pos_mask = batch['positive_attention_mask'].to(device)
            neg_ids = batch['negative_input_ids'].to(device)
            neg_mask = batch['negative_attention_mask'].to(device)

            anchor_emb = encode_batch(model, anchor_ids, anchor_mask, pooling)
            pos_emb = encode_batch(model, pos_ids, pos_mask, pooling)
            neg_emb = encode_batch(model, neg_ids, neg_mask, pooling)

            loss = triplet_loss_fn(anchor_emb, pos_emb, neg_emb)
            total_loss += loss.item()

    return total_loss / len(dataloader)

### Part 3: Training Loop

In [None]:
# ============================================================================
# PART 3: MAIN TRAINING LOOP
# ============================================================================

def train_disjoint_triplet_loss(
    cluster_df,
    metadata_df,
    train_cluster_ids,
    model,
    tokenizer,
    device,
    batch_size=16,
    epochs=3,
    lr=2e-5,
    margin=1.0,
    samples_per_cluster=5,
    pooling='cls',
    train_split=0.9
):
    """
    Main training function for disjoint clustering

    Args:
        cluster_df: DataFrame with node and cluster columns
        metadata_df: DataFrame with id, title, abstract
        train_cluster_ids: List of cluster IDs for training (NO TEST LEAKAGE!)
        model: Pretrained transformer model
        tokenizer: Corresponding tokenizer
        device: torch device
        batch_size: Training batch size
        epochs: Number of epochs
        lr: Learning rate
        margin: Triplet loss margin
        samples_per_cluster: Triplets to generate per cluster
        pooling: 'cls' or 'mean'
        train_split: Train/validation split ratio

    Returns:
        Trained model and history
    """
    # Create dataset with ONLY train clusters
    print("Creating dataset from training clusters...")
    dataset = DisjointClusterTripletDataset(
        cluster_df=cluster_df,
        metadata_df=metadata_df,
        tokenizer=tokenizer,
        max_length=512,
        samples_per_cluster=samples_per_cluster,
        train_clusters=train_cluster_ids  # CRITICAL: Only train clusters!
    )

    # Split train/val
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )

    print(f"Train: {len(train_dataset)} triplets, Val: {len(val_dataset)} triplets")

    # Optimizer
    optimizer = AdamW(model.parameters(), lr=lr)

    # Training loop
    print(f"\nStarting training for {epochs} epochs...")
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch + 1}/{epochs}")
        print('='*60)

        train_loss = train_epoch(
            model, train_loader, optimizer, device, margin=margin, pooling=pooling
        )
        print(f"Train loss: {train_loss:.4f}")

        val_loss = evaluate(model, val_loader, device, margin=margin, pooling=pooling)
        print(f"Val loss: {val_loss:.4f}")

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"✓ New best validation loss!")

    print(f"\n{'='*60}")
    print(f"Training complete! Best val loss: {best_val_loss:.4f}")
    print('='*60)

    return model, history

In [None]:
# ============================================================================
# RUN TRAINING (ONLY ON TRAIN CLUSTERS!)
# ============================================================================

# Train the model - CRITICAL: Only use train_cluster_ids!
finetuned_model, history = train_disjoint_triplet_loss(
    cluster_df=cluster_df,
    metadata_df=metadata_df,
    train_cluster_ids=train_cluster_ids,  # Only train clusters!
    model=model,
    tokenizer=tokenizer,
    device=device,
    batch_size=16,
    epochs=3,
    lr=2e-5,
    margin=1.0,
    samples_per_cluster=5,
    pooling='cls',
    train_split=0.9
)

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], marker='o', label='Train Loss', linewidth=2)
plt.plot(history['val_loss'], marker='s', label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Disjoint Cluster Triplet Loss Training', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

## Evaluation

How well does D-CAT perform on downstream tasks?

In [None]:
# Compute embeddings with fine-tuned model
print("\nComputing embeddings with fine-tuned model...")
finetuned_embeddings = compute_all_node_embeddings(
    finetuned_model,
    tokenizer,
    metadata_df,
    device,
    batch_size=32
)

print(f"✓ Computed embeddings for {len(finetuned_embeddings)} nodes")

In [None]:
# ============================================================================
# EVALUATE FINE-TUNED MODEL ON TEST SET
# ============================================================================

# Evaluate link prediction on TEST set only
finetuned_results = evaluate_network_link_prediction(
    edgelist_path=str(edgelist_path),
    embeddings_dict=finetuned_embeddings,
    test_nodes=test_node_ids,  # Same test set as baseline!
    k_values=[5, 10, 20, 50, 100],
    compute_auc=True,
    num_negative_samples=10
)

In [None]:
# ============================================================================
# COMPARE BASELINE VS FINE-TUNED
# ============================================================================

print("\n" + "="*80)
print("PERFORMANCE COMPARISON: BASELINE VS FINE-TUNED")
print("="*80)

# Create comparison table
comparison_data = {
    'K': [],
    'Baseline Precision@K': [],
    'Fine-tuned Precision@K': [],
    'Improvement': []
}

for k in [5, 10, 20, 50, 100]:
    bl_prec = baseline_results['topk']['summary'][k]['precision@k']
    ft_prec = finetuned_results['topk']['summary'][k]['precision@k']
    improvement = ((ft_prec - bl_prec) / bl_prec) * 100 if bl_prec != 0 else 0
    
    comparison_data['K'].append(k)
    comparison_data['Baseline Precision@K'].append(f"{bl_prec:.4f}")
    comparison_data['Fine-tuned Precision@K'].append(f"{ft_prec:.4f}")
    comparison_data['Improvement'].append(f"{improvement:+.2f}%")

comparison_df = pd.DataFrame(comparison_data)
print("\n" + comparison_df.to_string(index=False))

# AUC comparison
if 'auc' in baseline_results and 'auc' in finetuned_results:
    bl_auc = baseline_results['auc']['auc_roc']
    ft_auc = finetuned_results['auc']['auc_roc']
    auc_improvement = ((ft_auc - bl_auc) / bl_auc) * 100
    
    print(f"\nAUC-ROC:")
    print(f"  Baseline: {bl_auc:.4f}")
    print(f"  Fine-tuned: {ft_auc:.4f}")
    print(f"  Improvement: {auc_improvement:+.2f}%")

print("="*80)

In [None]:
# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Precision@K comparison
k_values = [5, 10, 20, 50, 100]
bl_precs = [baseline_results['topk']['summary'][k]['precision@k'] for k in k_values]
ft_precs = [finetuned_results['topk']['summary'][k]['precision@k'] for k in k_values]

ax1.plot(k_values, bl_precs, marker='o', linewidth=2, label='Baseline', color='#e74c3c')
ax1.plot(k_values, ft_precs, marker='s', linewidth=2, label='Fine-tuned', color='#2ecc71')
ax1.set_xlabel('K', fontsize=12)
ax1.set_ylabel('Precision@K', fontsize=12)
ax1.set_title('Link Prediction Performance', fontsize=13, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Improvement bars
improvements = [(ft - bl) for bl, ft in zip(bl_precs, ft_precs)]
colors = ['#2ecc71' if imp > 0 else '#e74c3c' for imp in improvements]
ax2.bar(range(len(k_values)), improvements, color=colors, alpha=0.7)
ax2.set_xlabel('K', fontsize=12)
ax2.set_ylabel('Δ Precision@K', fontsize=12)
ax2.set_title('Performance Improvement', fontsize=13, fontweight='bold')
ax2.set_xticks(range(len(k_values)))
ax2.set_xticklabels(k_values)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.3)
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()