In [2]:
# ============================================
# CELL 0: ULTIMATE PATCH (Ver 3) - RESTART KERNEL FIRST
# ============================================
import sys
import os
import shutil

# Remove old paths
sys.path = [p for p in sys.path if 'bdh-code' not in p.lower()]

CODE_SRC = '/kaggle/input/bdh-code/BDH'
CODE_DST = '/kaggle/working/BDH'

# Fresh copy
if os.path.exists(CODE_DST):
    shutil.rmtree(CODE_DST)
shutil.copytree(CODE_SRC, CODE_DST)

# 1. Patch backstory_parser.py
bp_path = os.path.join(CODE_DST, 'narrative_reasoning', 'backstory_parser.py')
with open(bp_path, 'r') as f:
    content = f.read()
content = content.replace('import re\n', 'import re\nimport torch\n')
with open(bp_path, 'w') as f:
    f.write(content)

# 2. Patch modeling_bdh.py (Added missing get_state/update_state)
mbdh_path = os.path.join(CODE_DST, 'master_bdh', 'modeling_bdh.py')
with open(mbdh_path, 'r') as f:
    lines = f.readlines()

new_lines = []
skip = False
for line in lines:
    if 'from transformers.cache_utils import Cache, CacheLayerMixin' in line:
        new_lines.append('from transformers.cache_utils import Cache\n')
        new_lines.append('try:\n    from transformers.cache_utils import CacheLayerMixin\nexcept ImportError:\n    class CacheLayerMixin: pass\n')
        continue
    if 'class BDHCache(Cache):' in line:
        skip = True
        new_lines.append('''class BDHCache(Cache):
    def __init__(self, config, max_batch_size=None, max_cache_len=None, device=None, dtype=None):
        try: super().__init__()
        except: pass
        self.layers = []
        self.config = config
        self.dtype = dtype or torch.float32
        self._max_batch_size = max_batch_size
        self._max_cache_len = max_cache_len
        self._seen_tokens = 0
    def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
        while len(self.layers) <= layer_idx: self.layers.append(BDHCacheLayer())
        if layer_idx == 0: self._seen_tokens += key_states.shape[-2]
        return key_states, value_states
    def get_seq_length(self, layer_idx=0): return self._seen_tokens
    def detach_(self):
        for layer in self.layers:
            if hasattr(layer, 'recurrent_state') and layer.recurrent_state is not None:
                layer.recurrent_state = layer.recurrent_state.detach()
    def update_state(self, layer_idx, new_state):
        while len(self.layers) <= layer_idx: self.layers.append(BDHCacheLayer())
        self.layers[layer_idx].recurrent_state = new_state.to(torch.float32)
    def get_state(self, layer_idx):
        if layer_idx < len(self.layers): return self.layers[layer_idx].recurrent_state
        return None
    @property
    def max_batch_size(self): return self._max_batch_size
    @property
    def max_cache_len(self): return self._max_cache_len
''')
        continue
    if skip and (line.startswith('class ') or line.startswith('# Note')):
        skip = False
    if not skip: new_lines.append(line)

with open(mbdh_path, 'w') as f:
    f.writelines(new_lines)

