<a href="https://colab.research.google.com/github/sivaratrisrinivas/ttt-playground/blob/main/notebooks/all_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/sivaratrisrinivas/ttt-playground/blob/main/notebooks/all_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TTT Playground - All Tests

Combined notebook for all phase tests. Sections:
1. **Setup** - Clone, install, verify GPU
2. **Phase 2** - Document Processing (PDF, Chunker, Validator)
3. **Phase 3** - TTT-Linear Layer

---
# 1. Setup

In [39]:
# Clone repo (or pull latest if exists)
import os
if os.path.exists('/content/ttt-playground'):
    !cd /content/ttt-playground && git pull
    %cd /content/ttt-playground
else:
    !git clone https://github.com/sivaratrisrinivas/ttt-playground.git
    %cd ttt-playground

# IMPORTANT: if this runtime previously imported src.*, force reload after git pull
import importlib
import sys
importlib.invalidate_caches()
for _m in [m for m in list(sys.modules.keys()) if m == 'src' or m.startswith('src.')]:
    del sys.modules[_m]
print('✓ Cleared cached src.* modules')

import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))
print(f"✓ Working directory: {os.getcwd()}")

remote: Enumerating objects: 15, done.[K
remote: Counting objects:   6% (1/15)[Kremote: Counting objects:  13% (2/15)[Kremote: Counting objects:  20% (3/15)[Kremote: Counting objects:  26% (4/15)[Kremote: Counting objects:  33% (5/15)[Kremote: Counting objects:  40% (6/15)[Kremote: Counting objects:  46% (7/15)[Kremote: Counting objects:  53% (8/15)[Kremote: Counting objects:  60% (9/15)[Kremote: Counting objects:  66% (10/15)[Kremote: Counting objects:  73% (11/15)[Kremote: Counting objects:  80% (12/15)[Kremote: Counting objects:  86% (13/15)[Kremote: Counting objects:  93% (14/15)[Kremote: Counting objects: 100% (15/15)[Kremote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects:  50% (1/2)[Kremote: Compressing objects: 100% (2/2)[Kremote: Compressing objects: 100% (2/2), done.[K
remote: Total 8 (delta 5), reused 8 (delta 5), pack-reused 0 (from 0)[K
Unpacking objects:  12% (1/8)Unpacking objects:  25% (2/8)Unpacking objects: 

In [40]:
# Install dependencies
!pip install -q -r requirements.txt
print("✓ Dependencies installed")

✓ Dependencies installed


In [41]:
# Verify GPU
!nvidia-smi
import torch
print(f"\nCUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Sun Jan 11 23:44:19 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   64C    P0             30W /   70W |    7490MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [42]:
# Verify all imports
import torch
import transformers
import fitz  # PyMuPDF
import gradio
import tiktoken
import tqdm
from loguru import logger
import pydantic
print("✓ All imports successful!")

✓ All imports successful!


---
# 2. Phase 2: Document Processing

In [43]:
# Import document processing modules
from src.document.pdf_parser import PDFParser, PDFExtractionError
from src.document.chunker import DocumentChunker
from src.document.validator import DocumentValidator
from src.config import DocumentConstraints, DocumentChunk
from transformers import AutoTokenizer
import fitz
print("✓ Document processing imports successful")

✓ Document processing imports successful


## 2.1 Generate Test PDFs

In [44]:
def create_test_pdf(filename: str, num_pages: int, text_per_page: str):
    """Create a test PDF with specified pages and text"""
    doc = fitz.open()
    for i in range(num_pages):
        page = doc.new_page()
        page.insert_text((50, 50), f"Page {i+1}")
        page.insert_text((50, 100), text_per_page)
    doc.save(filename)
    doc.close()
    print(f"Created {filename} ({num_pages} pages)")

text_short = "This is a short test document. " * 50
create_test_pdf("test_short.pdf", 3, text_short)

text_medium = "This is a medium test document with more content. " * 100
create_test_pdf("test_medium.pdf", 20, text_medium)

with open("test_corrupt.pdf", "wb") as f:
    f.write(b"not a valid pdf file")

print("\n✓ Test PDFs created")

Created test_short.pdf (3 pages)
Created test_medium.pdf (20 pages)

✓ Test PDFs created


## 2.2-2.3 PDFParser Tests

In [45]:
parser = PDFParser()

# Test valid PDF
with open("test_short.pdf", "rb") as f:
    pdf_bytes = f.read()

text, page_count = parser.parse(pdf_bytes)
print(f"✓ Parsed test_short.pdf:")
print(f"  - Pages: {page_count}")
print(f"  - Text length: {len(text)} chars")
assert page_count > 0 and len(text) > 0

# Test error handling
try:
    with open("test_corrupt.pdf", "rb") as f:
        parser.parse(f.read())
    assert False, "Should have raised PDFExtractionError"
except PDFExtractionError:
    print("✓ Error handling works")

✓ Parsed test_short.pdf:
  - Pages: 3
  - Text length: 374 chars
✓ Error handling works


## 2.4-2.6 DocumentChunker Tests

In [46]:
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
chunker = DocumentChunker(tokenizer, chunk_size=2048)
print(f"✓ Chunker initialized with chunk_size={chunker.chunk_size}")

# Test short text (single chunk)
short_text = "This is a short text. " * 10
chunks_short = chunker.chunk(short_text)
print(f"✓ Short text: {len(chunks_short)} chunk(s)")

# Test large text (multiple chunks)
large_text = "word " * 5000
chunks_large = chunker.chunk(large_text)
print(f"✓ Large text (~5000 tokens): {len(chunks_large)} chunks")
for i, chunk in enumerate(chunks_large):
    assert chunk.token_count <= 2048, f"Chunk {i} exceeds limit"

# Verify token preservation
original_ids = tokenizer.encode(large_text, add_special_tokens=False)
reconstructed_ids = []
for chunk in chunks_large:
    reconstructed_ids.extend(chunk.token_ids)
assert reconstructed_ids == original_ids, "Token preservation failed!"
print(f"✓ Token preservation verified: {len(original_ids)} tokens")

Token indices sequence length is longer than the specified maximum sequence length for this model (5001 > 2048). Running this sequence through the model will result in indexing errors


✓ Chunker initialized with chunk_size=2048
✓ Short text: 1 chunk(s)
✓ Large text (~5000 tokens): 3 chunks
✓ Token preservation verified: 5001 tokens


## 2.7 DocumentValidator Tests

In [47]:
validator = DocumentValidator()

with open("test_short.pdf", "rb") as f:
    pdf_bytes = f.read()

# Test valid (relaxed constraints)
is_valid, _ = validator.validate(pdf_bytes, DocumentConstraints(min_tokens=50))
assert is_valid, "Should pass relaxed validation"
print("✓ Valid PDF passes")

# Test max_pages violation
is_valid, msg = validator.validate(pdf_bytes, DocumentConstraints(max_pages=2, min_tokens=50))
assert not is_valid
print(f"✓ max_pages violation detected: {msg}")

# Test min_tokens violation
is_valid, msg = validator.validate(pdf_bytes, DocumentConstraints(min_tokens=500))
assert not is_valid
print(f"✓ min_tokens violation detected: {msg}")

# Test corrupt PDF
with open("test_corrupt.pdf", "rb") as f:
    is_valid, msg = validator.validate(f.read(), DocumentConstraints(min_tokens=50))
assert not is_valid
print(f"✓ Corrupt PDF rejected: {msg}")

print("\n" + "="*50)
print("✓ ALL PHASE 2 TESTS PASSED!")
print("="*50)

✓ Valid PDF passes
✓ max_pages violation detected: Page count (3) exceeds maximum (2)
✓ min_tokens violation detected: Estimated token count (93) below minimum (500)
✓ Corrupt PDF rejected: Invalid PDF: Failed to extract text from PDF: Failed to open stream

✓ ALL PHASE 2 TESTS PASSED!


---
# 3. Phase 3: TTT-Linear Layer

## 3.1 Import models package

In [48]:
from src.models import *
print("✓ Step 3.1: from src.models import * succeeds")

✓ Step 3.1: from src.models import * succeeds


## 3.2 TTTLinear.__init__

In [49]:
from src.models.ttt_linear import TTTLinear
import importlib
import src.models.ttt_linear as _ttt_linear
importlib.reload(_ttt_linear)

layer = TTTLinear(768, 2048, 768)
print(f"✓ TTTLinear instantiated")
print(f"  W_h.shape: {layer.W_h.shape}")
assert layer.W_h.shape == (2048, 768)
print("✓ Step 3.2: W_h.shape == (2048, 768) verified")

✓ TTTLinear instantiated
  W_h.shape: torch.Size([2048, 768])
✓ Step 3.2: W_h.shape == (2048, 768) verified


## 3.3 TTTLinear.forward (inference mode)

In [50]:
import torch

x = torch.randn(1, 128, 768)
y = layer(x, learning=False)
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {y.shape}")
assert y.shape == (1, 128, 768)
print("✓ Step 3.3: Output shape [1, 128, 768] verified")

  Input shape: torch.Size([1, 128, 768])
  Output shape: torch.Size([1, 128, 768])
✓ Step 3.3: Output shape [1, 128, 768] verified


## 3.4 Initial weights stored for reset

In [51]:
assert hasattr(layer, '_W_h_initial'), "Missing _W_h_initial attribute"
assert torch.allclose(layer.W_h, layer._W_h_initial)
print("✓ Step 3.4: _W_h_initial stored and matches W_h")

✓ Step 3.4: _W_h_initial stored and matches W_h


## 3.5 TTTLinear.forward (learning mode)

In [52]:
layer = TTTLinear(768, 2048, 768)
w_before = layer.W_h.clone()

x = torch.randn(1, 128, 768)
y = layer(x, learning=True)

assert not torch.allclose(layer.W_h, w_before), "W_h should change after learning=True"
print("✓ Step 3.5: W_h differs from initial after learning=True")

✓ Step 3.5: W_h differs from initial after learning=True


## 3.6 reset_weights()

In [53]:
layer.reset_weights()
assert torch.allclose(layer.W_h, layer._W_h_initial)
print("✓ Step 3.6: reset_weights() restores initial W_h")

✓ Step 3.6: reset_weights() restores initial W_h


## 3.7 get_weight_delta()

In [54]:
layer = TTTLinear(768, 2048, 768)
x = torch.randn(1, 128, 768)
layer(x, learning=True)

delta = layer.get_weight_delta()
print(f"  Weight delta: {delta}")
assert delta > 0
print("✓ Step 3.7: get_weight_delta() > 0 after learning")

  Weight delta: 0.06187882274389267
✓ Step 3.7: get_weight_delta() > 0 after learning


## 3.8 Gradient flow

In [55]:
layer = TTTLinear(768, 2048, 768)
x = torch.randn(1, 128, 768, requires_grad=True)
y = layer(x, learning=False)
loss = y.sum()
loss.backward()

assert x.grad is not None
print("✓ Step 3.8: Gradient flows through layer")

print("\n" + "="*50)
print("✓ ALL PHASE 3 TESTS PASSED!")
print("="*50)

✓ Step 3.8: Gradient flows through layer

✓ ALL PHASE 3 TESTS PASSED!


---
# 4. Phase 4: TinyLlama Integration

## 4.1 TTTModel class skeleton

In [56]:
from src.models.ttt_model import TTTModel
print('✓ Step 4.1: TTTModel class imports successfully')

✓ Step 4.1: TTTModel class imports successfully


## 4.2 TTTModel.from_pretrained() - load TinyLlama

In [57]:
# Load TinyLlama with TTT layers
model = TTTModel.from_pretrained(
    model_name='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
    device='cuda'
)
print('✓ Step 4.2: TTTModel loaded')

# Test generate
output = model.generate('Hello', max_new_tokens=20)
print(f'  Generated: {output[:100]}...')

✓ Step 4.2: TTTModel loaded
  Generated: Hello futureailableailableailable future future futurereetpreviewielleadowspreviewielleadowspreviewi...


## 4.3 Identify MLP layers in TinyLlama

In [58]:
# Print MLP layer names
print('MLP layers in TinyLlama:')
for name, module in model.model.named_modules():
    if 'mlp' in name.lower():
        print(f'  {name}: {type(module).__name__}')
print('✓ Step 4.3: MLP layers identified')

MLP layers in TinyLlama:
  model.layers.0.mlp: TTTLinear
  model.layers.0.mlp.W_out: Linear
  model.layers.0.mlp.activation: SiLU
  model.layers.1.mlp: TTTLinear
  model.layers.1.mlp.W_out: Linear
  model.layers.1.mlp.activation: SiLU
  model.layers.2.mlp: TTTLinear
  model.layers.2.mlp.W_out: Linear
  model.layers.2.mlp.activation: SiLU
  model.layers.3.mlp: TTTLinear
  model.layers.3.mlp.W_out: Linear
  model.layers.3.mlp.activation: SiLU
  model.layers.4.mlp: TTTLinear
  model.layers.4.mlp.W_out: Linear
  model.layers.4.mlp.activation: SiLU
  model.layers.5.mlp: TTTLinear
  model.layers.5.mlp.W_out: Linear
  model.layers.5.mlp.activation: SiLU
  model.layers.6.mlp: TTTLinear
  model.layers.6.mlp.W_out: Linear
  model.layers.6.mlp.activation: SiLU
  model.layers.7.mlp: TTTLinear
  model.layers.7.mlp.W_out: Linear
  model.layers.7.mlp.activation: SiLU
  model.layers.8.mlp: TTTLinear
  model.layers.8.mlp.W_out: Linear
  model.layers.8.mlp.activation: SiLU
  model.layers.9.mlp: TTTLinea

## 4.4-4.5 Replace MLP with TTTLinear

In [59]:
# Check that TTT layers were installed
print(f'Number of TTT layers: {len(model.ttt_layers)}')
assert len(model.ttt_layers) > 0, 'Should have TTT layers'
print('✓ Step 4.4-4.5: TTT layers installed')

# Test forward pass still works
output2 = model.generate('The capital of France is', max_new_tokens=10)
print(f'  Generated: {output2}')

Number of TTT layers: 22
✓ Step 4.4-4.5: TTT layers installed
  Generated: The capital of France isirieведеuclideʻuclidetransformictionaryсь️ictionary


## 4.6 All MLP layers replaced

In [60]:
# TinyLlama has 22 transformer layers
num_layers = len(model.ttt_layers)
print(f'TTT layers: {num_layers}')
# Note: May be fewer if we only replace subset
print('✓ Step 4.6: TTT layer count verified')

TTT layers: 22
✓ Step 4.6: TTT layer count verified


## 4.7 reset_learning()

In [61]:
# Get initial delta (should be 0)
delta_before = model.get_total_weight_delta()
print(f'  Delta before learning: {delta_before}')

# Simulate learning by calling forward with learning=True
# (This would normally be done via learn_from_chunks)
model.reset_learning()
delta_after_reset = model.get_total_weight_delta()
print(f'  Delta after reset: {delta_after_reset}')
assert delta_after_reset == 0, 'Delta should be 0 after reset'
print('✓ Step 4.7: reset_learning() works')

  Delta before learning: 0.0
  Delta after reset: 0.0
✓ Step 4.7: reset_learning() works


## 4.8 clear_context()

In [62]:
# Generate, clear, generate again
out1 = model.generate('Test', max_new_tokens=5)
model.clear_context()
out2 = model.generate('Test', max_new_tokens=5)
print(f'  Before clear: {out1}')
print(f'  After clear: {out2}')
print('✓ Step 4.8: clear_context() works')

print('\n' + '='*50)
print('✓ ALL PHASE 4 TESTS PASSED!')
print('='*50)

  Before clear: Test terminalitaireútirie transformations
  After clear: Test terminalitaireútirie transformations
✓ Step 4.8: clear_context() works

✓ ALL PHASE 4 TESTS PASSED!


---
# 5. Phase 5: LaCT (Large Chunk TTT)

## 5.1 LaCTUpdater class skeleton

In [63]:
from src.models.lact import LaCTUpdater
print('✓ Step 5.1: LaCTUpdater class imports successfully')

✓ Step 5.1: LaCTUpdater class imports successfully


## 5.2 process_chunk() - forward + loss

In [64]:
# Create updater
updater = LaCTUpdater(model)

# Create dummy tokens (2048 tokens)
dummy_tokens = list(range(100, 2148))  # 2048 token IDs

# Process chunk
loss = updater.process_chunk(dummy_tokens)
print(f'  Loss: {loss}')
assert isinstance(loss, float), 'Loss should be a float'
print('✓ Step 5.2: process_chunk() returns scalar loss')

  Loss: 12.58847713470459
✓ Step 5.2: process_chunk() returns scalar loss


## 5.3 Gradient accumulation

In [65]:
# Check accumulated gradients exist
has_grads = any(g is not None for g in updater._accumulated_grads)
print(f'  Has accumulated grads: {has_grads}')
assert has_grads, 'Should have accumulated gradients after process_chunk'
print('✓ Step 5.3: Gradients accumulated')

  Has accumulated grads: True
✓ Step 5.3: Gradients accumulated


## 5.4 apply_update()

In [66]:
# Reset model first
model.reset_learning()
delta_before = model.get_total_weight_delta()
print(f'  Delta before update: {delta_before}')

# Process and apply
updater.reset()
loss = updater.process_chunk(dummy_tokens)
updater.apply_update()

delta_after = model.get_total_weight_delta()
print(f'  Delta after update: {delta_after}')
assert delta_after > 0, 'Weight delta should be > 0 after update'
print('✓ Step 5.4: apply_update() changes weights')

  Delta before update: 0.0
  Delta after update: 0.01703643798828125
✓ Step 5.4: apply_update() changes weights


## 5.5 process_document()

In [67]:
from src.config import DocumentChunk

# Reset model
model.reset_learning()
updater.reset()

# Create 5 dummy chunks
chunks = []
for i in range(5):
    chunk = DocumentChunk(
        index=i,
        text=f'Chunk {i} ' * 100,
        token_ids=list(range(100 + i*500, 600 + i*500)),
        token_count=500
    )
    chunks.append(chunk)

# Process document
metrics = updater.process_document(chunks)
print(f'  Metrics: {metrics}')

# Verify loss decreases (or at least exists)
loss_history = updater.get_loss_history()
print(f'  Loss history: {loss_history}')
assert len(loss_history) == 5, 'Should have 5 loss values'

# Check weights changed
delta = model.get_total_weight_delta()
print(f'  Final weight delta: {delta}')
assert delta > 0, 'Weights should have changed'

print('✓ Step 5.5: process_document() works')

print('\n' + '='*50)
print('✓ ALL PHASE 5 TESTS PASSED!')
print('='*50)

  Metrics: {'initial_loss': 12.49267578125, 'final_loss': 12.023128509521484, 'total_chunks': 5}
  Loss history: [12.49267578125, 12.994003295898438, 12.408771514892578, 11.982575416564941, 12.023128509521484]
  Final weight delta: 0.5355224609375
✓ Step 5.5: process_document() works

✓ ALL PHASE 5 TESTS PASSED!


---
# 6. Phase 6: Learning Pipeline

## 6.1 learning package import

In [68]:
from src.learning import *
print('✓ Step 6.1: from src.learning import * succeeds')

✓ Step 6.1: from src.learning import * succeeds


## 6.2 MetricsTracker

In [69]:
from src.learning.metrics import MetricsTracker
m = MetricsTracker()
m.record_loss(0.5)
print(m.loss_history)
assert m.loss_history == [0.5]
print('✓ Step 6.2: MetricsTracker records loss')

[0.5]
✓ Step 6.2: MetricsTracker records loss


## 6.3 MetricsTracker.get_metrics()

In [70]:
from src.learning.metrics import MetricsTracker
m = MetricsTracker()
m.record_loss(2.0)
m.record_loss(1.0)
metrics = m.get_metrics(tokens_processed=123, learning_time_seconds=4.0, weight_delta_norm=0.5)
print(metrics)
assert metrics.initial_loss == 2.0
assert metrics.final_loss == 1.0
print('✓ Step 6.3: get_metrics() returns correct initial/final loss')

initial_loss=2.0 final_loss=1.0 loss_history=[2.0, 1.0] chunks_processed=2 tokens_processed=123 learning_time_seconds=4.0 weight_delta_norm=0.5
✓ Step 6.3: get_metrics() returns correct initial/final loss


## 6.4 TTTTrainer.__init__()

In [71]:
from src.learning.trainer import TTTTrainer
from src.config import LearningConfig

trainer = TTTTrainer(model=model, config=LearningConfig())
print(trainer)
print('✓ Step 6.4: TTTTrainer instantiated')

TTTTrainer(model=TTTModel(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 2048)
      (layers): ModuleList(
        (0-21): 22 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          )
          (mlp): TTTLinear(
            (W_out): Linear(in_features=5632, out_features=2048, bias=False)
            (activation): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((2048,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=2048, out_feat

## 6.5 TTTTrainer.train_on_document()

In [72]:
from src.learning.trainer import TTTTrainer
from src.config import LearningConfig, Document, DocumentChunk, DocumentStatus

trainer = TTTTrainer(model=model, config=LearningConfig())

# small dummy doc (3 chunks) to keep runtime low
chunks = []
for i in range(3):
    token_ids = list(range(200 + i*300, 200 + i*300 + 256))
    chunks.append(DocumentChunk(index=i, text=f'chunk {i}', token_ids=token_ids, token_count=len(token_ids)))

doc = Document(id='doc1', filename='dummy', page_count=1, total_tokens=sum(c.token_count for c in chunks), chunks=chunks, status=DocumentStatus.READY)
metrics = trainer.train_on_document(doc)
print(metrics)
assert metrics.chunks_processed == 3
assert metrics.final_loss <= metrics.initial_loss
assert metrics.weight_delta_norm > 0
print('✓ Step 6.5: train_on_document() returns metrics with learning signal')

initial_loss=12.66667366027832 final_loss=12.50511646270752 loss_history=[12.66667366027832, 12.592181205749512, 12.50511646270752] chunks_processed=3 tokens_processed=768 learning_time_seconds=0.720818060999818 weight_delta_norm=0.407012939453125
✓ Step 6.5: train_on_document() returns metrics with learning signal


## 6.6 train_on_document(progress_callback=...)

In [73]:
from src.learning.trainer import TTTTrainer
from src.config import LearningConfig, Document, DocumentChunk, DocumentStatus

trainer = TTTTrainer(model=model, config=LearningConfig())

chunks = []
for i in range(3):
    token_ids = list(range(500 + i*300, 500 + i*300 + 256))
    chunks.append(DocumentChunk(index=i, text=f'chunk {i}', token_ids=token_ids, token_count=len(token_ids)))

doc = Document(id='doc2', filename='dummy2', page_count=1, total_tokens=sum(c.token_count for c in chunks), chunks=chunks, status=DocumentStatus.READY)

calls = []
def cb(chunk_idx, total, loss):
    calls.append((chunk_idx, total, loss))

metrics = trainer.train_on_document(doc, progress_callback=cb)
print('calls:', calls)
assert len(calls) == 3
assert [c[0] for c in calls] == [0, 1, 2]
assert all(c[1] == 3 for c in calls)
assert all(isinstance(c[2], float) for c in calls)
print('✓ Step 6.6: callback called with (chunk_idx, total, loss)')

calls: [(0, 3, 12.905935287475586), (1, 3, 12.775334358215332), (2, 3, 12.204532623291016)]
✓ Step 6.6: callback called with (chunk_idx, total, loss)


---
# 7. Phase 7: Inference

## 7.1 inference package import

In [74]:
from src.inference import *
print('✓ Step 7.1: from src.inference import * succeeds')

✓ Step 7.1: from src.inference import * succeeds


## 7.2 Generator.__init__()

In [75]:
from src.inference.generator import Generator

gen = Generator(model=model, tokenizer=model.tokenizer)
print(f'  model: {type(gen.model).__name__}')
print(f'  tokenizer: {type(gen.tokenizer).__name__}')
print('✓ Step 7.2: Generator instantiated with TTTModel')

  model: TTTModel
  tokenizer: LlamaTokenizerFast
✓ Step 7.2: Generator instantiated with TTTModel


## 7.3 Generator.generate() -> Answer

In [76]:
from src.inference.generator import Generator
from src.config import Answer

gen = Generator(model=model, tokenizer=model.tokenizer)
answer = gen.generate('What is the capital of France?', max_tokens=50, temperature=0.7)

print(f'  text: {answer.text[:100]}...' if len(answer.text) > 100 else f'  text: {answer.text}')
print(f'  tokens_generated: {answer.tokens_generated}')
print(f'  generation_time: {answer.generation_time_seconds:.2f}s')

assert isinstance(answer, Answer), 'Should return Answer'
assert isinstance(answer.text, str), 'text should be string'
assert len(answer.text) > 0, 'text should be non-empty'
print('✓ Step 7.3: Generator.generate() returns Answer with non-empty text')

  text: teammountталиteam recommendationʻ flowalleryétat Teamteam microirie transformations transformationse...
  tokens_generated: 51
  generation_time: 3.44s
✓ Step 7.3: Generator.generate() returns Answer with non-empty text


## 7.4 Generator.compare() -> (ttt_answer, base_answer)

In [77]:
# First do some learning so TTT weights differ from base
from src.learning.trainer import TTTTrainer
from src.config import LearningConfig, Document, DocumentChunk, DocumentStatus

trainer = TTTTrainer(model=model, config=LearningConfig())
chunks = []
for i in range(3):
    token_ids = list(range(800 + i*300, 800 + i*300 + 256))
    chunks.append(DocumentChunk(index=i, text=f'chunk {i}', token_ids=token_ids, token_count=len(token_ids)))
doc = Document(id='compare_test', filename='test', page_count=1, total_tokens=sum(c.token_count for c in chunks), chunks=chunks, status=DocumentStatus.READY)
trainer.train_on_document(doc)
print(f'  Weight delta after learning: {model.get_total_weight_delta():.4f}')

# Now compare
from src.inference.generator import Generator
gen = Generator(model=model, tokenizer=model.tokenizer)
ttt_answer, base_answer = gen.compare('Hello world', max_tokens=20, temperature=0.0)

print(f'  TTT answer: {ttt_answer.text[:50]}...')
print(f'  Base answer: {base_answer.text[:50]}...')

# After compare, learned weights should be restored
delta_after = model.get_total_weight_delta()
print(f'  Weight delta after compare: {delta_after:.4f}')
assert delta_after > 0, 'Learned weights should be restored after compare'

# Both should be Answer objects
from src.config import Answer
assert isinstance(ttt_answer, Answer) and isinstance(base_answer, Answer)
print('✓ Step 7.4: compare() returns two different answers, restores TTT weights')

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


  Weight delta after learning: 0.4151
  TTT answer: future future await future transvaricca transird t...
  Base answer: awaitcollectionscollectionscollectionscollectionsc...
  Weight delta after compare: 0.4151
✓ Step 7.4: compare() returns two different answers, restores TTT weights


In [78]:
print('\n' + '='*50)
print('✓ ALL PHASE 7 TESTS PASSED!')
print('='*50)


✓ ALL PHASE 7 TESTS PASSED!
