In [1]:
# Cross-Encoder Reranker for Ontology Disambiguation
# Training script for Google Colab

# ============================================================================
# 1. INSTALLATION
# ============================================================================

!pip install -q sentence-transformers accelerate


In [2]:
# ============================================================================
# 2. IMPORTS
# ============================================================================

import json
import torch
from datasets import Dataset
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer
from sentence_transformers.cross_encoder import losses
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from typing import Dict, List

# ============================================================================
# 3. LOAD DATA
# ============================================================================

def load_jsonl(file_path: str) -> List[Dict]:
    """Load JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load datasets
print("Loading datasets...")
train_data = load_jsonl('/content/train.jsonl')
dev_data = load_jsonl('/content/dev.jsonl')
test_data = load_jsonl('/content/test.jsonl')

print(f"Train samples: {len(train_data)}")
print(f"Dev samples: {len(dev_data)}")
print(f"Test samples: {len(test_data)}")

# Display sample
print("\nSample training example:")
print(json.dumps(train_data[0], indent=2))

# ============================================================================
# 4. PREPARE DATASETS FOR CROSS-ENCODER
# ============================================================================

def prepare_cross_encoder_dataset(data: List[Dict]) -> Dataset:
    """
    Convert raw data to cross-encoder format.

    Each example should have:
    - sentence1: query
    - sentence2: candidate
    - label: binary label (0 or 1)
    """
    prepared_data = {
        'sentence1': [],
        'sentence2': [],
        'label': []
    }

    for item in data:
        prepared_data['sentence1'].append(item['query'])
        prepared_data['sentence2'].append(item['candidate'])
        prepared_data['label'].append(float(item['label']))  # Ensure float for BCE loss

    return Dataset.from_dict(prepared_data)

# Prepare datasets
print("\nPreparing datasets for cross-encoder training...")
train_dataset = prepare_cross_encoder_dataset(train_data)
dev_dataset = prepare_cross_encoder_dataset(dev_data)
test_dataset = prepare_cross_encoder_dataset(test_data)

print(f"Prepared train dataset: {len(train_dataset)} samples")
print(f"Prepared dev dataset: {len(dev_dataset)} samples")
print(f"Prepared test dataset: {len(test_dataset)} samples")

# Display prepared sample
print("\nPrepared sample:")
print(train_dataset[0])

Loading datasets...
Train samples: 2401485
Dev samples: 132647
Test samples: 134290

Sample training example:
{
  "query": "cell_type: cDC1; tissue: tonsil; organism: Homo sapiens",
  "candidate": "label: tonsil germinal center B cell; definition: Any germinal center B cell that is part of a tonsil.",
  "candidate_id": "CL:2000006",
  "correct_id": "CL:0000990",
  "label": 0,
  "retrieval_score": 1.0,
  "retrieval_rank": 0,
  "example_type": "hard_negative"
}

Preparing datasets for cross-encoder training...
Prepared train dataset: 2401485 samples
Prepared dev dataset: 132647 samples
Prepared test dataset: 134290 samples

Prepared sample:
{'sentence1': 'cell_type: cDC1; tissue: tonsil; organism: Homo sapiens', 'sentence2': 'label: tonsil germinal center B cell; definition: Any germinal center B cell that is part of a tonsil.', 'label': 0.0}


In [3]:
# ============================================================================
# 5. INITIALIZE MODEL
# ============================================================================

print("\nInitializing CrossEncoder model...")
model = CrossEncoder(
    "bioformers/bioformer-16L",
    num_labels=1,  # Binary classification (relevance score)
    max_length=512,  # Maximum sequence length
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print(f"Model loaded on device: {model.device}")
print(f"Model max length: {model.max_length}")

# ============================================================================
# 6. SETUP LOSS FUNCTION
# ============================================================================

# Binary Cross Entropy Loss for binary relevance prediction
loss = losses.BinaryCrossEntropyLoss(model)
print("\nUsing BinaryCrossEntropyLoss for training")

# ============================================================================
# 7. SETUP EVALUATOR
# ============================================================================

# Create evaluator for binary classification
evaluator = CEBinaryClassificationEvaluator(
    sentence_pairs=list(zip(dev_dataset['sentence1'], dev_dataset['sentence2'])),
    labels=dev_dataset['label'],
    name='dev'
)

print("Evaluator configured for development set")

# ============================================================================
# 8. CONFIGURE TRAINING ARGUMENTS
# ============================================================================

training_args = CrossEncoderTrainingArguments(
    output_dir='./results',

    # Training hyperparameters
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,

    # Evaluation and saving
    eval_strategy='steps',
    eval_steps=500,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='dev_accuracy',

    # Optimization
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    gradient_accumulation_steps=1,
    max_grad_norm=1.0,

    # Logging
    logging_dir='./logs',
    logging_steps=100,
    logging_first_step=True,
    report_to='none',  # Change to 'wandb' or 'tensorboard' if needed

    # Other settings
    seed=42,
    dataloader_drop_last=False,
)

print("\nTraining arguments configured:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  FP16: {training_args.fp16}")


Initializing CrossEncoder model...


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bioformers/bioformer-16L and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Model loaded on device: cuda:0
Model max length: 512

Using BinaryCrossEntropyLoss for training


  evaluator = CEBinaryClassificationEvaluator(


Evaluator configured for development set

Training arguments configured:
  Epochs: 3
  Batch size: 16
  Learning rate: 2e-05
  FP16: True


In [None]:
# ============================================================================
# 9. INITIALIZE TRAINER
# ============================================================================

trainer = CrossEncoderTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    loss=loss,
    evaluator=evaluator,
)

print("\nTrainer initialized successfully")

# ============================================================================
# 10. TRAIN MODEL
# ============================================================================

print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80 + "\n")

trainer.train()

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80 + "\n")

# ============================================================================
# 11. EVALUATE ON TEST SET
# ============================================================================

print("Evaluating on test set...")

# Create test evaluator
test_evaluator = CEBinaryClassificationEvaluator(
    sentence_pairs=list(zip(test_dataset['sentence1'], test_dataset['sentence2'])),
    labels=test_dataset['label'],
    name='test'
)

# Evaluate
test_results = test_evaluator(model)

print("\nTest Set Results:")
for metric, value in test_results.items():
    print(f"  {metric}: {value:.4f}")

# ============================================================================
# 12. SAVE MODEL
# ============================================================================

output_path = "./ontology-reranker"
model.save_pretrained(output_path)
print(f"\nModel saved to: {output_path}")


Trainer initialized successfully

STARTING TRAINING



Step,Training Loss,Validation Loss,Dev Accuracy,Dev Accuracy Threshold,Dev F1,Dev F1 Threshold,Dev Precision,Dev Recall,Dev Average Precision
500,0.4958,0.4874,0.931186,0.487074,0.129276,0.354086,0.069178,0.98499,0.070679
1000,0.4226,0.404927,0.931186,0.341604,0.139779,0.289565,0.077367,0.723129,0.074146
1500,0.2789,0.256913,0.931193,0.11329,0.13556,0.100265,0.073352,0.892298,0.080669
2000,0.2556,0.250304,0.931208,0.085302,0.162005,0.075175,0.117172,0.262408,0.10788
2500,0.2593,0.247155,0.933425,0.09777,0.215296,0.073282,0.159943,0.329243,0.169606
3000,0.2254,0.244895,0.933508,0.125902,0.259713,0.058864,0.190814,0.406486,0.216926
3500,0.1939,0.230448,0.934224,0.171938,0.315088,0.081824,0.322046,0.308426,0.291652


In [None]:
# ============================================================================
# 13. INFERENCE EXAMPLE
# ============================================================================

print("\n" + "="*80)
print("INFERENCE EXAMPLE")
print("="*80 + "\n")

# Load saved model (optional - for demonstration)
# loaded_model = CrossEncoder(output_path)

# Example inference
example_query = "cell_type: C_BEST4; tissue: descending colon; organism: Homo sapiens"
example_candidates = [
    "label: smooth muscle fiber of descending colon; synonyms: non-striated muscle fiber of descending colon; definition: A smooth muscle cell that is part of the descending colon.",
    "label: smooth muscle cell of colon; synonyms: non-striated muscle fiber of colon; definition: A smooth muscle cell that is part of the colon.",
    "label: epithelial cell of colon; synonyms: colon epithelial cell; definition: An epithelial cell that is part of the colon."
]

print("Query:")
print(f"  {example_query}\n")

print("Ranking candidates...")
# Create pairs
pairs = [(example_query, candidate) for candidate in example_candidates]

# Get predictions
scores = model.predict(pairs)

# Rank by score
ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)

print("\nRanked Results:")
for rank, idx in enumerate(ranked_indices, 1):
    print(f"\n{rank}. Score: {scores[idx]:.4f}")
    print(f"   Candidate: {example_candidates[idx][:100]}...")

# ============================================================================
# 14. BATCH RANKING EXAMPLE
# ============================================================================

print("\n" + "="*80)
print("BATCH RANKING WITH model.rank()")
print("="*80 + "\n")

# Using the convenient rank() method
ranked_results = model.rank(
    example_query,
    example_candidates,
    return_documents=True,
    top_k=3
)

print("Top 3 ranked results using model.rank():")
for result in ranked_results:
    print(f"\nRank: {result['corpus_id'] + 1}")
    print(f"Score: {result['score']:.4f}")
    print(f"Text: {result['text'][:100]}...")

print("\n" + "="*80)
print("SCRIPT COMPLETED SUCCESSFULLY")
print("="*80)