# 3. Patch train_pipeline.py (Full file with Dataset class & OOM fix)
tp_path = os.path.join(CODE_DST, 'training', 'train_pipeline.py')
new_tp_content = r'''
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Dict, Optional, List, Tuple
import os
from tqdm import tqdm
import gc

from master_bdh import BDHConfig, BDHForCausalLM
from narrative_reasoning.consistency_classifier import ConsistencyClassifier
from narrative_reasoning.representation_extractor import RepresentationExtractor
from narrative_reasoning.backstory_parser import BackstoryParser, BackstoryEmbedder
from training.pretrain_bdh import pretrain_bdh, NarrativeDataset

class ConsistencyDataset(Dataset):
    def __init__(self, narratives, backstories, labels, tokenizer, narrative_processor, backstory_parser, max_narrative_length=2048):
        self.narratives = narratives
        self.backstories = backstories
        self.labels = labels
        self.tokenizer = tokenizer
        self.narrative_processor = narrative_processor
        self.backstory_parser = backstory_parser
        self.max_narrative_length = max_narrative_length
        self.parsed_backstories = [self.backstory_parser.parse(b) for b in backstories]
    def __len__(self): return len(self.narratives)
    def __getitem__(self, idx):
        return {'narrative': self.narratives[idx], 'backstory': self.backstories[idx], 
                'backstory_claims': self.parsed_backstories[idx], 'label': self.labels[idx]}

def train_consistency_classifier(
    model, classifier, train_dataset, val_dataset=None,
    num_epochs=10, batch_size=4, learning_rate=1e-4, weight_decay=0.01,
    device=None, save_dir=None, use_bf16=True, freeze_bdh=True,
    tokenizer=None, narrative_processor=None
):
    if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache(); gc.collect()
    
    model = model.to(device).eval()
    classifier = classifier.to(device).train()
    
    if freeze_bdh:
        for p in model.parameters(): p.requires_grad = False
            
    base = train_dataset.dataset if hasattr(train_dataset, 'dataset') else train_dataset
    tokenizer = tokenizer or getattr(base, 'tokenizer', None)
    narrative_processor = narrative_processor or getattr(base, 'narrative_processor', None)
    
    def collate_fn(batch):
        return {
            'narrative': [b['narrative'] for b in batch],
            'backstory_claims': [b['backstory_claims'] for b in batch],
            'label': torch.tensor([b['label'] for b in batch], dtype=torch.long)
        }

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size, False, collate_fn=collate_fn, num_workers=0) if val_dataset else None
    
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    backstory_embedder = BackstoryEmbedder(tokenizer, model, device)
    
    best_acc = 0.0
    for epoch in range(num_epochs):
        classifier.train()
        total_loss, correct, total = 0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for batch in pbar:
            labels = batch['label'].to(device)
            narr_reprs = []
            
            for narr in batch['narrative']:
                chunks, _ = narrative_processor.process_narrative(narr)
                if len(chunks)>200: chunks = chunks[:200]
                
                c_reps = []
                for c in chunks:
                    c = c.to(device)
                    if c.shape[-1]>2048: c = c[:, :2048]
                    with torch.no_grad():
                        out = model(input_ids=c, output_hidden_states=True)
                        c_reps.append(out.hidden_states[-1].mean(1).detach().cpu())
                    del out, c
                    torch.cuda.empty_cache()
                
                if c_reps: narr_reprs.append(torch.stack(c_reps).mean(0).to(device))
                else: narr_reprs.append(torch.zeros(model.config.hidden_size).to(device))
            
            back_reprs = [backstory_embedder.aggregate_claims(c) for c in batch['backstory_claims']]
            
            # Forward
            logits, _ = classifier(torch.stack(narr_reprs), torch.stack(back_reprs))
            loss = criterion(logits, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            correct += (logits.argmax(1) == labels).sum().item()
            total += len(labels)
            
            del logits, loss, narr_reprs, back_reprs
            pbar.set_postfix({'acc': correct/total})
            
        # Validation
        if val_loader:
            classifier.eval()
            v_correct, v_total = 0, 0
            with torch.no_grad():
                for batch in tqdm(val_loader, desc="Val"):
                    labels = batch['label'].to(device)
                    narr_reprs = []
                    for narr in batch['narrative']:
                        chunks, _ = narrative_processor.process_narrative(narr)
                        if len(chunks)>200: chunks = chunks[:200]
                        c_reps = [model(input_ids=c.to(device)[:,:2048], output_hidden_states=True).hidden_states[-1].mean(1).detach().cpu() for c in chunks]
                        narr_reprs.append(torch.stack(c_reps).mean(0).to(device) if c_reps else torch.zeros(model.config.hidden_size).to(device))
                    
                    back_reprs = [backstory_embedder.aggregate_claims(c) for c in batch['backstory_claims']]
                    logits, _ = classifier(torch.stack(narr_reprs), torch.stack(back_reprs))
                    v_correct += (logits.argmax(1)==labels).sum().item()
                    v_total += len(labels)
            
            val_acc = v_correct/v_total
            print(f"Epoch {epoch+1}: Train Acc={correct/total:.3f}, Val Acc={val_acc:.3f}")
            if val_acc > best_acc and save_dir:
                best_acc = val_acc
                os.makedirs(save_dir, exist_ok=True)
                torch.save(classifier.state_dict(), os.path.join(save_dir, 'best_classifier.pt'))
                
    return classifier, model
'''
with open(tp_path, 'w') as f:
    f.write(new_tp_content)

sys.path.insert(0, CODE_DST)
print("‚úÖ All patches applied successfully!")

‚úÖ All patches applied successfully!


In [4]:
# ============================================
# CELL 1: Setup and GPU Verification
# ============================================

