# Task 3: Citation Span Extraction with SciBERT-QA

**Model:** allenai/scibert_scivocab_uncased (Question Answering)

**Task:** Extract text span that citation supports

**Features:**
- ‚úÖ Memory efficient (streaming data)
- ‚úÖ Auto resume from checkpoint
- ‚úÖ Works on Colab Free

---

In [1]:
# Kaggle already has these packages installed!
# No need to install: transformers, datasets, accelerate

# Verify versions (optional)
import transformers, datasets, accelerate
print(f"‚úÖ transformers: {transformers.__version__}")
print(f"‚úÖ datasets: {datasets.__version__}")
print(f"‚úÖ accelerate: {accelerate.__version__}")

‚úÖ transformers: 4.57.1
‚úÖ datasets: 4.4.2
‚úÖ accelerate: 1.11.0


In [2]:
# Data already unzipped by Kaggle - verify it
import os

train_path = '/kaggle/input/thesis-data-task3/train/train'
val_path = '/kaggle/input/thesis-data-task3/val/val'

train_count = len([f for f in os.listdir(train_path) if f.endswith('.in')])
val_count = len([f for f in os.listdir(val_path) if f.endswith('.in')])

print(f"‚úÖ Train: {train_count} files")
print(f"‚úÖ Val: {val_count} files")

‚úÖ Train: 55556 files
‚úÖ Val: 3000 files


In [3]:
# Load data - STREAMING (memory efficient)
import json
from pathlib import Path
from datasets import IterableDataset

def generate_task3_examples(data_dir):
    """
    Generator for Task 3 span extraction examples - STREAMING mode
    """
    data_path = Path(data_dir)
    in_files = sorted(data_path.glob("*.in"))

    total_files = len(in_files)
    print(f"üìä Found {total_files:,} .in files - streaming mode")

    for i, in_file in enumerate(in_files):
        if (i+1) % 5000 == 0:
            print(f"‚è≥ Processed {i+1:,}/{total_files:,} files ({(i+1)*100//total_files}%)")

        with open(in_file) as f:
            in_data = json.load(f)

        label_file = in_file.with_suffix('.label')
        if not label_file.exists():
            continue

        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])

        for span_info in citation_spans:
            citation = span_info['citation_id']
            span_text = span_info['span_text']
            start_char = span_info.get('start_char', -1)
            end_char = span_info.get('end_char', -1)

            if start_char == -1 or end_char == -1:
                start_char = text.find(span_text)
                if start_char != -1:
                    end_char = start_char + len(span_text)
                else:
                    continue

            question = f"What does citation {citation} support?"

            yield {
                'question': question,
                'context': text,
                'answer': span_text,
                'start_char': start_char,
                'end_char': end_char
            }

    print(f"‚úÖ Finished processing all {total_files:,} files")

print("=" * 60)
print("Creating TRAIN dataset (streaming)...")
print("=" * 60)
train_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3/train/train'}
)
print("‚úÖ Train dataset ready")

print("\n" + "=" * 60)
print("Creating VAL dataset (streaming)...")
print("=" * 60)
val_dataset = IterableDataset.from_generator(
    generate_task3_examples,
    gen_kwargs={'data_dir': '/kaggle/input/thesis-data-task3/val/val'}
)
print("‚úÖ Val dataset ready")

print("\nüí° Using IterableDataset for memory-efficient streaming!")

Creating TRAIN dataset (streaming)...
‚úÖ Train dataset ready

Creating VAL dataset (streaming)...
‚úÖ Val dataset ready

üí° Using IterableDataset for memory-efficient streaming!


In [4]:
# Tokenize for QA - DYNAMIC PADDING (memory efficient)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

def prepare_train_features(examples):
    tokenized = tokenizer(
        examples['question'],
        examples['context'],
        max_length=512,
        truncation='only_second',
        padding=False,        # Changed from 'max_length' to False for dynamic padding
        return_offsets_mapping=True
    )

    offset_mapping = tokenized['offset_mapping']
    start_positions = []
    end_positions = []

    for i in range(len(examples['question'])):
        start_char = examples['start_char'][i]
        end_char = examples['end_char'][i]
        offsets = offset_mapping[i]

        start_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= start_char < offset_end:
                start_token = idx
                break

        end_token = 0
        for idx, (offset_start, offset_end) in enumerate(offsets):
            if offset_start < end_char <= offset_end:
                end_token = idx
                break

        start_positions.append(start_token)
        end_positions.append(end_token)

    tokenized['start_positions'] = start_positions
    tokenized['end_positions'] = end_positions

    return tokenized

print("Tokenizing datasets...")
train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])
val_dataset = val_dataset.map(prepare_train_features, batched=True, remove_columns=['question', 'context', 'answer', 'start_char', 'end_char'])

# Remove offset_mapping (not needed for training)
train_dataset = train_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)
val_dataset = val_dataset.map(lambda x: {k: v for k, v in x.items() if k != 'offset_mapping'}, batched=True)

print("‚úÖ Tokenization complete (dynamic padding will be applied during training)")

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

Tokenizing datasets...
‚úÖ Tokenization complete (dynamic padding will be applied during training)


In [5]:
# Load model
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained(
    'allenai/scibert_scivocab_uncased'
)

print(f"‚úÖ Model loaded: {model.num_parameters():,} parameters")

