# Task 2: Citation-Reference Linking with BERT-base

**Model:** bert-base-uncased (Baseline)

**Task:** Binary classification - Is this the correct bib entry?

---

In [None]:
# Kaggle already has these packages installed!
# No need to install: transformers, datasets, accelerate

# Verify versions (optional)
import transformers, datasets, accelerate
print(f"✅ transformers: {transformers.__version__}")
print(f"✅ datasets: {datasets.__version__}")
print(f"✅ accelerate: {accelerate.__version__}")

In [None]:
# Data already unzipped by Kaggle - verify it
import os

train_path = '/kaggle/input/thesis-data-task2/train/train'
val_path = '/kaggle/input/thesis-data-task2/val/val'

train_count = len([f for f in os.listdir(train_path) if f.endswith('.in')])
val_count = len([f for f in os.listdir(val_path) if f.endswith('.in')])

print(f"✅ Train: {train_count} files")
print(f"✅ Val: {val_count} files")

In [None]:
# Load data - LIMITED to 100 files for testing
import json
from pathlib import Path
from datasets import Dataset

def generate_task2_examples(data_dir, max_files=100):
    """Generator - only process first max_files"""
    data_path = Path(data_dir)
    in_files = sorted(data_path.glob("*.in"))
    
    # LIMIT FILES
    in_files = in_files[:max_files]
    
    total = len(in_files)
    print(f"📊 Processing {total:,} files (limited)")

    for i, in_file in enumerate(in_files):
        if (i+1) % 20 == 0:
            print(f"⏳ Progress: {i+1}/{total}")
        
        with open(in_file) as f:
            in_data = json.load(f)

        label_file = in_file.with_suffix('.label')
        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data.get('text', '')
        bib_entries = in_data.get('bib_entries', {})
        citation_to_bib = label_data.get('correct_citation', {})

        for citation, correct_bib_id in citation_to_bib.items():
            citation_pos = text.find(citation)
            if citation_pos == -1:
                continue

            start = max(0, citation_pos - 200)
            end = min(len(text), citation_pos + len(citation) + 200)
            context = text[start:end]

            if correct_bib_id in bib_entries:
                bib_data = bib_entries[correct_bib_id]
                bib_text = bib_data.get('abstract', bib_data.get('title', ''))
                yield {
                    'context': context,
                    'bib_entry': bib_text,
                    'label': 1
                }

            for bib_id, bib_data in bib_entries.items():
                if bib_id != correct_bib_id:
                    bib_text = bib_data.get('abstract', bib_data.get('title', ''))
                    yield {
                        'context': context,
                        'bib_entry': bib_text,
                        'label': 0
                    }

print("=" * 60)
print("Creating TRAIN dataset (100 files only)...")
print("=" * 60)
train_dataset = Dataset.from_generator(
    generate_task2_examples, 
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task2/train/train', 'max_files': 100}
)
print(f"✅ Train: {len(train_dataset):,} examples")

print("\n" + "=" * 60)
print("Creating VAL dataset (100 files only)...")
print("=" * 60)
val_dataset = Dataset.from_generator(
    generate_task2_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task2/val/val', 'max_files': 100}
)
print(f"✅ Val: {len(val_dataset):,} examples")

In [None]:
# Tokenize
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(
        examples['context'],
        examples['bib_entry'],
        max_length=512,
        padding='max_length',
        truncation=True
    )

print("Tokenizing datasets...")
train_dataset = train_dataset.map(tokenize_function, batched=True, batch_size=1000)
val_dataset = val_dataset.map(tokenize_function, batched=True, batch_size=1000)

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

print("✅ Tokenization complete!")

In [None]:
# Load model
from transformers import AutoModelForSequenceClassification
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2
)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

print(f"✅ BERT-base loaded: {model.num_parameters():,} parameters")

In [None]:
# Training setup
from transformers import TrainingArguments, Trainer
from pathlib import Path
import os

training_args = TrainingArguments(
    output_dir='/kaggle/working/checkpoints/task2_bert',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=500,
    eval_strategy='steps',
    eval_steps=500,
    logging_steps=100,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,
    report_to='none',
    seed=42
)

# Check for checkpoint to resume from
# Option 1: From previous run in /kaggle/working/
checkpoint_dir = Path(training_args.output_dir)
checkpoints = sorted(checkpoint_dir.glob('checkpoint-*')) if checkpoint_dir.exists() else []

# Option 2: From Kaggle dataset (if you uploaded a checkpoint)
# Uncomment this if you added a checkpoint dataset:
# if not checkpoints and Path('/kaggle/input/task2-bert-checkpoint').exists():
#     checkpoints = sorted(Path('/kaggle/input/task2-bert-checkpoint').glob('checkpoint-*'))

resume_checkpoint = str(checkpoints[-1]) if checkpoints else None

if resume_checkpoint:
    print(f"🔄 Resuming from: {Path(resume_checkpoint).name}")
else:
    print("🆕 Starting fresh training")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# Train
print("="*60)
print("🚀 STARTING TRAINING - BERT-base")
print("="*60)

trainer.train(resume_from_checkpoint=resume_checkpoint)

print("\n✅ Training complete!")

In [None]:
# Evaluate
print("📊 VALIDATION RESULTS")
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

In [None]:
# Save final model
final_model_path = '/kaggle/working/models/task2_bert_final'
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"✅ Model saved to: {final_model_path}")
print("\n💡 TIP: Click 'Save Version' to commit and save this model permanently!")
print("\n" + "="*60)
print("✅ TASK 2 - BERT COMPLETE!")
print("="*60)