import os
import sys
import gc
import torch
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Enable H100 optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# Verify GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"bfloat16 Supported: {torch.cuda.is_bf16_supported()}")
else:
    print("WARNING: Running on CPU - will be very slow!")

# Set directories
WORK_DIR = Path("/kaggle/working")
INPUT_DIR = Path("/kaggle/input")
CHECKPOINT_DIR = WORK_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

print(f"\nWorking directory: {WORK_DIR}")
print(f"Input directory: {INPUT_DIR}")

GPU VERIFICATION
Device: cuda
GPU: NVIDIA H100 80GB HBM3
Memory: 85.3 GB
CUDA Version: 12.4
bfloat16 Supported: True

Working directory: /kaggle/working
Input directory: /kaggle/input


In [5]:
# -------------------------------------------------
# HARD GPU STATE RESET
# Purpose:
#   Ensure a clean GPU memory slate before
#   re-instantiating large models.
#
# Why this matters:
#   - Python keeps references alive longer than expected
#   - PyTorch does not immediately return memory to CUDA
#   - Reinitializing models without cleanup ‚Üí silent OOMs
#
# What this block does:
#   1. Explicitly deletes existing model objects (if present)
#   2. Flushes PyTorch's CUDA memory cache
#   3. Forces Python garbage collection to reclaim orphaned tensors
# -------------------------------------------------

# Remove existing model instance from global scope, if it exists
if 'model' in globals():
    print("üßπ Deleting existing model instance from memory")
    del model

# Remove existing classifier instance from global scope, if it exists
if 'classifier' in globals():
    print("üßπ Deleting existing classifier instance from memory")
    del classifier

# Release all cached CUDA memory held by PyTorch
print("üßπ Clearing CUDA memory cache")
torch.cuda.empty_cache()

# Force garbage collection to clean up dangling Python references
print("üßπ Running Python garbage collection")
gc.collect()

print("‚úÖ GPU memory reset complete")

üßπ Clearing CUDA memory cache
üßπ Running Python garbage collection
‚úÖ GPU memory reset complete


In [6]:
# # ============================================
# # CELL 2: Install Dependencies
# # ============================================

# # Uncomment these in Kaggle:
# !pip install -q transformers datasets accelerate
# !pip install -q scikit-learn matplotlib seaborn networkx

# import transformers
# print(f"Transformers version: {transformers.__version__}")

In [7]:
# ============================================
# CELL 3: Add Project to Path & Import
# ============================================
import sys
# For Kaggle, assuming code is uploaded as a dataset
# Adjust this path based on your dataset name
CODE_DATASET_NAME = "bdh-code/BDH"  # CHANGE THIS!
DATA_DATASET_NAME = "kdsh-data/Dataset_kdsh"  # CHANGE THIS!

code_path = INPUT_DIR / CODE_DATASET_NAME
data_path = INPUT_DIR / DATA_DATASET_NAME

# Add code to Python path
if code_path.exists():
    sys.path.insert(0, str(code_path))
    print(f"Added {code_path} to Python path")
else:
    # Fallback: code might be in working directory
    sys.path.insert(0, str(WORK_DIR))
    print(f"Code dataset not found at {code_path}")
    print("Please ensure your code is uploaded as a Kaggle dataset")

Added /kaggle/input/bdh-code/BDH to Python path


In [8]:
# ============================================
# CELL 4: Load Dataset
# ============================================

def load_and_prepare_data(data_dir: Path):
    """Load train/test CSVs and Books directory."""
    
    # Try to find the data
    possible_dirs = [
        data_dir,
        data_dir / "BDH",
        INPUT_DIR,
    ]
    
    train_df = None
    test_df = None
    books_dir = None
    
    for d in possible_dirs:
        train_path = d / "train.csv"
        test_path = d / "test.csv"
        books_path = d / "Books"
        
        if train_path.exists():
            train_df = pd.read_csv(train_path)
            print(f"Found train.csv at {train_path}")
        if test_path.exists():
            test_df = pd.read_csv(test_path)
            print(f"Found test.csv at {test_path}")
        if books_path.exists():
            books_dir = books_path
            print(f"Found Books/ at {books_path}")
    
    return train_df, test_df, books_dir

# Load data
train_df, test_df, books_dir = load_and_prepare_data(data_path)

