# üß† Hybrid Titans Memory: Augmenting LLMs with Learning-Based Memory

## Introduction
Standard Large Language Models (LLMs) like GPT-4 or Mistral suffer from a "Fixed Context Window." Once text scrolls off the top, it is gone forever. RAG (Retrieval Augmented Generation) helps, but it is static database retrieval, not true "learning."

**Titans**, a new architecture from Google Research, proposes a "Neural Memory" that *learns* context in real-time.

In this notebook, we will implement a slightly simplified **Hybrid Architecture**:
1.  **The Frozen Brain**: A standard pre-trained LLM (GPT-2 for speed, scalable to Llama-3).
2.  **The Learning Sidecar**: A tiny Neural Network that "watches" the LLM's thoughts and updates its own weights using **Test-Time Training (TTT)**.

### The Objective
We will teach the model 3 random facts. Then, we will **delete** the conversation history (Input Context). Finally, we will ask the model to answer a question. If it answers correctly, it means the information didn't come from the prompt‚Äîit came from the **Neural Memory weights**.

Let's build it! üöÄ

In [1]:
# 1. Environment Setup
# We need PyTorch for the memory module and Transformers for the LLM.
# 'accelerate' and 'bitsandbytes' are highly recommended for loading larger models efficiently.

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import copy

# Determine if we have a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Running on device: {device}")

# Set seeds for reproducibility
torch.manual_seed(42)

# If you need to install packages, uncomment below:
# !pip install transformers accelerate bitsandbytes

‚úÖ Running on device: cuda


<torch._C.Generator at 0xee9378dcba90>

## üß† 2. The Neural Memory Module (Sidecar)

This is the core innovation. Instead of just "storing" vectors (like a vector DB), we train a neural network to memorize them.

### "Surprise" Metric
We define learning as minimizing **Surprise**.
If the LLM sees a new concept (like "The sky is green"), its internal state will be "surprised" (Novelty).
Our Memory Module attempts to **predict** or **reconstruct** this hidden state.
-   If reconstruction is good -> Low Surprise (No learning needed).
-   If reconstruction is bad -> High Surprise (Update weights!).

We will implement **Test-Time Training (TTT)**. The `memorize()` function runs a backward pass *during inference*.

In [None]:
class NeuralMemory(nn.Module):
    """
    A Neural Memory Module (The "Sidecar").
    
    This acts as a dynamic, learnable memory that runs alongside the frozen LLM.
    Functionally, it is a simple Autoencoder or MLP that maps an input states (query)
    to a memory context.
    
    Key Feature:
    It contains its own optimizer. This allows it to update its weights 
    on-the-fly (Test-Time Training) based on the context of the current conversation,
    effectively "memorizing" new information in its weights.
    """
    def __init__(self, input_dim, memory_dim=None, learning_rate=0.01):
        super().__init__()
        
        # If no specific memory dimension is given, we keep it same as input
        # flexible for compression or expansion
        if memory_dim is None:
            memory_dim = input_dim 
            
        # The Architecture: A simple Encoder-Decoder style network
        # 1. Compress/Transform input to memory space
        self.encoder = nn.Linear(input_dim, memory_dim)
        # 2. Non-linearity to capture complex relationships
        self.activation = nn.GELU() 
        # 3. Project back to input space (or context space)
        self.decoder = nn.Linear(memory_dim, input_dim)
        
        # Internal Optimizer:
        # Standard PyTorch models don't usually hold their own optimizer.
        # We do this here to encapsulate the "Learning" capability within the module itself.
        # SGD is used here for simplicity and stability in small batch updates.
        self.optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
        
    def forward(self, x):
        """
        Forward pass: Generates a 'Memory Context' from the input query.
        """
        encoded = self.encoder(x)
        activated = self.activation(encoded)
        reconstruction = self.decoder(activated)
        return reconstruction

    def memorize(self, target_state):
        """
        The core learning mechanism: Test-Time Training (TTT).
        
        This function performs a single gradient descent step to minimize 
        reconstruction error (Surprise) on a given target state.
        
        Args:
            target_state: The vector we want the memory to 'remember'.
        """
        self.train()
        self.optimizer.zero_grad()
        
        # Detach target to ensure we don't backpropagate into the entity generating the target
        target = target_state.detach()
        
        # Try to predict/reconstruct the target
        reconstruction = self.forward(target)
        
        # Calculate 'Surprise' (Loss): How different is our memory's prediction from reality?
        loss = F.mse_loss(reconstruction, target)
        
        # Update weights to reduce surprise next time
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

