<a href="https://colab.research.google.com/github/weagan/Engram/blob/main/Copy2_of_Engram_Conditional_Memory_Module_Setup.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell gOaM79sX72tR
import torch
import torch.nn as nn
import torch.nn.functional as F

class EngramModule(nn.Module):
    def __init__(self, table_size=100000, d_model=512, n_heads=4):
        super().__init__()
        self.table_size = table_size
        self.d_model = d_model
        self.n_heads = n_heads

        # The Static Memory table
        self.memory_table = nn.Parameter(torch.randn(table_size, d_model))

        # Context-Aware Gating
        self.gate = nn.Linear(d_model, 1)

        # Multi-head projection
        self.merge_proj = nn.Linear(d_model, d_model)

    def multi_head_hash(self, input_ids):
        """Generates O(1) indices for the memory table"""
        # We create different 'views' of the memory per head
        hashes = [(input_ids * (i + 13)) % self.table_size for i in range(self.n_heads)]
        return torch.stack(hashes, dim=-1) # [Batch, Seq, Heads]

    # FIXED: Added 'def' here
    def forward(self, hidden_states, input_ids):
        batch_size, seq_len, _ = hidden_states.shape

        # Get indices
        indices = self.multi_head_hash(input_ids)

        # Retrieve from Memory Table using functional embedding for speed
        # We flatten to look up, then reshape back
        retrieved_mem = F.embedding(indices, self.memory_table) # [B, S, n_heads, d_model]

        # Aggregate the heads (Simple mean)
        retrieved_mem = retrieved_mem.mean(dim=2)

        # Apply Gating
        gate_score = torch.sigmoid(self.gate(hidden_states))
        gated_memory = retrieved_mem * gate_score

        # Merge back into the Transformer stream
        output = hidden_states + self.merge_proj(gated_memory)
        return output

# --- Verification ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EngramModule(n_heads=1).to(device)

# Mock inputs
mock_ids = torch.randint(0, 1000, (2, 8)).to(device)
mock_hidden = torch.randn(2, 8, 512).to(device)

# Test forward pass
out = model(mock_hidden, mock_ids)
print(f"Success! Output shape: {out.shape}")

Success! Output shape: torch.Size([2, 8, 512])


In [None]:
# Cell 88998a27
print('Demonstrating the value of EngramModule:')

# 1. Manually set specific entries in the memory_table
#    For demonstration, we'll set a few distinct patterns.
#    We use .data to modify the underlying tensor of the nn.Parameter directly.
#    With n_heads=1, multi_head_hash for input_id 'x' is (x * 13) % table_size.
model.memory_table.data[0, :] = 1.0  # input_id 0 -> index 0
model.memory_table.data[13, :] = 0.0 # input_id 1 -> index 13
model.memory_table.data[26, :] = -1.0 # input_id 2 -> index 26

print("Memory table entries 0, 13, and 26 have been set to distinct values.")

# --- Modify gate and merge_proj to be identity for direct observation ---
# Modify the gate to effectively be an identity (output ~1.0 after sigmoid)
# Set weights to 0 and bias to a large positive number
model.gate.weight.data.zero_()
model.gate.bias.data.fill_(10.0) # Large bias -> sigmoid output close to 1

# Modify the merge_proj to be an identity mapping
# Set weight to identity matrix and bias to zero
model.merge_proj.weight.data.copy_(torch.eye(model.d_model))
model.merge_proj.bias.data.zero_()
print("Gating and Merge Projection modified to be identity mappings.")
# ------------------------------------------------------------------


# 2. Prepare mock inputs
#    - mock_hidden: neutral (all zeros) hidden states
#    - mock_ids: input IDs that will query the specific memory locations we set
mock_hidden_demo = torch.zeros(1, 3, model.d_model).to(device) # Batch=1, Seq=3
mock_ids_demo = torch.tensor([[0, 1, 2]]).to(device) # Query input_ids 0, 1, 2

print(f"Initial mock hidden state (first element of first item): {mock_hidden_demo[0, 0, 0].item():.4f}")
print(f"Mock input IDs: {mock_ids_demo.tolist()}")