if train_df is not None:
    print(f"\nTrain samples: {len(train_df)}")
    print(f"Columns: {train_df.columns.tolist()}")
    print("\nSample row:")
    print(train_df.head(1).to_dict('records')[0])

if test_df is not None:
    print(f"\nTest samples: {len(test_df)}")

if books_dir:
    book_files = list(books_dir.glob("*.txt"))
    print(f"\nBooks available: {len(book_files)}")
    print(f"Sample books: {[f.stem for f in book_files[:5]]}")

Found train.csv at /kaggle/input/kdsh-data/Dataset_kdsh/train.csv
Found test.csv at /kaggle/input/kdsh-data/Dataset_kdsh/test.csv
Found Books/ at /kaggle/input/kdsh-data/Dataset_kdsh/Books

Train samples: 80
Columns: ['id', 'book_name', 'char', 'caption', 'content', 'label']

Sample row:
{'id': 46, 'book_name': 'In Search of the Castaways', 'char': 'Thalcave', 'caption': nan, 'content': 'Thalcave‚Äôs people faded as colonists advanced; his father, last of the tribal guides, knew the pampas geography and animal ways, while his mother died giving birth. Boyhood was spent roaming the plains with his father, learning to track, tame horses and steer by the stars.', 'label': 'consistent'}

Test samples: 60

Books available: 2
Sample books: ['In search of the castaways', 'The Count of Monte Cristo']


In [3]:
# ============================================
# CELL 5: Import BDH Components
# ============================================

try:
    from master_bdh import BDHConfig, BDHForCausalLM
    from master_bdh.continual_learning import ContinualLearningWrapper
    from narrative_reasoning.narrative_processor import NarrativeProcessor
    from narrative_reasoning.backstory_parser import BackstoryParser, BackstoryEmbedder
    from narrative_reasoning.representation_extractor import RepresentationExtractor
    from narrative_reasoning.consistency_classifier import ConsistencyClassifier
    from training.pretrain_bdh import pretrain_bdh, NarrativeDataset
    from training.train_pipeline import train_consistency_classifier, ConsistencyDataset
    from utils.data_loader import (
        load_dataset_from_path,
        prepare_training_data,
        prepare_test_data
    )
    print("‚úÖ All BDH components imported successfully")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("Make sure code dataset path is correct")

2026-01-08 22:59:34.211487: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767913174.226552    2983 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767913174.231071    2983 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

‚úÖ All BDH components imported successfully


In [9]:
# ============================================
# CELL 6: Initialize Model and Tokenizer (FIXED)
# ============================================

import torch
import gc
from transformers import AutoTokenizer
from master_bdh import BDHConfig, BDHForCausalLM
from narrative_reasoning.consistency_classifier import ConsistencyClassifier

# ----------------------------
# Hard reset GPU state
# ----------------------------
if 'model' in globals():
    del model
if 'classifier' in globals():
    del classifier

torch.cuda.empty_cache()
gc.collect()

# ----------------------------
# Lean configuration (OOM-proof)
# ----------------------------
CONFIG = {
    'hidden_size': 256,                 # ‚Üì was 512
    'num_hidden_layers': 6,             # ‚Üì was 8
    'num_attention_heads': 4,           # ‚Üì was 8
    'mlp_internal_dim_multiplier': 64,  # ‚Üì was 256 (primary memory hog)
    'dropout': 0.1,
    'max_position_embeddings': 2048,    # realistic context
    'batch_size': 16,
    'use_bf16': True,
}

# ----------------------------
# Tokenizer
# ----------------------------
tokenizer_name = "bert-base-uncased"
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ----------------------------
# BDH Model
# ----------------------------
config = BDHConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=CONFIG['hidden_size'],
    num_hidden_layers=CONFIG['num_hidden_layers'],
    num_attention_heads=CONFIG['num_attention_heads'],
    mlp_internal_dim_multiplier=CONFIG['mlp_internal_dim_multiplier'],
    dropout=CONFIG['dropout'],
    max_position_embeddings=CONFIG['max_position_embeddings'],
    attn_implementation="bdh_recurrent"  # ‚Üì memory-efficient
)

print("Initializing BDH model...")
model = BDHForCausalLM(config)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

# ----------------------------
# Consistency Classifier
# ----------------------------
classifier = ConsistencyClassifier(
    narrative_dim=CONFIG['hidden_size'],
    backstory_dim=CONFIG['hidden_size'],
    hidden_dim=256,
    num_layers=3,
    dropout=0.1
)

