# Task 3: Citation Span Extraction with BERT-QA + Special Tokens

**Model:** bert-base-uncased (Question Answering)

**Task:** Extract text span that citation supports

**NEW: Added citation special tokens** to prevent tokenization splitting

**Features:**
- ‚úÖ Citation tokens as special tokens (e.g., [CITATION_1], [CITATION_2])
- ‚úÖ Memory efficient (streaming data)
- ‚úÖ Auto resume from checkpoint
- ‚úÖ Works on Kaggle

---

In [None]:
# Kaggle already has these packages installed!
# No need to install: transformers, datasets, accelerate

# Verify versions (optional)
import transformers, datasets, accelerate
print(f"‚úÖ transformers: {transformers.__version__}")
print(f"‚úÖ datasets: {datasets.__version__}")
print(f"‚úÖ accelerate: {accelerate.__version__}")

In [None]:
# Data already unzipped by Kaggle - verify it
import os

train_path = '/kaggle/input/thesis-data-task3/train/train'
val_path = '/kaggle/input/thesis-data-task3/val/val'

train_count = len([f for f in os.listdir(train_path) if f.endswith('.in')])
val_count = len([f for f in os.listdir(val_path) if f.endswith('.in')])

print(f"‚úÖ Train: {train_count} files")
print(f"‚úÖ Val: {val_count} files")

In [None]:
# Load data - STREAMING (memory efficient)
import json
from pathlib import Path
from datasets import IterableDataset

def generate_task3_examples(data_dir):
    """
    Generator for Task 3 span extraction examples - STREAMING mode
    """
    data_path = Path(data_dir)
    in_files = sorted(data_path.glob("*.in"))

    total_files = len(in_files)
    print(f"üìä Found {total_files:,} .in files - streaming mode")

    for i, in_file in enumerate(in_files):
        if (i+1) % 5000 == 0:
            print(f"‚è≥ Processed {i+1:,}/{total_files:,} files ({(i+1)*100//total_files}%)")

        with open(in_file) as f:
            in_data = json.load(f)

        label_file = in_file.with_suffix('.label')
        if not label_file.exists():
            continue

        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])

        for span_info in citation_spans:
            citation = span_info['citation_id']
            span_text = span_info['span_text']
            start_char = span_info.get('start_char', -1)
            end_char = span_info.get('end_char', -1)

            if start_char == -1 or end_char == -1:
                start_char = text.find(span_text)
                if start_char != -1:
                    end_char = start_char + len(span_text)
                else:
                    continue

            question = f"What does citation {citation} support?"

            yield {
                'question': question,
                'context': text,
                'answer': span_text,
                'start_char': start_char,
                'end_char': end_char
            }

    print(f"‚úÖ Finished processing all {total_files:,} files")

print("="*60)
print("Creating TRAIN dataset (streaming)...")
print("="*60)
train_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3/train/train'}
)
print("‚úÖ Train dataset ready")

print("\n" + "="*60)
print("Creating VAL dataset (streaming)...")
print("="*60)
val_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3/val/val'}
)
print("‚úÖ Val dataset ready")

print("\nüí° Using IterableDataset for memory-efficient streaming!")

In [None]:
# Load tokenizer and ADD SPECIAL TOKENS for citations
from transformers import AutoTokenizer

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

print(f"Original vocab size: {len(tokenizer)}")

# Add citation special tokens
print("\nüîß Adding citation special tokens...")
citation_tokens = [f'[CITATION_{i}]' for i in range(1, 101)]  # Support up to 100 citations
num_added = tokenizer.add_tokens(citation_tokens)

print(f"‚úÖ Added {num_added} citation tokens")
print(f"New vocab size: {len(tokenizer)}")

# Test tokenization
test_text = "This research [CITATION_1] shows that [CITATION_2] improves performance."
test_tokens = tokenizer.tokenize(test_text)
print(f"\nüìã Test tokenization:")
print(f"Text: {test_text}")
print(f"Tokens: {test_tokens}")
print(f"\nüí° Citation tokens should appear as single tokens now!")

