In [9]:
import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import RobertaTokenizer, RobertaForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight
import torch.nn as nn
import torch.nn.functional as F

In [10]:
def load_data(data_folder, variable_code, exclude_classes=None, include_classes=None):
    """
    Loads question-response pairs for a given ANES variable code.
    Adds feeling thermometer scores as contextual input features.
    """
    examples = []
    label_map = {}
    next_label_id = 0

    excluded_count = 0
    included_count = 0
    missing_answer_count = 0
    not_included_count = 0
    matched_count = 0

    if exclude_classes is None:
        exclude_classes = ['Inapplicable', 'Refused', "Don't know", 'Error', "Don't know"]

    json_files = [f for f in os.listdir(data_folder) if f.endswith('.json')]
    print(f"Processing {len(json_files)} JSON files for variable {variable_code}")

    # Helper to extract thermometer score
    def extract_thermometer_score(responses, code):
        for r in responses:
            if r["variable_code"] == code:
                ans = r.get("respondent_answer")
                if ans in exclude_classes or ans is None:
                    return "NA"
                return str(ans)
        return "NA"

    for i, fname in enumerate(json_files):
        if i % 500 == 0:
            print(f"Progress: {i}/{len(json_files)} files processed")

        try:
            with open(os.path.join(data_folder, fname)) as f:
                respondent = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            continue

        responses = respondent.get("responses", [])
        found = False
        for item in responses:
            if item.get("variable_code") != variable_code:
                continue

            question = item.get("full_question_text", "")
            possible_answers = [opt["text"] for opt in item.get("possible_answers", [])]
            respondent_answer = item.get("respondent_answer", None)

            if not respondent_answer:
                missing_answer_count += 1
                continue

            if respondent_answer in exclude_classes:
                excluded_count += 1
                continue

            if include_classes and respondent_answer not in include_classes:
                not_included_count += 1
                continue

            included_count += 1

            if respondent_answer not in label_map:
                label_map[respondent_answer] = next_label_id
                next_label_id += 1
            label = label_map[respondent_answer]

            # Add feeling thermometer variables
            harris = extract_thermometer_score(responses, "V241156")
            trump = extract_thermometer_score(responses, "V241157")
            biden = extract_thermometer_score(responses, "V241158")
            dem_party = extract_thermometer_score(responses, "V241166")
            rep_party = extract_thermometer_score(responses, "V241167")

            input_text = (
                f"Kamala Harris rating: {harris}\n"
                f"Donald Trump rating: {trump}\n"
                f"Joe Biden rating: {biden}\n"
                f"Democratic Party: {dem_party}\n"
                f"Republican Party: {rep_party}\n"
                f"Q: Who would this respondent vote for in a Harris vs Trump election?"
            )

            examples.append((input_text, label))
            matched_count += 1
            found = True
            break  # Only use first match per respondent

    # Summary logging
    print(f"\n📊 Summary for variable {variable_code}:")
    print(f"  ➤ Total JSON files: {len(json_files)}")
    print(f"  ➤ Valid examples collected: {matched_count}")
    print(f"  ➤ Unique labels: {len(label_map)}")
    print(f"  ➤ Skipped due to missing answers: {missing_answer_count}")
    print(f"  ➤ Skipped due to exclusion list: {excluded_count}")
    print(f"  ➤ Skipped (not in include_classes): {not_included_count}")
    if include_classes:
        print(f"  ➤ Included only: {include_classes}")
    print(f"  ➤ Final label map: {label_map}")

    # Class distribution
    label_counts = Counter([label for _, label in examples])
    print("\n🔍 Class distribution (label IDs):", label_counts)
    for label, count in label_counts.items():
        for key, val in label_map.items():
            if val == label:
                print(f"  ➤ '{key}': {count} samples")

    return examples, label_map


In [11]:
class ANESDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts = list(texts)
        self.labels = list(labels)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )

        return {
            "input_ids": enc["input_ids"].squeeze(0).long(),
            "attention_mask": enc["attention_mask"].squeeze(0).float(),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long),
        }