# ----------------------------
# Move to device
# ----------------------------
model = model.to(device)
classifier = classifier.to(device)

print(f"‚úÖ Models successfully loaded on {device}")

Loading tokenizer: bert-base-uncased
Initializing BDH model...
Model parameters: 20,396,544
‚úÖ Models successfully loaded on cuda


In [10]:
# ============================================
# CELL 7: Prepare Training Data
# ============================================

# Load narratives from books
narratives, backstories, labels = prepare_training_data(
    train_df, 
    books_dir,
    verbose=True
)

print(f"\nNarrative lengths:")
print(f"  Min: {min(len(n) for n in narratives):,} chars")
print(f"  Max: {max(len(n) for n in narratives):,} chars")
print(f"  Avg: {sum(len(n) for n in narratives)//len(narratives):,} chars")

Loaded 80 training examples
Unique books: 2
Label distribution: 1=51, 0=29

Narrative lengths:
  Min: 826,131 chars
  Max: 2,646,614 chars
  Avg: 1,531,568 chars


In [11]:
# ============================================
# CELL 8: Pretrain BDH on Narratives
# ============================================

from torch.utils.data import random_split

# Create pretraining dataset
print("Creating pretraining dataset...")
pretrain_dataset = NarrativeDataset(
    narratives=narratives,
    tokenizer=tokenizer,
    max_length=2048
)

# Split for validation
train_size = int(0.9 * len(pretrain_dataset))
val_size = len(pretrain_dataset) - train_size
train_ds, val_ds = random_split(pretrain_dataset, [train_size, val_size])
print(f"Pretraining: {len(train_ds)} train, {len(val_ds)} val")

# Pretrain
print("\n" + "=" * 60)
print("STARTING BDH PRETRAINING")
print("=" * 60)

pretrain_bdh(
    model=model,
    train_dataset=train_ds,
    val_dataset=val_ds,
    num_epochs=3,
    batch_size=4,
    learning_rate=1e-4,
    weight_decay=0.1,
    grad_accum_steps=4,
    max_grad_norm=1.0,
    device=device,
    save_dir=str(CHECKPOINT_DIR / "pretrained"),
    use_bf16=CONFIG['use_bf16']
)
print("‚úÖ Pretraining complete!")

Creating pretraining dataset...
Pretraining: 72 train, 8 val

STARTING BDH PRETRAINING


Epoch 1/3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 18/18 [00:22<00:00,  1.26s/it, loss=9.82]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:03<00:00,  1.81s/it]


Epoch 1: Train Loss = 10.0823, Val Loss = 9.8158


Epoch 2/3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 18/18 [00:23<00:00,  1.28s/it, loss=9.58]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:03<00:00,  1.85s/it]


Epoch 2: Train Loss = 9.7002, Val Loss = 9.5581


Epoch 3/3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 18/18 [00:21<00:00,  1.20s/it, loss=9.35]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:03<00:00,  1.85s/it]


Epoch 3: Train Loss = 9.4662, Val Loss = 9.3354
‚úÖ Pretraining complete!


In [21]:
# ============================================
# HOTFIX 4: Redefine Training Function
# ============================================
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import gc
from narrative_reasoning.backstory_parser import BackstoryEmbedder