# 3. Perform forward pass
output_demo = model(mock_hidden_demo, mock_ids_demo)

# 4. Analyze output
#    We expect the output to reflect the values from the memory table, especially
#    since the initial hidden_states were zeros and the gate should allow information flow.

print("\nOutput after EngramModule forward pass:")
print("--------------------------------------")

# Print a slice of the output to observe the influence
# For input_id=0, we expect output to be influenced by 'all ones'
print(f"Output for input_id 0 (first 5 dims): {[round(x, 4) for x in output_demo[0, 0, :5].tolist()]}")
# For input_id=1, we expect output to be influenced by 'all zeros'
print(f"Output for input_id 1 (first 5 dims): {[round(x, 4) for x in output_demo[0, 1, :5].tolist()]}")
# For input_id=2, we expect output to be influenced by 'all negative ones'
print(f"Output for input_id 2 (first 5 dims): {[round(x, 4) for x in output_demo[0, 2, :5].tolist()]}")

print("\nObservation: The output hidden states for each position reflect the distinct values manually set in the memory table for their corresponding input IDs. This demonstrates how the EngramModule can inject specific, contextually relevant information from its memory into the main data stream, even when the initial hidden state is neutral.")

Demonstrating the value of EngramModule:
Memory table entries 0, 13, and 26 have been set to distinct values.
Gating and Merge Projection modified to be identity mappings.
Initial mock hidden state (first element of first item): 0.0000
Mock input IDs: [[0, 1, 2]]

Output after EngramModule forward pass:
--------------------------------------
Output for input_id 0 (first 5 dims): [1.0, 1.0, 1.0, 1.0, 1.0]
Output for input_id 1 (first 5 dims): [0.0, 0.0, 0.0, 0.0, 0.0]
Output for input_id 2 (first 5 dims): [-1.0, -1.0, -1.0, -1.0, -1.0]

Observation: The output hidden states for each position reflect the distinct values manually set in the memory table for their corresponding input IDs. This demonstrates how the EngramModule can inject specific, contextually relevant information from its memory into the main data stream, even when the initial hidden state is neutral.


### Conceptual Training Loop Steps:

1.  **Initialize Model**: Create an instance of your `EngramModule`.
2.  **Initialize Loss Function**: Choose an appropriate loss function for your task.
3.  **Initialize Optimizer**: Tell the optimizer which model parameters to update.
4.  **Loop over epochs**: An epoch is one full pass through your entire training dataset.
    a.  **Loop over batches**: Divide your data into smaller batches.
        i.  **Forward Pass**: Feed a batch of inputs through the `EngramModule` to get predictions.
        ii. **Calculate Loss**: Compare predictions to true targets using the loss function.
        iii. **Zero Gradients**: Clear out old gradients from the previous step.
        iv. **Backward Pass (Backpropagation)**: Calculate gradients (how much each parameter contributed to the loss).
        v.  **Optimizer Step**: Adjust model parameters based on the gradients to reduce the loss.


In [None]:
# Cell 53b3fff4
import torch.optim as optim

# Re-initialize a fresh EngramModule for training demonstration
# We will use the default n_heads=4 to show a more typical scenario
model_for_training = EngramModule(n_heads=4).to(device)

# --- 1. Create Mock Data for Demonstration ---
# Let's say we want to learn to store and retrieve specific numerical vectors for given input IDs
# We'll create a simple mapping: input_id -> target_vector

# Number of unique items we want to store/retrieve
num_items = 5

# Generate mock input_ids (e.g., word tokens)
mock_train_ids = torch.randint(0, 100, (32, num_items)).to(device) # Batch size 32, seq_len 5

# Generate mock hidden_states (e.g., from a preceding Transformer layer)
mock_train_hidden = torch.randn(32, num_items, model_for_training.d_model).to(device)

# Generate target 'embeddings' that the EngramModule should output when given these ids
# For simplicity, target embedding for input_id=k is a vector of all 'k's
target_embeddings = torch.zeros(32, num_items, model_for_training.d_model).to(device)
for i in range(num_items):
    # For the i-th position in the sequence, let its target be `mock_train_ids[batch_idx, i]`
    # We'll use the actual input ID as the target scalar for this simplified example
    # This means the model should learn to output a vector of 'input_id's
    current_input_ids_batch = mock_train_ids[:, i].unsqueeze(1) # [Batch, 1]
    target_val = current_input_ids_batch.float() # [Batch, 1]
    target_embeddings[:, i, :] = target_val.expand(-1, model_for_training.d_model)

