# 03 — Hybrid Engine Integration
This notebook wires up a frozen LLM (GPT-2) with the `NeuralMemory` module. The engine implements the read→surprise→learn→recall loop, where memory weights adapt online during inference based on the Surprise signal from the LLM's hidden states.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

In [None]:
# Load frozen LLM (GPT-2 for speed) and tokenizer
model_id = "gpt2"  # or "gpt2-medium" for better quality
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
llm.to(device)
llm.eval()  # frozen

hidden_dim = llm.config.n_embd
print(f"Loaded {model_id}, hidden_dim={hidden_dim}")

In [None]:
# Re-import NeuralMemory from Notebook 2
class NeuralMemory(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, lr: float = 1e-3, device_str: str = None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Always use float32 for stable training (inputs will be cast automatically)
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
        )
        self.to(self.device, torch.float32)
        self.optim = torch.optim.AdamW(self.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()

    @torch.no_grad()
    def recall(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(self.device, torch.float32)
        p = self.net(x)
        return p.detach()

    def memorize(self, x: torch.Tensor, y: torch.Tensor) -> float:
        # Cast float16 inputs from LLM to float32 for stable training
        x = x.to(self.device, torch.float32)
        y = y.to(self.device, torch.float32)
        pred = self.net(x)
        loss = self.loss_fn(pred, y)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return float(loss.item())

# Initialize memory (maps hidden_dim -> hidden_dim soft prompt)
memory = NeuralMemory(input_dim=hidden_dim, hidden_dim=256, output_dim=hidden_dim, lr=5e-4, device_str=device)
print("Memory initialized")

## The Hybrid Loop
1. **Read**: Tokenize input text and run the LLM to get hidden states
2. **Surprise**: Compute prediction error (MSE) between memory's prediction and the actual hidden state
3. **Learn**: Update memory weights via backprop with the Surprise loss
4. **Recall**: Memory generates a soft prompt vector to condition the next step

In [None]:
def run_step_with_memory(text: str, use_memory: bool = True, verbose: bool = True):
    """
    1. Tokenize input
    2. Get LLM hidden states (frozen)
    3. If use_memory: memory.memorize(prev_hidden, current_hidden) to learn surprise
    4. memory.recall(current_hidden) produces soft prompt for next step
    5. Return generated text and surprise loss
    """
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = llm(**inputs, output_hidden_states=True)
    
    # Extract last hidden state from final layer
    hidden_states = outputs.hidden_states[-1]  # shape: (batch, seq_len, hidden_dim)
    last_hidden = hidden_states[:, -1, :]  # shape: (batch, hidden_dim)
    
    surprise_loss = 0.0
    soft_prompt = None
    
    if use_memory:
        # For simplicity: predict last_hidden from itself (circular dependency demo)
        # In a real system, you'd predict *next* hidden from current context
        surprise_loss = memory.memorize(last_hidden, last_hidden)
        soft_prompt = memory.recall(last_hidden)
    
    if verbose:
        print(f"Text: {text[:60]}...")
        print(f"Surprise loss: {surprise_loss:.6f}")
        if soft_prompt is not None:
            print(f"Soft prompt norm: {soft_prompt.norm().item():.4f}")
    
    return last_hidden, surprise_loss, soft_prompt

# Test the step
text = "The Titans architecture enables long-term memory by"
h, loss, sp = run_step_with_memory(text, use_memory=True, verbose=True)

In [None]:
# Multi-step adaptation demo: feed varied sentences and watch surprise decrease
sentences = [
    "The quick brown fox jumps over the lazy dog.",
    "Neural networks learn patterns from data.",
    "Titans use a surprise metric to decide what to remember.",
    "Memory modules can adapt online during inference.",
    "The quick brown fox jumps over the lazy dog.",  # repeat
]

print("=== Multi-step Memory Adaptation ===")
losses = []
for i, sent in enumerate(sentences, 1):
    _, loss, _ = run_step_with_memory(sent, use_memory=True, verbose=False)
    losses.append(loss)
    print(f"Step {i}: loss={loss:.6f}  |  {sent[:50]}")

print(f"\nFirst loss: {losses[0]:.6f}, Last loss: {losses[-1]:.6f}")

# Next steps
In Notebook 4, we'll build an interactive chat demo that memorizes distinct facts, clears the LLM context window, and retrieves memorized information purely from the neural memory weights.