In [12]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        alpha: 1D tensor of shape [num_classes] or None
        gamma: focusing parameter
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        logits: Tensor[B, C]
        targets: Tensor[B] with class indices 0 ≤ targets[i] < C
        """
        # move class weights if provided
        if self.alpha is not None:
            self.alpha = self.alpha.to(logits.device)

        # standard CE with no reduction → [B]
        ce = F.cross_entropy(logits, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce)             # [B], pt = probability of the true class
        loss = (1 - pt) ** self.gamma * ce

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss  # [B]


In [13]:
def train_epoch(model, loader, optimizer, scheduler, device, loss_fn):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for batch in loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, C]

        loss = loss_fn(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item() * labels.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    acc = correct / total
    return avg_loss, acc


def eval_epoch(model, loader, device, loss_fn):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    all_logits, all_preds, all_labels = [], [], []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            total_loss += loss.item() * labels.size(0)

            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_logits.append(logits.cpu())
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_loss = total_loss / total
    acc = correct / total

    all_logits = torch.cat(all_logits)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    return avg_loss, acc, all_logits, all_preds, all_labels

In [14]:
def print_class_distribution(labels, label_map):
    """Print the distribution of classes in the dataset."""
    from collections import Counter
    reverse_label_map = {v: k for k, v in label_map.items()}
    
    label_counts = Counter(labels)
    print("\nClass Distribution:")
    print("-" * 50)
    for label_id, count in sorted(label_counts.items()):
        class_name = reverse_label_map.get(label_id, f"Unknown_{label_id}")
        percentage = (count / len(labels)) * 100
        print(f"{class_name}: {count} ({percentage:.1f}%)")


In [15]:
def apply_binary_threshold_and_report(val_logits, val_true, label_map, threshold=0.5):
    """
    Apply threshold-based prediction for binary classification and print report + confusion matrix.
    """
    # Convert logits to probabilities
    probs = torch.softmax(val_logits, dim=-1).cpu().numpy()

    # Get class indices
    idx_to_label = {v: k for k, v in label_map.items()}
    class_names = [idx_to_label[i] for i in range(len(idx_to_label))]
    
    # For binary classification, we can use the probability of class 1
    # If we have more than 2 classes, this needs to be adjusted
    if len(class_names) == 2:
        # Get index for the second class (typically index 1)
        class_idx = 1
        class_probs = probs[:, class_idx]
        
        # Binary prediction based on threshold
        preds = (class_probs > threshold).astype(int)
        
        # Convert true labels to binary format matching our predictions
        binary_true = (val_true.numpy() == class_idx).astype(int)
        
        # Print classification report
        print(f"\n✅ Classification Report (Thresholded @ {threshold:.2f}):")
        print(classification_report(binary_true, preds, target_names=class_names, zero_division=0))
        
        # Plot confusion matrix
        cm = confusion_matrix(binary_true, preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.title("Confusion Matrix")
        plt.tight_layout()
        plt.savefig('confusion_matrix.png')
        plt.close()
    else:
        print("This function is designed for binary classification only.")


In [16]:
def main(data_folder, variable_code="V241049"):
    # Target variable and label filtering
    # V241049 is "WHO WOULD R VOTE FOR: HARRIS VS TRUMP"
    include_classes = ['Donald Trump', 'Kamala Harris']

    # Load data
    examples, label_map = load_data(data_folder, variable_code, include_classes=include_classes)

    # Split texts and labels
    texts = [ex[0] for ex in examples]
    labels = [ex[1] for ex in examples]

    # Print class distribution
    print_class_distribution(labels, label_map)

    # Train-test split with stratification
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Initialize tokenizer
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    # Create datasets and dataloaders
    train_dataset = ANESDataset(train_texts, train_labels, tokenizer)
    val_dataset = ANESDataset(val_texts, val_labels, tokenizer)
    
    # Use a smaller batch size if memory is an issue
    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize model
    num_labels = len(label_map)
    model = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Compute class weights for handling imbalance
    classes = np.unique(train_labels)
    weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_labels)
    weights_tensor = torch.tensor(weights, dtype=torch.float)
    
    # Create loss function with class weights
    loss_fn = FocalLoss(alpha=weights_tensor, gamma=2.0)

    # Optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=2e-5)
    total_steps = len(train_loader) * 4  # num_epochs = 4
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
    )

    # Training loop
    for epoch in range(1, 5):
        print(f"\nEpoch {epoch}")
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fn)
        val_loss, val_acc, val_logits, val_preds, val_labels = eval_epoch(model, val_loader, device, loss_fn)
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    print("\n✅ Training completed.")
    
    # Evaluate with different thresholds
    print("\nEvaluating with different thresholds:")
    for threshold in [0.3, 0.4, 0.5, 0.6, 0.7]:
        apply_binary_threshold_and_report(val_logits, val_labels, label_map, threshold=threshold)
    
    # Save the model
    torch.save(model.state_dict(), 'anes_classifier_model.pt')
    print("\nModel saved to 'anes_classifier_model.pt'")

if __name__ == '__main__':
    # Replace with your data folder path
    data_folder = "/home/tsultanov/shared/datasets/respondents"
    main(data_folder)


Processing 3349 JSON files for variable V241049
Progress: 0/3349 files processed


Progress: 500/3349 files processed
Progress: 1000/3349 files processed
Progress: 1500/3349 files processed
Progress: 2000/3349 files processed
Progress: 2500/3349 files processed
Progress: 3000/3349 files processed

📊 Summary for variable V241049:
  ➤ Total JSON files: 3349
  ➤ Valid examples collected: 2959
  ➤ Unique labels: 2
  ➤ Skipped due to missing answers: 0
  ➤ Skipped due to exclusion list: 34
  ➤ Skipped (not in include_classes): 356
  ➤ Included only: ['Donald Trump', 'Kamala Harris']
  ➤ Final label map: {'Donald Trump': 0, 'Kamala Harris': 1}

🔍 Class distribution (label IDs): Counter({1: 1623, 0: 1336})
  ➤ 'Donald Trump': 1336 samples
  ➤ 'Kamala Harris': 1623 samples

Class Distribution:
--------------------------------------------------
Donald Trump: 1336 (45.2%)
Kamala Harris: 1623 (54.8%)


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1
  Train - Loss: 0.1242, Acc: 0.7592
  Val   - Loss: 0.1057, Acc: 0.8868

Epoch 2
  Train - Loss: 0.0891, Acc: 0.9020
  Val   - Loss: 0.1273, Acc: 0.8936

Epoch 3
  Train - Loss: 0.0840, Acc: 0.9117
  Val   - Loss: 0.1049, Acc: 0.8902

Epoch 4
  Train - Loss: 0.0779, Acc: 0.9142
  Val   - Loss: 0.1003, Acc: 0.8902

✅ Training completed.

Evaluating with different thresholds:

✅ Classification Report (Thresholded @ 0.30):
               precision    recall  f1-score   support

 Donald Trump       0.86      0.88      0.87       267
Kamala Harris       0.90      0.89      0.89       325

     accuracy                           0.89       592
    macro avg       0.88      0.89      0.88       592
 weighted avg       0.89      0.89      0.89       592


✅ Classification Report (Thresholded @ 0.40):
               precision    recall  f1-score   support

 Donald Trump       0.86      0.90      0.88       267
Kamala Harris       0.91      0.88      0.90       325

     accuracy       