# Stage 3: Fine-Tune Two DistilBERT Sentiment Models

**Run this notebook in Google Colab** with a T4 GPU (free tier).

Trains two separate `distilbert-base-uncased` models sequentially:
1. **Labor stance model**: pro_labor / anti_labor / neutral
2. **Railroad outlook model**: optimistic / pessimistic / neutral

Validation sets are **100% hand-labeled**. Uses **keyword-centered truncation** â€” tokens are
centered on the highest-weighted keyword match rather than first/last tokens.

In [None]:
# --- Colab Setup ---
# Uncomment the following lines when running in Colab:

# !pip install -q transformers datasets accelerate scikit-learn

# from google.colab import drive
# drive.mount('/content/drive')

# DATA_DIR = '/content/drive/MyDrive/sentiment_analysis/data/verified_labels'
# MODEL_BASE_DIR = '/content/drive/MyDrive/sentiment_analysis/models'

# For local testing, use:
DATA_DIR = 'data/verified_labels'
MODEL_BASE_DIR = 'models'

In [None]:
import json
import os
import re
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from datasets import Dataset
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
)

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# --- Keyword definitions for keyword-centered truncation ---
LABOR_KEYWORDS = {
    'labor union': 3, 'trade union': 3, 'labor strike': 3, 'labor riot': 3,
    'collective bargaining': 3, 'labor movement': 3, 'strikebreaker': 3,
    'scab labor': 3, 'working men': 3, 'workingmen': 3,
    'knights of labor': 3, 'eight hour': 3,
    'striker': 2, 'strikers': 2, 'picket': 2, 'lockout': 2,
    'boycott': 2, 'walkout': 2, 'arbitration': 2, 'picketing': 2,
    'strike': 1, 'strikes': 1, 'wage': 1, 'wages': 1,
    'workers': 1, 'laborers': 1,
}
RAILROAD_KEYWORDS = {
    'railroad company': 3, 'railroad strike': 3, 'railroad workers': 3,
    'railway company': 3, 'union pacific': 3, 'central pacific': 3,
    'northern pacific': 3, 'pennsylvania railroad': 3,
    'baltimore and ohio': 3, 'railroad line': 3,
    'locomotive': 2, 'locomotives': 2, 'brakeman': 2,
    'freight car': 2, 'passenger car': 2, 'rail road': 2,
    'railroad': 1, 'railway': 1, 'train': 1, 'trains': 1,
}

LABOR_PATTERNS = {kw: (re.compile(r'\b' + re.escape(kw) + r'\b', re.IGNORECASE), w)
                  for kw, w in LABOR_KEYWORDS.items()}
RAILROAD_PATTERNS = {kw: (re.compile(r'\b' + re.escape(kw) + r'\b', re.IGNORECASE), w)
                     for kw, w in RAILROAD_KEYWORDS.items()}


def best_keyword_position(text, axis):
    """Find char position of highest-weighted keyword for the given axis."""
    patterns = LABOR_PATTERNS if axis == 'labor' else RAILROAD_PATTERNS
    best_pos, best_weight = 0, 0
    for kw, (regex, weight) in patterns.items():
        match = regex.search(text)
        if match and weight > best_weight:
            best_weight = weight
            best_pos = match.start()
    return best_pos


print("Keyword patterns loaded.")

In [None]:
MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def load_json(filepath):
    with open(filepath, encoding='utf-8') as f:
        return json.load(f)


def smart_truncate(text: str, axis: str, max_length: int = 512) -> str:
    """
    Keyword-centered truncation: keep a token window centered on the
    highest-weighted keyword match for this axis.
    """
    tokens = tokenizer.tokenize(text)
    if len(tokens) <= max_length - 2:  # -2 for [CLS] and [SEP]
        return text

    center_char = best_keyword_position(text, axis)
    center_token = int(len(tokens) * center_char / len(text)) if len(text) > 0 else 0

    window = max_length - 2
    half = window // 2
    start = max(0, center_token - half)
    end = start + window
    if end > len(tokens):
        end = len(tokens)
        start = max(0, end - window)

    return tokenizer.convert_tokens_to_string(tokens[start:end])


def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=512,
    )


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted', zero_division=0
    )
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}


print(f"Tokenizer ready: {MODEL_NAME}")