def train_consistency_classifier(
    model, classifier, train_dataset, val_dataset=None,
    num_epochs=10, batch_size=4, learning_rate=1e-4, weight_decay=0.01,
    device=None, save_dir=None, use_bf16=True, freeze_bdh=True,
    tokenizer=None, narrative_processor=None
):
    if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache(); gc.collect()
    
    model = model.to(device).eval()
    classifier = classifier.to(device).train()
    
    if freeze_bdh:
        for p in model.parameters(): p.requires_grad = False
            
    base = train_dataset.dataset if hasattr(train_dataset, 'dataset') else train_dataset
    tokenizer = tokenizer or getattr(base, 'tokenizer', None)
    narrative_processor = narrative_processor or getattr(base, 'narrative_processor', None)
    
    def collate_fn(batch):
        return {
            'narrative': [b['narrative'] for b in batch],
            'backstory_claims': [b['backstory_claims'] for b in batch],
            'label': torch.tensor([b['label'] for b in batch], dtype=torch.long)
        }

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size, False, collate_fn=collate_fn) if val_dataset else None
    
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    backstory_embedder = BackstoryEmbedder(tokenizer, model, device)
    
    best_acc = 0.0
    for epoch in range(num_epochs):
        classifier.train()
        total_loss, correct, total = 0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for batch in pbar:
            labels = batch['label'].to(device)
            narr_reprs = []
            
            for narr in batch['narrative']:
                chunks, _ = narrative_processor.process_narrative(narr)
                if len(chunks)>200: chunks = chunks[:200]
                
                c_reps = []
                for c in chunks:
                    c = c.to(device)
                    if c.shape[-1]>2048: c = c[:, :2048]
                    with torch.no_grad():
                        out = model(input_ids=c, output_hidden_states=True)
                        # FIX: Squeeze the batch dimension [1, D] -> [D]
                        c_reps.append(out.hidden_states[-1].mean(1).detach().squeeze(0).cpu())
                    del out, c
                    torch.cuda.empty_cache()
                
                # Stack chunks [N, D] -> Mean [D]
                if c_reps: narr_reprs.append(torch.stack(c_reps).mean(0).to(device))
                else: narr_reprs.append(torch.zeros(model.config.hidden_size).to(device))
            
            back_reprs = [backstory_embedder.aggregate_claims(c) for c in batch['backstory_claims']]
            
            # Forward
            logits, _ = classifier(torch.stack(narr_reprs), torch.stack(back_reprs))
            loss = criterion(logits, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            correct += (logits.argmax(1) == labels).sum().item()
            total += len(labels)
            
            del logits, loss, narr_reprs, back_reprs
            pbar.set_postfix({'acc': correct/total})
            
        # Validation
        if val_loader:
            classifier.eval()
            v_correct, v_total = 0, 0
            with torch.no_grad():
                for batch in tqdm(val_loader, desc="Val"):
                    labels = batch['label'].to(device)
                    narr_reprs = []
                    for narr in batch['narrative']:
                        chunks, _ = narrative_processor.process_narrative(narr)
                        if len(chunks)>200: chunks = chunks[:200]
                        # FIX: Squeeze here too
                        c_reps = [model(input_ids=c.to(device)[:,:2048], output_hidden_states=True).hidden_states[-1].mean(1).squeeze(0).detach().cpu() for c in chunks]
                        narr_reprs.append(torch.stack(c_reps).mean(0).to(device) if c_reps else torch.zeros(model.config.hidden_size).to(device))
                    
                    back_reprs = [backstory_embedder.aggregate_claims(c) for c in batch['backstory_claims']]
                    logits, _ = classifier(torch.stack(narr_reprs), torch.stack(back_reprs))
                    v_correct += (logits.argmax(1)==labels).sum().item()
                    v_total += len(labels)
            
            val_acc = v_correct/v_total
            print(f"Epoch {epoch+1}: Train Acc={correct/total:.3f}, Val Acc={val_acc:.3f}")
            if val_acc > best_acc and save_dir:
                best_acc = val_acc
                os.makedirs(save_dir, exist_ok=True)
                torch.save(classifier.state_dict(), os.path.join(save_dir, 'best_classifier.pt'))
                
    return classifier, model

print("‚úÖ Training function patched!")

‚úÖ Training function patched!


In [22]:
# ============================================
# CELL 9: Train Consistency Classifier
# ============================================

# Initialize processors
narrative_processor = NarrativeProcessor(
    tokenizer=tokenizer,
    chunk_size=2048,
    overlap_size=256
)
backstory_parser = BackstoryParser()

# Fix: Override process_narrative to ensure proper chunking
_original_process = narrative_processor.process_narrative
def _fixed_process(text, *args, **kwargs):
    # Tokenize and chunk manually
    tokens = tokenizer(text, return_tensors='pt', truncation=False, max_length=None)['input_ids']
    chunks = []
    chunk_size = 2048
    for i in range(0, tokens.shape[1], chunk_size - 256):  # overlap
        chunk = tokens[:, i:i+chunk_size]
        if chunk.shape[1] > 0:
            chunks.append(chunk)
    if not chunks:
        chunks = [tokens[:, :512]]  # fallback
    return chunks, None
narrative_processor.process_narrative = _fixed_process

# Create consistency dataset
print("Creating consistency dataset...")
consistency_dataset = ConsistencyDataset(
    narratives=narratives,
    backstories=backstories,
    labels=labels,
    tokenizer=tokenizer,
    narrative_processor=narrative_processor,
    backstory_parser=backstory_parser,
    max_narrative_length=2048
)

# Split
train_size = int(0.8 * len(consistency_dataset))
val_size = len(consistency_dataset) - train_size
train_cons_ds, val_cons_ds = random_split(consistency_dataset, [train_size, val_size])
print(f"Classifier training: {len(train_cons_ds)} train, {len(val_cons_ds)} val")

# Train
print("\n" + "=" * 60)
print("STARTING CLASSIFIER TRAINING")
print("=" * 60)

# Quick fix - access base dataset from Subset
train_cons_ds.dataset.tokenizer = tokenizer
train_cons_ds.dataset.narrative_processor = narrative_processor

classifier, model = train_consistency_classifier(
    model=model,
    classifier=classifier,
    train_dataset=train_cons_ds,
    val_dataset=val_cons_ds,
    num_epochs=10,
    batch_size=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    device=device,
    save_dir=str(CHECKPOINT_DIR / "classifier"),
    use_bf16=CONFIG['use_bf16'],
    freeze_bdh=False
)
print("‚úÖ Classifier training complete!")

Creating consistency dataset...
Classifier training: 64 train, 16 val

STARTING CLASSIFIER TRAINING


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:23<00:00, 16.46s/it, acc=0.562]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:03<00:00, 15.89s/it]


