# Task 3: Citation Span Extraction with BERT-QA

**Model:** bert-base-uncased (Question Answering - Baseline)

**Task:** Extract text span that citation supports

**Features:**
- âœ… Memory efficient (streaming data)
- âœ… Auto resume from checkpoint
- âœ… Works on Colab Free

---

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies
!pip install transformers datasets accelerate -q

In [None]:
# Unzip data
import os
import zipfile

os.makedirs('/content/data/task3', exist_ok=True)

print("Unzipping train data...")
with zipfile.ZipFile('/content/drive/MyDrive/THESIS/data/task3/train.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/data/task3/')
train_count = len([f for f in os.listdir('/content/data/task3/train') if f.endswith('.in')])
print(f"âœ… Train: {train_count} files")

print("Unzipping val data...")
with zipfile.ZipFile('/content/drive/MyDrive/THESIS/data/task3/val.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/data/task3/')
val_count = len([f for f in os.listdir('/content/data/task3/val') if f.endswith('.in')])
print(f"âœ… Val: {val_count} files")

In [None]:
# Load data (memory efficient)
import json
from pathlib import Path
from datasets import Dataset

def generate_task3_examples(data_dir):
    data_path = Path(data_dir)
    in_files = sorted(data_path.glob("*.in"))

    for in_file in in_files:
        with open(in_file) as f:
            in_data = json.load(f)

        label_file = in_file.with_suffix('.label')
        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])

        for span_info in citation_spans:
            citation = span_info['citation']
            span_text = span_info['span_text']
            start_char = span_info.get('start_char', -1)
            end_char = span_info.get('end_char', -1)

            if start_char == -1 or end_char == -1:
                start_char = text.find(span_text)
                if start_char != -1:
                    end_char = start_char + len(span_text)
                else:
                    continue

            question = f"What does citation {citation} support?"

            yield {
                'question': question,
                'context': text,
                'answer': span_text,
                'start_char': start_char,
                'end_char': end_char
            }

def create_dataset(data_dir):
    examples = list(generate_task3_examples(data_dir))
    return Dataset.from_dict({
        'question': [ex['question'] for ex in examples],
        'context': [ex['context'] for ex in examples],
        'answer': [ex['answer'] for ex in examples],
        'start_char': [ex['start_char'] for ex in examples],
        'end_char': [ex['end_char'] for ex in examples]
    })

print("Loading train dataset...")
train_dataset = create_dataset('/content/data/task3/train')
print(f"âœ… Train: {len(train_dataset):,} examples")

print("Loading val dataset...")
val_dataset = create_dataset('/content/data/task3/val')
print(f"âœ… Val: {len(val_dataset):,} examples")

In [None]:
# Tokenize for QA
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def prepare_train_features(examples):
    tokenized = tokenizer(
        examples['question'],
        examples['context'],
        max_length=512,
        truncation='only_second',
        padding='max_length',
        return_offsets_mapping=True
    )

    offset_mapping = tokenized['offset_mapping']
    start_positions = []
    end_positions = []

    for i in range(len(examples['question'])):
        start_char = examples['start_char'][i]
        end_char = examples['end_char'][i]
        offsets = offset_mapping[i]

        start_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= start_char < offset_end:
                start_token = idx
                break

        end_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start < end_char <= offset_end:
                end_token = idx
                break

        start_positions.append(start_token)
        end_positions.append(end_token)

    tokenized['start_positions'] = start_positions
    tokenized['end_positions'] = end_positions

    return tokenized

print("Tokenizing datasets...")
train_dataset = train_dataset.map(prepare_train_features, batched=True, batch_size=1000, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(prepare_train_features, batched=True, batch_size=1000, remove_columns=val_dataset.column_names)

train_dataset = train_dataset.remove_columns(['offset_mapping'])
val_dataset = val_dataset.remove_columns(['offset_mapping'])

print("âœ… Tokenization complete!")

In [None]:
# Load model
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained(
    'bert-base-uncased'
)

print(f"âœ… BERT-QA loaded: {model.num_parameters():,} parameters")

In [None]:
# Training setup
from transformers import TrainingArguments, Trainer
from pathlib import Path

training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/THESIS/checkpoints/task3_bert',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=500,
    eval_strategy='steps',
    eval_steps=500,
    logging_steps=100,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    fp16=True,
    report_to='none',
    seed=42
)

# Check for checkpoint
checkpoint_dir = Path(training_args.output_dir)
checkpoints = sorted(checkpoint_dir.glob('checkpoint-*'))
resume_checkpoint = str(checkpoints[-1]) if checkpoints else None

if resume_checkpoint:
    print(f"ðŸ”„ Resuming from: {checkpoints[-1].name}")
else:
    print("ðŸ†• Starting fresh training")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer
)

In [None]:
# Train
print("="*60)
print("ðŸš€ STARTING TRAINING - BERT-QA")
print("="*60)

trainer.train(resume_from_checkpoint=resume_checkpoint)

print("\nâœ… Training complete!")

In [None]:
# Evaluate
print("ðŸ“Š VALIDATION RESULTS")
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

In [None]:
# Save final model
final_model_path = '/content/drive/MyDrive/THESIS/models/task3_bert_final'
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"âœ… Model saved to: {final_model_path}")
print("\n" + "="*60)
print("âœ… TASK 3 - BERT-QA COMPLETE!")
print("="*60)