# Task 3: Citation Span Extraction with BERT-QA (WITH POSITIONS + METRICS)

**Model:** bert-base-uncased (Question Answering)

**Task:** Extract text span that citation supports

**KEY IMPROVEMENTS:**
- ‚úÖ Uses pre-computed s_span/e_span positions (NO MORE text.find()!)
- ‚úÖ F1 + Exact Match metrics for proper evaluation
- ‚úÖ Early stopping based on F1 score
- ‚úÖ No data loss from failed text.find()
- ‚úÖ Memory efficient (streaming data)

---

In [1]:
import transformers, datasets, accelerate
print(f"‚úÖ transformers: {transformers.__version__}")
print(f"‚úÖ datasets: {datasets.__version__}")
print(f"‚úÖ accelerate: {accelerate.__version__}")

‚úÖ transformers: 4.57.1
‚úÖ datasets: 4.4.2
‚úÖ accelerate: 1.11.0


In [2]:
import os

train_path = '/kaggle/input/thesis-data-task3-with-positions/data/train'
val_path = '/kaggle/input/thesis-data-task3-with-positions/data/val'

train_count = len([f for f in os.listdir(train_path) if f.endswith('.label')])
val_count = len([f for f in os.listdir(val_path) if f.endswith('.label')])

print(f"‚úÖ Train: {train_count:,} files")
print(f"‚úÖ Val: {val_count:,} files")

‚úÖ Train: 55,556 files
‚úÖ Val: 3,000 files


In [3]:
# Load data with positions
import json
from pathlib import Path
from datasets import IterableDataset

def generate_task3_examples(data_dir):
    data_path = Path(data_dir)
    label_files = sorted(data_path.glob("*.label"))
    total_files = len(label_files)
    print(f"üìä Found {total_files:,} .label files")
    
    skipped = 0
    successful = 0

    for i, label_file in enumerate(label_files):
        if (i+1) % 5000 == 0:
            print(f"‚è≥ {i+1:,}/{total_files:,} | Success: {successful:,} | Skipped: {skipped}")

        try:
            with open(label_file) as f:
                label_data = json.load(f)
        except:
            skipped += 1
            continue

        text = label_data.get('text', '')
        if not text:
            skipped += 1
            continue
            
        citation_spans = label_data.get('citation_spans', [])

        for span_info in citation_spans:
            citation_id = span_info.get('citation_id', '')
            span_text = span_info.get('span_text', '')
            s_span = span_info.get('s_span', -1)
            e_span = span_info.get('e_span', -1)
            
            if s_span == -1 or e_span == -1 or s_span >= e_span:
                skipped += 1
                continue

            question = f"What does citation {citation_id} support?"
            successful += 1
            
            yield {
                'question': question,
                'context': text,
                'answer': span_text,
                'start_char': s_span,
                'end_char': e_span
            }

    print(f"‚úÖ {successful:,} examples | Skipped: {skipped}")

print("=" * 60)
print("Creating datasets...")
train_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3-with-positions/data/train'}
)
val_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3-with-positions/data/val'}
)
print("‚úÖ Datasets ready")

Creating datasets...
‚úÖ Datasets ready


In [4]:
# Tokenize
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def prepare_train_features(examples):
    tokenized = tokenizer(
        examples['question'],
        examples['context'],
        max_length=512,
        truncation='only_second',
        padding=False,
        return_offsets_mapping=True
    )

    offset_mapping = tokenized['offset_mapping']
    start_positions = []
    end_positions = []

    for i in range(len(examples['question'])):
        start_char = examples['start_char'][i]
        end_char = examples['end_char'][i]
        offsets = offset_mapping[i]

        start_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= start_char < offset_end:
                start_token = idx
                break

        end_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start < end_char <= offset_end:
                end_token = idx
                break

        start_positions.append(start_token)
        end_positions.append(end_token)

    tokenized['start_positions'] = start_positions
    tokenized['end_positions'] = end_positions
    return tokenized

train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])
val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])
train_dataset = train_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)
val_dataset = val_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)
print("‚úÖ Tokenization complete")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

‚úÖ Tokenization complete


In [5]:
# Load model
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')
print(f"‚úÖ BERT loaded: {model.num_parameters():,} parameters")

2026-01-28 13:31:49.831886: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769607110.032273      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769607110.090083      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769607110.583384      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769607110.583431      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769607110.583434      24 computation_placer.cc:177] computation placer alr

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ BERT loaded: 108,893,186 parameters


In [6]:
# Define F1 + EM metrics
import numpy as np

def compute_metrics(pred):
    """
    Compute Exact Match (EM) and Token-level F1 for span extraction.
    
    For QA models:
    - pred.predictions[0]: start_logits
    - pred.predictions[1]: end_logits
    - pred.label_ids: tuple of (start_positions, end_positions)
    """
    start_logits, end_logits = pred.predictions
    start_predictions = np.argmax(start_logits, axis=1)
    end_predictions = np.argmax(end_logits, axis=1)
    
    start_labels = pred.label_ids[0] if isinstance(pred.label_ids, tuple) else pred.label_ids[:, 0]
    end_labels = pred.label_ids[1] if isinstance(pred.label_ids, tuple) else pred.label_ids[:, 1]
    
    exact_match = 0
    f1_total = 0.0
    total = len(start_labels)
    
    for i in range(total):
        pred_start = start_predictions[i]
        pred_end = end_predictions[i]
        true_start = start_labels[i]
        true_end = end_labels[i]
        
        # Exact Match: both start and end must match
        if pred_start == true_start and pred_end == true_end:
            exact_match += 1
            f1_total += 1.0
        else:
            # F1: Calculate token overlap
            if pred_end < pred_start:
                pred_end = pred_start
            
            pred_tokens = set(range(pred_start, pred_end + 1))
            true_tokens = set(range(true_start, true_end + 1))
            
            if len(pred_tokens) > 0 and len(true_tokens) > 0:
                overlap = pred_tokens & true_tokens
                
                if len(overlap) > 0:
                    precision = len(overlap) / len(pred_tokens)
                    recall = len(overlap) / len(true_tokens)
                    f1 = 2 * precision * recall / (precision + recall)
                    f1_total += f1
    
    return {
        'exact_match': exact_match / total,
        'f1': f1_total / total
    }