print(f"Mock Training Data Prepared:\n  Input IDs shape: {mock_train_ids.shape}\n  Hidden States shape: {mock_train_hidden.shape}\n  Target Embeddings shape: {target_embeddings.shape}")

# --- 2. Define Loss Function ---
# Mean Squared Error (MSE) is suitable for learning to output specific target vectors
criterion = nn.MSELoss()

# --- 3. Define Optimizer ---
# Adam optimizer will update the learnable parameters of our model
optimizer = optim.Adam(model_for_training.parameters(), lr=0.001)

# --- 4. Training Loop (minimal example) ---
num_epochs = 500 # Increased epochs for better demonstration

print(f"\nStarting a simple training loop for {num_epochs} epochs...")

for epoch in range(num_epochs):
    model_for_training.train() # Set the model to training mode

    # Forward pass
    output = model_for_training(mock_train_hidden, mock_train_ids)

    # Calculate loss
    loss = criterion(output, target_embeddings)

    # Zero gradients
    optimizer.zero_grad()

    # Backward pass
    loss.backward()

    # Optimizer step
    optimizer.step()

    if (epoch + 1) % 50 == 0 or epoch == 0: # Print loss every 50 epochs or at the start
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

print("\nTraining complete!\n")

# --- Optional: Verify learning (conceptually) ---
# After training, the memory_table entries should have changed to better represent the target outputs
# for the hash-mapped input_ids. We can inspect a retrieved output.

# Set model to evaluation mode
model_for_training.eval()
with torch.no_grad(): # No need to calculate gradients during inference
    # Take the first item from our mock data
    sample_hidden = mock_train_hidden[0:1, :, :]
    sample_ids = mock_train_ids[0:1, :]
    sample_target = target_embeddings[0:1, :, :]

    retrieved_output = model_for_training(sample_hidden, sample_ids)

    print("Verification after training (first sample):")
    print(f"  Input ID (first element): {sample_ids[0, 0].item()}")
    print(f"  Target (first element, first 5 dims): {[round(x, 4) for x in sample_target[0, 0, :5].tolist()]}")
    print(f"  Output (first element, first 5 dims): {[round(x, 4) for x in retrieved_output[0, 0, :5].tolist()]}")
    print("\nObservation: The output after even a few epochs should show values closer to the target values compared to random initialization, indicating that the EngramModule's parameters (especially the memory table) have started to learn.")

Mock Training Data Prepared:
  Input IDs shape: torch.Size([32, 5])
  Hidden States shape: torch.Size([32, 5, 512])
  Target Embeddings shape: torch.Size([32, 5, 512])

Starting a simple training loop for 500 epochs...
Epoch 1/500, Loss: 3501.6450
Epoch 50/500, Loss: 3203.0532
Epoch 100/500, Loss: 2375.9204
Epoch 150/500, Loss: 1376.3779
Epoch 200/500, Loss: 631.4255
Epoch 250/500, Loss: 238.7297
Epoch 300/500, Loss: 82.6029
Epoch 350/500, Loss: 28.9877
Epoch 400/500, Loss: 10.9647
Epoch 450/500, Loss: 4.7411
Epoch 500/500, Loss: 2.4118

Training complete!

Verification after training (first sample):
  Input ID (first element): 77
  Target (first element, first 5 dims): [77.0, 77.0, 77.0, 77.0, 77.0]
  Output (first element, first 5 dims): [77.7014, 77.5675, 77.1594, 77.327, 76.8516]

Observation: The output after even a few epochs should show values closer to the target values compared to random initialization, indicating that the EngramModule's parameters (especially the memory table

In [None]:
# Cell 4a28b389
import math

# Create a new instance of EngramModule for this demonstration
# We set n_heads=1 to avoid multi-head averaging for simplicity
model_direct = EngramModule(n_heads=1).to(device)