2026-01-20 09:17:29.551237: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768900649.753930      23 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768900649.816537      23 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768900650.272488      23 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768900650.272530      23 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768900650.272533      23 computation_placer.cc:177] computation placer alr

pytorch_model.bin:   0%|          | 0.00/442M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Model loaded: 109,329,410 parameters


In [6]:
# Training setup - OPTIMIZED for IterableDataset
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
from pathlib import Path

# Dynamic padding collator (saves VRAM!)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir='/kaggle/working/checkpoints/task3_scibert',
    max_steps=10000,              # Use max_steps for IterableDataset (not epochs)
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4, # Effective batch size = 32
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=500,
    eval_strategy='steps',
    eval_steps=500,
    logging_steps=100,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    fp16=True,                    # Mixed precision to save VRAM
    report_to='none',
    seed=42
)

# Check for checkpoint to resume from
checkpoint_dir = Path(training_args.output_dir)
checkpoints = sorted(checkpoint_dir.glob('checkpoint-*')) if checkpoint_dir.exists() else []

resume_checkpoint = str(checkpoints[-1]) if checkpoints else None

if resume_checkpoint:
    print(f"üîÑ Resuming from: {Path(resume_checkpoint).name}")
else:
    print("üÜï Starting fresh training")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator  # Use dynamic padding!
)

print(f"\nüí° Training config:")
print(f"   - Effective batch size: {8 * 4} (per_device={8} √ó accumulation={4})")
print(f"   - Max steps: {10000}")
print(f"   - Dynamic padding: ON (saves VRAM)")
print(f"   - FP16: ON (saves VRAM)")

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]

üÜï Starting fresh training

üí° Training config:
   - Effective batch size: 32 (per_device=8 √ó accumulation=4)
   - Max steps: 10000
   - Dynamic padding: ON (saves VRAM)
   - FP16: ON (saves VRAM)


  trainer = Trainer(


In [7]:
# Train
print("="*60)
print("üöÄ STARTING TRAINING")
print("="*60)

trainer.train(resume_from_checkpoint=resume_checkpoint)

print("\n‚úÖ Training complete!")

üöÄ STARTING TRAINING
üìä Found 55,556 .in files - streaming mode
‚è≥ Processed 5,000/55,556 files (8%)
‚è≥ Processed 10,000/55,556 files (17%)
‚è≥ Processed 15,000/55,556 files (26%)


Step,Training Loss,Validation Loss
500,0.2842,0.532571
1000,0.0673,0.714739
1500,0.0269,0.6454
2000,0.0144,0.795877
2500,0.012,0.898423
3000,0.008,0.886054
3500,0.0035,0.945257
4000,0.0017,0.973507
4500,0.0035,0.948465
5000,0.0002,1.09387


‚è≥ Processed 20,000/55,556 files (35%)
‚è≥ Processed 25,000/55,556 files (44%)
‚è≥ Processed 30,000/55,556 files (53%)
‚è≥ Processed 35,000/55,556 files (62%)
‚è≥ Processed 40,000/55,556 files (71%)
‚è≥ Processed 45,000/55,556 files (80%)
‚è≥ Processed 50,000/55,556 files (89%)
‚è≥ Processed 55,000/55,556 files (98%)
‚úÖ Finished processing all 55,556 files
üìä Found 55,556 .in files - streaming mode
‚è≥ Processed 5,000/55,556 files (8%)
‚è≥ Processed 10,000/55,556 files (17%)
‚è≥ Processed 15,000/55,556 files (26%)
‚è≥ Processed 20,000/55,556 files (35%)
‚è≥ Processed 25,000/55,556 files (44%)
‚è≥ Processed 30,000/55,556 files (53%)
‚è≥ Processed 35,000/55,556 files (62%)
‚è≥ Processed 40,000/55,556 files (71%)
‚è≥ Processed 45,000/55,556 files (80%)
‚è≥ Processed 50,000/55,556 files (89%)
‚è≥ Processed 55,000/55,556 files (98%)
‚úÖ Finished processing all 55,556 files
üìä Found 55,556 .in files - streaming mode
‚è≥ Processed 5,000/55,556 files (8%)
‚è≥ Processed 10,000/55,556 file

In [8]:
# Evaluate
print("üìä VALIDATION RESULTS")
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

üìä VALIDATION RESULTS
üìä Found 3,000 .in files - streaming mode
‚úÖ Finished processing all 3,000 files
eval_loss: 0.5326
eval_runtime: 7.1003
eval_samples_per_second: 20.2810
eval_steps_per_second: 1.2680
epoch: 109.0081


In [9]:
# Save final model
final_model_path = '/kaggle/working/models/task3_scibert_final'
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"‚úÖ Model saved to: {final_model_path}")

‚úÖ Model saved to: /kaggle/working/models/task3_scibert_final


In [10]:
# Test inference
import torch
from transformers import pipeline

qa_pipeline = pipeline(
    'question-answering',
    model=final_model_path,
    tokenizer=final_model_path,
    device=0 if torch.cuda.is_available() else -1
)

# Test
result = qa_pipeline(
    question="What does citation [1] support?",
    context="Test paper text with citation [1] here."
)

print("\nüìã Test Inference:")
print(f"Predicted Answer: {result['answer']}")
print(f"Confidence: {result['score']:.4f}")
print("\n‚úÖ TASK 3 COMPLETE!")

Device set to use cuda:0



üìã Test Inference:
Predicted Answer: here.
Confidence: 0.2178

‚úÖ TASK 3 COMPLETE!