print("‚úÖ Metrics function defined: EM + F1")

‚úÖ Metrics function defined: EM + F1


In [7]:
# Training setup with metrics
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback
from pathlib import Path

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir='/kaggle/working/checkpoints/task3_bert_with_positions',
    max_steps=10000,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=500,
    eval_strategy='steps',
    eval_steps=500,
    logging_steps=100,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True,
    fp16=True,
    report_to='none',
    seed=42
)

early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

checkpoint_dir = Path(training_args.output_dir)
checkpoints = sorted(checkpoint_dir.glob('checkpoint-*')) if checkpoint_dir.exists() else []
resume_checkpoint = str(checkpoints[-1]) if checkpoints else None

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping]
)

print(f"üí° Training config:")
print(f"   - Model: BERT")
print(f"   - Positions: s_span/e_span ‚úÖ")
print(f"   - Metrics: F1 + EM ‚úÖ")
print(f"   - Best model selection: F1 score")
print(f"   - Early stopping: patience=3")
print(f"   - Batch size: {8 * 4}")

üí° Training config:
   - Model: BERT
   - Positions: s_span/e_span ‚úÖ
   - Metrics: F1 + EM ‚úÖ
   - Best model selection: F1 score
   - Early stopping: patience=3
   - Batch size: 32


  trainer = Trainer(


In [8]:
# Train
print("="*60)
print("üöÄ TRAINING SCIBERT WITH POSITIONS + METRICS")
print("="*60)
trainer.train(resume_from_checkpoint=resume_checkpoint)
print("\n‚úÖ Training complete!")

üöÄ TRAINING SCIBERT WITH POSITIONS + METRICS
üìä Found 55,556 .label files


Step,Training Loss,Validation Loss,Exact Match,F1
500,0.3551,0.338362,0.887729,0.938047
1000,0.2367,0.219623,0.919746,0.953528
1500,0.1643,0.164035,0.929478,0.96109
2000,0.1809,0.158572,0.937659,0.963579
2500,0.1316,0.163646,0.942595,0.966734
3000,0.1352,0.171162,0.937941,0.962203
3500,0.1263,0.130187,0.94598,0.969106
4000,0.1234,0.108042,0.951199,0.972817
4500,0.0863,0.130404,0.950494,0.972014
5000,0.1105,0.114314,0.939774,0.959034


‚è≥ 5,000/55,556 | Success: 11,808 | Skipped: 1
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 10,000/55,556 | Success: 23,503 | Skipped: 3
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 15,000/55,556 | Success: 35,021 | Skipped: 7
‚è≥ 20,000/55,556 | Success: 46,360 | Skipped: 9
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 25,000/55,556 | Success: 57,827 | Skipped: 10
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 30,000/55,556 | Success: 69,636 | Skipped: 10
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 35,000/55,556 | Success: 81,682 | Skipped: 23
‚è≥ 40,000/55,556 | Success: 93,582 | Skipped: 25
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 45,000/55,556 | Success: 105,296 | Skipped: 31
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
‚è≥ 50,000/55,556 | Success: 116,696 | Skipped: 43
‚è≥ 55,000/55,556 | Success: 128,441 | Skipped: 45
üìä Found 3,000 .

In [9]:
# Evaluate with metrics
print("üìä VALIDATION RESULTS")
eval_results = trainer.evaluate()
print("="*60)
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")
print("="*60)
print(f"\n‚úÖ F1 Score: {eval_results.get('eval_f1', 0):.2%}")
print(f"‚úÖ Exact Match: {eval_results.get('eval_exact_match', 0):.2%}")

üìä VALIDATION RESULTS
üìä Found 3,000 .label files
‚úÖ 7,090 examples | Skipped: 0
eval_loss: 0.0930
eval_exact_match: 0.9622
eval_f1: 0.9799
eval_runtime: 100.8773
eval_samples_per_second: 70.2830
eval_steps_per_second: 4.4010
epoch: 2.0890

‚úÖ F1 Score: 97.99%
‚úÖ Exact Match: 96.22%


In [10]:
# Save
final_model_path = '/kaggle/working/models/task3_bert_with_positions_final'
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"‚úÖ Model saved to: {final_model_path}")

‚úÖ Model saved to: /kaggle/working/models/task3_bert_with_positions_final


In [11]:
# Test
import torch
from transformers import pipeline

qa_pipeline = pipeline(
    'question-answering',
    model=final_model_path,
    tokenizer=final_model_path,
    device=0 if torch.cuda.is_available() else -1
)

result = qa_pipeline(
    question="What does citation [CITATION_1] support?",
    context="Previous studies demonstrated significant improvements in model performance. These findings support our hypothesis."
)

print("\nüìã Test Inference:")
print(f"Answer: {result['answer']}")
print(f"Confidence: {result['score']:.4f}")
print("\n‚úÖ SCIBERT TRAINING COMPLETE (WITH METRICS)!")

Device set to use cuda:0



üìã Test Inference:
Answer: These findings support our hypothesis.
Confidence: 0.0055

‚úÖ SCIBERT TRAINING COMPLETE (WITH METRICS)!