## ü§ñ 3. Initialize the Frozen LLM

We will use **GPT-2** (Small) for this demonstration. It is fast, lightweight, and perfect for testing concepts.
*   **Frozen**: We will set `requires_grad = False` for the LLM. It will *not* change.
*   **Tokenizer**: Standard GPT-2 tokenizer.

In [3]:
# Load Model & Tokenizer
model_name = "gpt2" # Can swap with "TinyLlama/TinyLlama-1.1B-Chat-v1.0" if GPU permits
print(f"‚¨áÔ∏è Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
llm = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Freeze the LLM
for param in llm.parameters():
    param.requires_grad = False

print("‚ùÑÔ∏è LLM parameters frozen.")
    
# Get Hidden Dimension Size (e.g., 768 for GPT-2)
hidden_dim = llm.config.n_embd
print(f"üìè Hidden Dimension: {hidden_dim}")

# Initialize our Trainable Memory Sidecar
memory_module = NeuralMemory(input_dim=hidden_dim).to(device)
print("üß† Neural Memory initialized.")

‚¨áÔ∏è Loading gpt2...


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    


‚ùÑÔ∏è LLM parameters frozen.
üìè Hidden Dimension: 768
üß† Neural Memory initialized.


## ‚öôÔ∏è 4. The Hybrid Inference Engine

This is the glue that binds them. We need a custom generation loop.
Unlike `model.generate()`, we need to inspect the internals step-by-step.

**The Loop:**
1.  **Embed**: Term text -> Vectors.
2.  **Recall**: Pass Vectors -> Memory -> `memory_context`.
3.  **Mix**: Combine `[Embeddings + memory_context]` (Soft Prompting).
4.  **Forward**: Run LLM.
5.  **Learn**: Take the output hidden state, calculate surprise, and call `memory.memorize()`.

*Note: For simplicity in this demo, we will perform the memory update on the INPUT embeddings to "remember input facts", a common simplification for these demos.*

In [4]:
import random

class HybridTitansEngine:
    def __init__(self, llm, memory, tokenizer):
        self.llm = llm
        self.memory = memory
        self.tokenizer = tokenizer
        
    def process_and_learn(self, text, steps=20):
        inputs = self.tokenizer(text, return_tensors="pt").to(device)
        input_ids = inputs.input_ids
        seq_len = input_ids.shape[1]
        
        self.memory.train()
        losses = []
        
        for _ in range(steps):
             self.memory.optimizer.zero_grad()
             
             with torch.no_grad():
                 embeds = self.llm.transformer.wte(input_ids)
            
             min_len = 3 
             if seq_len > min_len:
                 cut_point = random.randint(min_len, seq_len)
                 query_embeds = embeds[:, :cut_point, :]
             else:
                 query_embeds = embeds
                 
             query = query_embeds.mean(dim=1).detach()
             
             soft_memory = self.memory(query)
             
             # Inject
             soft_memory = soft_memory.unsqueeze(1) 
             combined_embeds = torch.cat([soft_memory, embeds], dim=1)
             
             # Labels
             ignore_token = torch.full((1, 1), -100, dtype=torch.long, device=device)
             combined_labels = torch.cat([ignore_token, input_ids], dim=1)
             
             outputs = self.llm(inputs_embeds=combined_embeds, labels=combined_labels)
             
             loss = outputs.loss
             loss.backward()
             self.memory.optimizer.step()
             losses.append(loss.item())
             
        return sum(losses) / len(losses)

    def generate(self, prompt, max_new_tokens=20):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
        input_ids = inputs.input_ids 
        
        # Ensure we don't exceed model context
        input_ids = input_ids[:, -1020:] 
        
        embeds = self.llm.transformer.wte(input_ids)
        query = embeds.mean(dim=1)
        
        self.memory.eval()
        with torch.no_grad():
            memory_context = self.memory(query)
        
        memory_context = memory_context.unsqueeze(1)
        current_embeds = torch.cat([memory_context, embeds], dim=1)
        
        generated_ids = []
        for _ in range(max_new_tokens):
            outputs = self.llm(inputs_embeds=current_embeds)
            logits = outputs.logits[:, -1, :] 
            
            # Greedy search
            next_token_id = torch.argmax(logits, dim=-1).unsqueeze(0)
            
            generated_ids.append(next_token_id.item())
            if next_token_id.item() == self.tokenizer.eos_token_id:
                break
                
            next_embed = self.llm.transformer.wte(next_token_id)
            current_embeds = torch.cat([current_embeds, next_embed], dim=1)
            
        return self.tokenizer.decode(generated_ids)

# Robust parameters: High LR for one-shot/few-shot learning
memory_module = NeuralMemory(input_dim=hidden_dim, learning_rate=0.04).to(device)
engine = HybridTitansEngine(llm, memory_module, tokenizer)
print("‚öôÔ∏è Optimized Engine Ready.")

‚öôÔ∏è Optimized Engine Ready.


## üß™ 5. Demo: The "Total Recall" Experiment

We will now perform the experiment.

**Phase 1: Learning**
We will feed the engine 3 random facts.
It will use `process_and_learn` to update `NeuralMemory` weights.

**Phase 2: Context Clearance**
We will NOT pass these facts into the generation prompt. The context window is effectively empty.

**Phase 3: Testing**
We ask the model questions. If it answers, it retrieved the concept from the Neural Network weights.

In [6]:
# The Facts
facts = [
    "The secret project code is Omega-99.",
    "The CEO's favorite fruit is a apple.", 
    "My shopping list is eggs, milk, bread", 
    "The meeting is at 4:32 PM exactly."
]

print("üü¶ Phase 1: Learning...")
# We use Interleaved Training to prevent Catastrophic Forgetting.
# We increase epochs significantly to force 'Overfitting' on these specific facts.

n_epochs = 60 
for epoch in range(n_epochs):
    epoch_loss = 0
    random.shuffle(facts) # Semantic Mixing
    for fact in facts:
        # We train with fewer steps per 'visit' but more visits overall
        loss = engine.process_and_learn(fact, steps=10) 
        epoch_loss += loss
    
    if epoch % 10 == 0:
        print(f"   Epoch {epoch}: Avg Surprise {epoch_loss/3:.4f}")

print("   (Learning Complete)")
print("\n-------------------------------------------------\n")

print("üü• Phase 2: Clearing Context...")
# We do nothing here. The 'facts' variable is just a python list. 
print("   (Brain Wiped. Only Neural Sidecar retains weights.)")

print("\n-------------------------------------------------\n")

print("üü© Phase 3: Testing Recall...")

questions = [
    "The secret project code is",
    "The CEO's favorite fruit is",
    "My shopping list is",
    "The meeting is at"
]

for q in questions:
    print(f"\nQuestion: {q}")
    answer = engine.generate(q, max_new_tokens=15)
    print(f"Titans Answer: {answer}")
    
print("\n-------------------------------------------------\n")
print("üéâ Experiment Complete. Check if the answers match the facts!")

üü¶ Phase 1: Learning...
   Epoch 0: Avg Surprise 2.8367
   Epoch 10: Avg Surprise 0.0601
   Epoch 20: Avg Surprise 0.0187
   Epoch 30: Avg Surprise 0.0106
   Epoch 40: Avg Surprise 0.0074
   Epoch 50: Avg Surprise 0.0057
   (Learning Complete)

-------------------------------------------------

üü• Phase 2: Clearing Context...
   (Brain Wiped. Only Neural Sidecar retains weights.)

-------------------------------------------------

üü© Phase 3: Testing Recall...

Question: The secret project code is
Titans Answer:  Omega-99.

A secret project is a secret project secret.

Question: The CEO's favorite fruit is
Titans Answer:  a apple. apple. apple.org is a apple.org is a

Question: My shopping list is
Titans Answer:  eggs, milk, bread, bread and milk.] My shopping list:

Question: The meeting is at
Titans Answer:  4:32 PM exactly. the second, it's not.



-------------------------------------------------

üéâ Experiment Complete. Check if the answers match the facts!