In [None]:
def prepare_dataset(train_path, val_path, label_map, sentiment_key, axis):
    """Load data, apply keyword-centered truncation, and return HF Datasets."""
    train_data = load_json(train_path)
    val_data = load_json(val_path)
    label_names = list(label_map.keys())

    print(f"Training samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")

    for name, data in [('Train', train_data), ('Val', val_data)]:
        print(f"\n{name} distribution:")
        for s in label_names:
            count = sum(1 for d in data if d.get(sentiment_key) == s)
            pct = count / len(data) * 100 if data else 0
            print(f"  {s}: {count} ({pct:.1f}%)")

    human_count = sum(1 for d in train_data if d.get('labeler') == 'human')
    gemini_count = sum(1 for d in train_data if d.get('labeler') == 'gemini')
    print(f"\nTraining labeler mix: {human_count} human, {gemini_count} gemini")
    print(f"Validation labeler: 100% human ({len(val_data)} samples)")

    # Keyword-centered truncation using the model's axis
    train_df = pd.DataFrame([
        {'text': smart_truncate(item['text'], axis), 'label': label_map[item[sentiment_key]]}
        for item in train_data
        if item.get(sentiment_key) in label_map
    ])
    val_df = pd.DataFrame([
        {'text': smart_truncate(item['text'], axis), 'label': label_map[item[sentiment_key]]}
        for item in val_data
        if item.get(sentiment_key) in label_map
    ])

    train_dataset = Dataset.from_pandas(train_df).map(tokenize_function, batched=True)
    val_dataset = Dataset.from_pandas(val_df).map(tokenize_function, batched=True)

    print(f"\nTokenized train: {len(train_dataset)}, val: {len(val_dataset)}")

    long_count = sum(1 for item in train_data if len(tokenizer.tokenize(item['text'])) > 510)
    print(f"Articles exceeding 512 tokens (keyword-centered): {long_count}/{len(train_data)} "
          f"({long_count / len(train_data) * 100:.1f}%)")

    return train_dataset, val_dataset