In [None]:
# Tokenize for QA - DYNAMIC PADDING (memory efficient)
def prepare_train_features(examples):
    tokenized = tokenizer(
        examples['question'],
        examples['context'],
        max_length=512,
        truncation='only_second',
        padding=False,        # Dynamic padding
        return_offsets_mapping=True
    )

    offset_mapping = tokenized['offset_mapping']
    start_positions = []
    end_positions = []

    for i in range(len(examples['question'])):
        start_char = examples['start_char'][i]
        end_char = examples['end_char'][i]
        offsets = offset_mapping[i]

        start_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= start_char < offset_end:
                start_token = idx
                break

        end_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start < end_char <= offset_end:
                end_token = idx
                break

        start_positions.append(start_token)
        end_positions.append(end_token)

    tokenized['start_positions'] = start_positions
    tokenized['end_positions'] = end_positions

    return tokenized

print("Tokenizing datasets...")
train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])
val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])

# Remove offset_mapping (not needed for training)
train_dataset = train_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)
val_dataset = val_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)

print("‚úÖ Tokenization complete (dynamic padding will be applied during training)")

In [None]:
# Load model and RESIZE embeddings for new tokens
from transformers import AutoModelForQuestionAnswering

print("Loading BERT-QA model...")
model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased')

print(f"Original parameters: {model.num_parameters():,}")

# Resize token embeddings to match new vocab size
print("\nüîß Resizing model embeddings...")
model.resize_token_embeddings(len(tokenizer))

print(f"‚úÖ New parameters: {model.num_parameters():,}")
print(f"‚úÖ Model ready with {len(tokenizer)} tokens!")

In [None]:
# Training setup - OPTIMIZED for IterableDataset
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
from pathlib import Path

# Dynamic padding collator (saves VRAM!)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir='/kaggle/working/checkpoints/task3_bert_special_tokens',  # Different dir!
    max_steps=10000,              # Use max_steps for IterableDataset
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4, # Effective batch size = 32
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=500,
    eval_strategy='steps',
    eval_steps=500,
    logging_steps=100,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    fp16=True,                    # Mixed precision to save VRAM
    report_to='none',
    seed=42
)

# Check for checkpoint to resume from
checkpoint_dir = Path(training_args.output_dir)
checkpoints = sorted(checkpoint_dir.glob('checkpoint-*')) if checkpoint_dir.exists() else []

resume_checkpoint = str(checkpoints[-1]) if checkpoints else None

if resume_checkpoint:
    print(f"üîÑ Resuming from: {Path(resume_checkpoint).name}")
else:
    print("üÜï Starting fresh training")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator  # Use dynamic padding!
)

print(f"\nüí° Training config:")
print(f"   - Effective batch size: {8 * 4} (per_device={8} √ó accumulation={4})")
print(f"   - Max steps: {10000}")
print(f"   - Dynamic padding: ON (saves VRAM)")
print(f"   - FP16: ON (saves VRAM)")
print(f"   - Special tokens: {num_added} citation tokens added")

In [None]:
# Train
print("="*60)
print("üöÄ STARTING TRAINING - BERT-QA + SPECIAL TOKENS")
print("="*60)

trainer.train(resume_from_checkpoint=resume_checkpoint)

print("\n‚úÖ Training complete!")

In [None]:
# Evaluate
print("üìä VALIDATION RESULTS")
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

In [None]:
# Save final model
final_model_path = '/kaggle/working/models/task3_bert_special_tokens_final'
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"‚úÖ Model saved to: {final_model_path}")
print(f"‚úÖ Tokenizer saved (includes {num_added} special tokens)")
print("\n" + "="*60)
print("‚úÖ TASK 3 - BERT-QA + SPECIAL TOKENS COMPLETE!")
print("="*60)