print("--- Modified EngramModule for Direct Retrieval (Animal Embeddings) ---")

# 1. Manually set specific entries in the memory_table with distinct numerical vectors
#    representing 'dog', 'cat', 'horse', 'cow'.
#    Based on multi_head_hash:
#    input_id 0 -> index 0
#    input_id 1 -> index 13
#    input_id 2 -> index 26
#    input_id 3 -> index 39
model_direct.memory_table.data[0, :] = 1.0  # Represents 'dog' (all ones)
model_direct.memory_table.data[13, :] = 2.0 # Represents 'cat' (all twos)
model_direct.memory_table.data[26, :] = 3.0 # Represents 'horse' (all threes)
model_direct.memory_table.data[39, :] = 4.0 # Represents 'cow' (all fours)

print("Memory table entries 0, 13, 26, and 39 have been set to distinct values.")
print("These vectors conceptually represent 'dog' (1s), 'cat' (2s), 'horse' (3s), and 'cow' (4s).")

# 2. Modify the gate to effectively be an identity (output ~1.0 after sigmoid)
# Set weights to 0 and bias to a large positive number
model_direct.gate.weight.data.zero_()
model_direct.gate.bias.data.fill_(10.0) # Large bias -> sigmoid output close to 1

print("Gating mechanism modified to pass through most of the memory.")

# 3. Modify the merge_proj to be an identity mapping
# Set weight to identity matrix and bias to zero
model_direct.merge_proj.weight.data.copy_(torch.eye(model_direct.d_model))
model_direct.merge_proj.bias.data.zero_()

print("Merge projection modified to be an identity mapping.")

# Prepare mock inputs (neutral hidden states and input_ids for our animals)
mock_hidden_demo_direct = torch.zeros(1, 4, model_direct.d_model).to(device) # 4 animals
mock_ids_demo_direct = torch.tensor([[0, 1, 2, 3]]).to(device) # Input IDs to query the animals

# Perform forward pass with the modified model
output_direct = model_direct(mock_hidden_demo_direct, mock_ids_demo_direct)

print("\nOutput after Modified EngramModule forward pass (Animal Embeddings):")
print("------------------------------------------------------------------")

# Print a slice of the output to observe the influence
print(f"Output for input_id 0 (representing 'dog', first 5 dims):")
for val in output_direct[0, 0, :5].tolist():
    print(f"  {val}")

print(f"Output for input_id 1 (representing 'cat', first 5 dims):")
for val in output_direct[0, 1, :5].tolist():
    print(f"  {val}")

print(f"Output for input_id 2 (representing 'horse', first 5 dims):")
for val in output_direct[0, 2, :5].tolist():
    print(f"  {val}")

print(f"Output for input_id 3 (representing 'cow', first 5 dims):")
for val in output_direct[0, 3, :5].tolist():
    print(f"  {val}")

print("\nObservation: The output values for each input ID now closely reflect the distinct numerical vectors stored (")
print("1.0 for 'dog', 2.0 for 'cat', 3.0 for 'horse', and 4.0 for 'cow' respectively),")
print("demonstrating the conceptual storage and retrieval of 'embeddings' for words.")

--- Modified EngramModule for Direct Retrieval (Animal Embeddings) ---
Memory table entries 0, 13, 26, and 39 have been set to distinct values.
These vectors conceptually represent 'dog' (1s), 'cat' (2s), 'horse' (3s), and 'cow' (4s).
Gating mechanism modified to pass through most of the memory.
Merge projection modified to be an identity mapping.

Output after Modified EngramModule forward pass (Animal Embeddings):
------------------------------------------------------------------
Output for input_id 0 (representing 'dog', first 5 dims):
  0.9999545812606812
  0.9999545812606812
  0.9999545812606812
  0.9999545812606812
  0.9999545812606812
Output for input_id 1 (representing 'cat', first 5 dims):
  1.9999091625213623
  1.9999091625213623
  1.9999091625213623
  1.9999091625213623
  1.9999091625213623
Output for input_id 2 (representing 'horse', first 5 dims):
  2.999863624572754
  2.999863624572754
  2.999863624572754
  2.999863624572754
  2.999863624572754
Output for input_id 3 (repr