In [None]:
def train_and_evaluate(train_dataset, val_dataset, label_map, label_names,
                       model_save_dir, model_display_name, axis):
    """
    Train a DistilBERT model, evaluate, print report, save, and free GPU memory.
    """
    id2label = {v: k for k, v in label_map.items()}

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(label_map),
        id2label=id2label,
        label2id=label_map,
    )

    training_args = TrainingArguments(
        output_dir=f'./results_{model_display_name.replace(" ", "_").lower()}',
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        num_train_epochs=5,
        weight_decay=0.01,
        eval_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        fp16=torch.cuda.is_available(),
        logging_steps=50,
        logging_dir=f'./logs_{model_display_name.replace(" ", "_").lower()}',
        report_to='none',
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    print(f"\n{'='*60}")
    print(f"Training: {model_display_name}")
    print(f"{'='*60}")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Labels: {label_names}")
    print(f"  Epochs: {training_args.num_train_epochs}")
    print(f"  Batch size: {training_args.per_device_train_batch_size}")
    print(f"  FP16: {training_args.fp16}")
    print(f"  Truncation: keyword-centered on '{axis}' keywords")

    trainer.train()

    # Evaluate
    results = trainer.evaluate()
    print(f"\n=== {model_display_name} Validation Results (100% hand-labeled) ===")
    print(f"Accuracy:  {results['eval_accuracy']:.3f}")
    print(f"Precision: {results['eval_precision']:.3f}")
    print(f"Recall:    {results['eval_recall']:.3f}")
    print(f"F1:        {results['eval_f1']:.3f}")

    if results['eval_accuracy'] < 0.60:
        print(f"\n** WARNING: {model_display_name} accuracy below 60%. Consider:")
        print("   - Collapsing to 2-class (biased vs neutral)")
        print("   - Increasing hand-labeled training data")
        print("   - Switching to RoBERTa-base")

    # Classification report and confusion matrix
    predictions = trainer.predict(val_dataset)
    y_pred = np.argmax(predictions.predictions, axis=-1)
    y_true = np.array(val_dataset['label'])

    print(f"\n=== Classification Report ===")
    print(classification_report(y_true, y_pred, target_names=label_names))

    print(f"=== Confusion Matrix ===")
    cm = confusion_matrix(y_true, y_pred)
    header = ''.join(f'{f"pred_{n[:5]}":>12}' for n in label_names)
    print(f"{'':>14}{header}")
    for i, label in enumerate(label_names):
        row = ''.join(f'{cm[i][j]:>12}' for j in range(len(label_names)))
        print(f"{label:>14}{row}")

    # Save model
    os.makedirs(model_save_dir, exist_ok=True)
    trainer.save_model(model_save_dir)
    tokenizer.save_pretrained(model_save_dir)

    metadata = {
        'model_name': MODEL_NAME,
        'display_name': model_display_name,
        'axis': axis,
        'num_labels': len(label_map),
        'label_map': label_map,
        'label_names': label_names,
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
        'val_accuracy': results['eval_accuracy'],
        'val_f1': results['eval_f1'],
        'epochs': int(training_args.num_train_epochs),
        'learning_rate': training_args.learning_rate,
        'truncation': 'keyword_centered_512',
    }
    with open(os.path.join(model_save_dir, 'training_metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"\nModel saved to: {model_save_dir}")

    # Free GPU memory before training next model
    del model, trainer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return results

---
## Model A: Labor Stance

Trained on labor + both articles. Measures pro-labor vs anti-labor editorial stance.

In [None]:
LABOR_LABEL_MAP = {'pro_labor': 0, 'anti_labor': 1, 'neutral': 2}
LABOR_LABEL_NAMES = ['pro_labor', 'anti_labor', 'neutral']

print("=== Loading Labor Stance Data ===")
labor_train_ds, labor_val_ds = prepare_dataset(
    train_path=os.path.join(DATA_DIR, 'labor_train.json'),
    val_path=os.path.join(DATA_DIR, 'labor_val.json'),
    label_map=LABOR_LABEL_MAP,
    sentiment_key='labor_sentiment',
    axis='labor',
)

In [None]:
labor_results = train_and_evaluate(
    train_dataset=labor_train_ds,
    val_dataset=labor_val_ds,
    label_map=LABOR_LABEL_MAP,
    label_names=LABOR_LABEL_NAMES,
    model_save_dir=os.path.join(MODEL_BASE_DIR, 'labor_stance_model'),
    model_display_name='Labor Stance',
    axis='labor',
)

---
## Model B: Railroad Outlook

Trained on railroad + both articles. Measures optimism vs pessimism about railroads.

In [None]:
RAILROAD_LABEL_MAP = {'optimistic': 0, 'pessimistic': 1, 'neutral': 2}
RAILROAD_LABEL_NAMES = ['optimistic', 'pessimistic', 'neutral']

print("=== Loading Railroad Outlook Data ===")
rr_train_ds, rr_val_ds = prepare_dataset(
    train_path=os.path.join(DATA_DIR, 'railroad_train.json'),
    val_path=os.path.join(DATA_DIR, 'railroad_val.json'),
    label_map=RAILROAD_LABEL_MAP,
    sentiment_key='railroad_sentiment',
    axis='railroad',
)

In [None]:
railroad_results = train_and_evaluate(
    train_dataset=rr_train_ds,
    val_dataset=rr_val_ds,
    label_map=RAILROAD_LABEL_MAP,
    label_names=RAILROAD_LABEL_NAMES,
    model_save_dir=os.path.join(MODEL_BASE_DIR, 'railroad_outlook_model'),
    model_display_name='Railroad Outlook',
    axis='railroad',
)

---
## Summary

In [None]:
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
print(f"\nLabor Stance Model:")
print(f"  Accuracy: {labor_results['eval_accuracy']:.3f}")
print(f"  F1:       {labor_results['eval_f1']:.3f}")
print(f"  Saved to: {os.path.join(MODEL_BASE_DIR, 'labor_stance_model')}")
print(f"\nRailroad Outlook Model:")
print(f"  Accuracy: {railroad_results['eval_accuracy']:.3f}")
print(f"  F1:       {railroad_results['eval_f1']:.3f}")
print(f"  Saved to: {os.path.join(MODEL_BASE_DIR, 'railroad_outlook_model')}")

both_ok = labor_results['eval_accuracy'] >= 0.60 and railroad_results['eval_accuracy'] >= 0.60
print(f"\nBoth models >= 60% accuracy: {'YES' if both_ok else 'NO -- consider 2-class fallback below'}")
print(f"\nNext step: Run 04_sentiment_inference.ipynb")