Epoch 1: Train Acc=0.562, Val Acc=0.750


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:23<00:00, 16.50s/it, acc=0.594]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:04<00:00, 16.01s/it]


Epoch 2: Train Acc=0.594, Val Acc=0.750


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:25<00:00, 16.57s/it, acc=0.609]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:03<00:00, 15.94s/it]


Epoch 3: Train Acc=0.609, Val Acc=0.750


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:24<00:00, 16.51s/it, acc=0.562]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:04<00:00, 16.00s/it]


Epoch 4: Train Acc=0.562, Val Acc=0.688


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:24<00:00, 16.55s/it, acc=0.641]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:03<00:00, 15.84s/it]


Epoch 5: Train Acc=0.641, Val Acc=0.750


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:20<00:00, 16.29s/it, acc=0.609]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:03<00:00, 15.83s/it]


Epoch 6: Train Acc=0.609, Val Acc=0.750


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:17<00:00, 16.11s/it, acc=0.594]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:01<00:00, 15.47s/it]


Epoch 7: Train Acc=0.594, Val Acc=0.750


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:15<00:00, 15.98s/it, acc=0.641]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:01<00:00, 15.43s/it]


Epoch 8: Train Acc=0.641, Val Acc=0.500


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:08<00:00, 15.54s/it, acc=0.609]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:02<00:00, 15.52s/it]


Epoch 9: Train Acc=0.609, Val Acc=0.750


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [04:13<00:00, 15.83s/it, acc=0.625]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [01:02<00:00, 15.68s/it]

Epoch 10: Train Acc=0.625, Val Acc=0.750
‚úÖ Classifier training complete!





In [28]:
# ============================================
# CELL 10: Run Inference on Test Set (FIXED)
# ============================================

import gc
from tqdm import tqdm

# Clear GPU memory from training
torch.cuda.empty_cache()
gc.collect()

# Load best models
checkpoint_path = CHECKPOINT_DIR / "classifier" / "best_classifier.pt"
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if isinstance(checkpoint, dict) and 'classifier_state_dict' in checkpoint:
        classifier.load_state_dict(checkpoint['classifier_state_dict'])
    else:
        classifier.load_state_dict(checkpoint)
    print("Loaded best checkpoint")

model.eval()
classifier.eval()

# Load test data
test_data = prepare_test_data(test_df, books_dir, verbose=True)

print(f"\nProcessing {len(test_data)} test examples...")
results = []

for example in tqdm(test_data, desc="Inference"):
    try:
        if not example.get('narrative'):
            results.append({'id': example['id'], 'Prediction': 0, 'Rationale': "No narrative"})
            continue
        
        # Process narrative in chunks
        narrative_processor.chunk_size = 512  # Smaller chunks to save memory
        chunks, _ = narrative_processor.process_narrative(example['narrative'])
        
        if len(chunks) > 100:
            chunks = chunks[:100]  # Limit chunks
        
        # Extract narrative representation
        chunk_reprs = []
        with torch.no_grad():
            for chunk_ids in chunks:
                chunk_ids = chunk_ids.to(device)
                if chunk_ids.dim() == 1:
                    chunk_ids = chunk_ids.unsqueeze(0)
                if chunk_ids.shape[-1] > 2048:
                    chunk_ids = chunk_ids[:, :2048]
                
                outputs = model(input_ids=chunk_ids, output_hidden_states=True)
                hidden = outputs.hidden_states[-1]
                rep = hidden.mean(dim=1).squeeze(0)
                chunk_reprs.append(rep.cpu())
                
                del outputs, hidden
                torch.cuda.empty_cache()
        
        if chunk_reprs:
            narr_repr = torch.stack(chunk_reprs).max(dim=0)[0].unsqueeze(0).to(device)
        else:
            narr_repr = torch.zeros(1, model.config.hidden_size, device=device)
        
        # Extract backstory representation
        back_tokens = tokenizer(example['backstory'], return_tensors='pt', 
                                truncation=True, max_length=512)['input_ids'].to(device)
        with torch.no_grad():
            back_out = model(input_ids=back_tokens, output_hidden_states=True)
            back_repr = back_out.hidden_states[-1].mean(dim=1)
        
        # Classify
        with torch.no_grad():
            logits, probs = classifier(narr_repr, back_repr)
            pred = logits.argmax(dim=-1).item()
            conf = probs.max().item()
        
        results.append({
            'id': example['id'],
            'Prediction': pred,
            'Rationale': f"Confidence: {conf:.2f}"
        })
        
        del narr_repr, back_repr, logits, probs
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Error {example['id']}: {e}")
        results.append({'id': example['id'], 'Prediction': 0, 'Rationale': f"Error: {str(e)[:50]}"})
        torch.cuda.empty_cache()

# Save results
results_df = pd.DataFrame(results)
results_df.to_csv(WORK_DIR / "results.csv", index=False)
print(f"\n‚úÖ Saved {len(results)} results")
print(results_df['Prediction'].value_counts())

Loaded best checkpoint
Loaded 60 test examples
Unique books: 2

Processing 60 test examples...


Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 60/60 [03:16<00:00,  3.28s/it]


‚úÖ Saved 60 results
Prediction
1    60
Name: count, dtype: int64





In [25]:
# # Quick fix - inspect the checkpoint keys
# checkpoint = torch.load(checkpoint_path, map_location=device)
# print(checkpoint.keys())  # See what keys exist

odict_keys(['classifier.0.weight', 'classifier.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.5.weight', 'classifier.5.bias', 'classifier.8.weight', 'classifier.8.bias', 'classifier.9.weight', 'classifier.9.bias', 'classifier.12.weight', 'classifier.12.bias'])


In [29]:
# ============================================
# CELL 11: Download Results
# ============================================

# Results are at /kaggle/working/results.csv
print(f"\n{'='*60}")
print("DONE! Results saved to: /kaggle/working/results.csv")
print("Download from Kaggle notebook output")
print(f"{'='*60}")

# Display first few results
print("\nFirst 10 predictions:")
print(results_df.head(10))


DONE! Results saved to: /kaggle/working/results.csv
Download from Kaggle notebook output

First 10 predictions:
    id  Prediction         Rationale
0   95           1  Confidence: 0.56
1  136           1  Confidence: 0.55
2   59           1  Confidence: 0.57
3   60           1  Confidence: 0.57
4  124           1  Confidence: 0.54
5  111           1  Confidence: 0.53
6  135           1  Confidence: 0.54
7   27           1  Confidence: 0.58
8  110           1  Confidence: 0.56
9   42           1  Confidence: 0.57


In [None]:
# That comment is a note from the original BDH authors about potential numerical instability in linear attention. They're saying the recurrent sum computation could overflow/underflow for very long sequences. It's a known limitation, not something broken - just a caveat for extreme use cases.

# About the loss (9.4) - it's actually not terrible for pretraining! Here's context:

# Loss ~9.4 ‚âà perplexity ~12,000 - sounds bad but...
# You're training a small model (20M params) on huge narratives (800K-2.6M chars each)
# Only 72 training samples for 3 epochs
# The vocab is ~30K tokens, so random guessing would give loss ~10.3
# The loss dropped (9.4 ‚Üí 9.3 val) which means it IS learning.

# To improve:

# Train more epochs (10-20)
# Use smaller sequence chunks in 
# NarrativeDataset
#  (512 instead of 2048)
# More data would help most
# But for your task (consistency classification), the pretraining loss doesn't need to be great - you just need the model to learn narrative representations. The classifier training is what matters most for accuracy on the final task.

# Continue to classifier training and see how the classification accuracy looks - that's your actual metric.
