# Explicit Memory Modules in RL: Neural Turing Machines (NTMs) and Differentiable Neural Computers (DNCs)

## Motivation
Problems with RNNs and LSTMs:
- Hidden state has fixed size
- Memory is implicitly stored - hard to query or manipulate
- Long-term memory is fragile and vanishes during backpropagation
We need:
- A system that can store, access, and update information over long time spans
- External memory that grows in capacity
- Differentiable operations for end-to-end training
Leads to architectures like:
- Neural Turing Machines (NTMs)
- Differentiable Neural Computers (DNCs)

## Neural Turing Machine (NTM)
### Architecture
- Controller: Neural Network, usually RNN/LSTM that drives memory operations
- Memory Matrix: External memory $M \in \mathbb{R}^{N \times W}$
- Read/Write heads: Intergace to memory - emit weights to access locations
key idea is to parameterize a neural network with differentiable read/write access to memory.

### Math
**Memory Matrix**:\
At time t, memory is:
$$M_t \in \mathbb{R}^{N \times W}$$
- N: number of memory slots
- W: width (size of each memory vector)

**Read Head**\
Produces a read weight vector $w_t^r \in \Delta^N$ (a probability distribution over N locations).
The read vector is:
$$r_t = \sum_{i=1}^N w_t^r(i)M_t(i)$$
So the read operation is a weighted sum of memory rows.\
**Write Head**\
The write process has two steps:
1. Erase:\
Apply an erase vector $e_t \in [0,1]^W$ and write weights $w_t^w \in \Delta^N$:
$$M'_t(i) = M_{t-1}(i) \cdot (1-w_t^w(i)e_t)$$
2. Add:\
Apply an add vector $a_t \in \mathbb{R}^W$:
$$M_t(i) = M'_t(i) + w_t^w(i)a_t$$
Together, they enable seletive overwriting of memory.\
**Addressing Mechanism**\
Heads emit a key $k_t$ and strength $\beta_t$, and compute similarity with memory:
$$w_t^e(i)=\frac{\exp(\beta_t \cdot (k_t, M(i)))}{\sum_{j=1}\exp(\beta_t \cdot (k_t, M(j)))}$$
(similarity is usually cosine similarity)\
The content-based addressing can be enhanced with:
- Location-based shifts: shift the read/write heads to nearby locations
- Sharpening: Make soft weights sharper (more deterministic)

## Differentiable Neural Computers (DNC)
DNC = NTM + Improvements\
Key Improvements:
- Multiple read heads
- Temporal memory linkage matrix $L_t$: tracks order of writes
- Usage vector: tracks which memory locations are free
- Dynamic memory allocation via least-used slot selection
- Better gradient flow and scalability
### Temporal Link Matrix
$L_t(i,j) \in [0,1]$ tracks that memory location i was written aften j.\
Allow backeard and forward traversal through memory, crucial for reasoing over sequences.\
This update is:
$$L_t(i,j) = (1-w_t^w(i)-w_t^w(j))\cdot L_{t-1}(i,j)+w_t^w(i) \cdot prev_t(j)$$
Where $prev_t$ is the previously written location.

## Applications in RL
Explicit memory modules can imrpve RL agents by enabling:
- One-shot learning
- Reasoning over temporal events
- Learning algorithms, not just reactive policies
- Handling partial observability and long-term dependencies

## Implementation

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym

In [3]:
class NTMReadHead(nn.Module):
    def __init__(self, memory_units, memory_unit_size, controller_dim):
        super().__init__()
        self.key_proj = nn.Linear(controller_dim, memory_unit_size)
        self.beta_proj = nn.Linear(controller_dim, 1)

    def forward(self, memory, controller_state):
        key = torch.tanh(self.key_proj(controller_state))
        beta = F.softplus(self.beta_proj(controller_state))

        norm_mem = F.normalize(memory, dim=1)
        norm_key = F.normalize(key.unsqueeze(1), dim=1)
        sim = torch.sum(norm_mem * norm_key, dim=-1)

        weights = F.softmax(beta*sim, dim=-1)
        read_vec = torch.bmm(weights.unsqueeze(1), memory).squeeze(1)
        return read_vec, weights

class NTMWriteHead(nn.Module):
    def __init__(self, memory_units, memory_unit_size, controller_dim):
        super().__init__()
        self.key_proj = nn.Linear(controller_dim, memory_unit_size)
        self.beta_proj = nn.Linear(controller_dim, 1)
        self.erase_proj = nn.Linear(controller_dim, memory_unit_size)
        self.add_proj = nn.Linear(controller_dim, memory_unit_size)

    def forward(self, memory, controller_state):
        key = torch.tanh(self.key_proj(controller_state))
        beta = F.softplus(self.beta_proj(controller_state))
        erase = torch.sigmoid(self.erase_proj(controller_state))
        add = torch.tanh(self.add_proj(controller_state))

        norm_mem = F.normalize(memory, dim=1)
        norm_key = F.normalize(key.unsqueeze(1), dim=1)
        sim = torch.sum(norm_mem * norm_key, dim=-1)

        weights = F.softmax(beta*sim, dim=-1)
        erase_matrix = torch.bmm(weights.unsqueeze(-1), erase.unsqueeze(1))
        add_matrix = torch.bmm(weights.unsqueeze(-1), add.unsqueeze(1))

        memory = memory * (1-erase_matrix) + add_matrix
        return memory, weights

class NeuralTuringMachine(nn.Module):
    def __init__(self, input_dim, controller_dim, output_dim, memory_units=128, memory_unit_size=20):
        super().__init__()
        self.controller = nn.LSTMCell(input_dim + memory_unit_size, controller_dim)
        self.read_head = NTMReadHead(memory_units, memory_unit_size, controller_dim)
        self.write_head = NTMWriteHead(memory_units, memory_unit_size, controller_dim)
        self.output_layer = nn.Linear(controller_dim + memory_unit_size, output_dim)

        self.memory_units = memory_units
        self.memory_unit_size = memory_unit_size

    def forward(self, x, memory, state):
        h,c,read_vec = state
        controller_input = torch.cat([x, read_vec], dim=-1)
        h,c = self.controller(controller_input, (h,c))

        read_vec, _ = self.read_head(memory, h)
        memory, _ = self.write_head(memory, h)

        output = self.output_layer(torch.cat([h, read_vec], dim=-1))
        return output, memory, (h,c, read_vec)

In [5]:
## Minimal Integration
input_dim = 8
controller_dim = 64
output_dim = 8
memory_units = 16
memory_unit_size = 20

ntm = NeuralTuringMachine(input_dim, controller_dim, output_dim, memory_units, memory_unit_size)

batch_size = 1  # Set a valid batch size
x = torch.randn(batch_size, input_dim)
memory = torch.zeros(batch_size, memory_units, memory_unit_size)
h = torch.zeros(batch_size, controller_dim)
c = torch.zeros(batch_size, controller_dim)
read_vec = torch.zeros(batch_size, memory_unit_size)
state = (h, c, read_vec)

output, update_memory, new_state = ntm(x, memory, state)

print("Output:", output)
print("Memory diff norm:", (update_memory - memory).norm())


Output: tensor([[ 0.0206, -0.1185,  0.1093, -0.0591, -0.0119,  0.0607, -0.0992, -0.0116]],
       grad_fn=<AddmmBackward0>)
Memory diff norm: tensor(0.0904, grad_fn=<LinalgVectorNormBackward0>)


In [7]:
## Sequence Copying Taks for NTM

SEQ_LEN = 5
INPUT_DIM = 8
OUTPUT_DIM = INPUT_DIM
BATCH_SIZE = 1
EPOCHS = 500

ntm = NeuralTuringMachine(INPUT_DIM, controller_dim, OUTPUT_DIM, memory_units, memory_unit_size)
optimizer = optim.Adam(ntm.parameters(), lr=0.01)
loss_fn = nn.BCEWithLogitsLoss()

def generate():
    seq = torch.bernoulli(torch.full((BATCH_SIZE, SEQ_LEN, INPUT_DIM), 0.5))
    return seq, seq.clone()

for epoch in range(EPOCHS):
    x_seq, y_seq = generate()
    memory = torch.zeros(BATCH_SIZE, 128, 20)
    h = torch.zeros(BATCH_SIZE, 64)
    c = torch.zeros(BATCH_SIZE, 64)
    read_vec = torch.zeros(BATCH_SIZE, 20)
    state = (h, c, read_vec)

    loss = 0
    outputs = []
    for t in range(SEQ_LEN):
        out, memory, state = ntm(x_seq[:, t], memory, state)
        outputs.append(out)

    outputs = torch.stack(outputs, dim=1)
    loss = loss_fn(outputs, y_seq)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1)%50 == 0:
        print(f'Epoch [{epoch + 1}/{EPOCHS}], Loss: {loss.item():.4f}')

Epoch [50/500], Loss: 0.3286
Epoch [100/500], Loss: 0.0588
Epoch [150/500], Loss: 0.0202
Epoch [200/500], Loss: 0.0110
Epoch [250/500], Loss: 0.0074
Epoch [300/500], Loss: 0.0067
Epoch [350/500], Loss: 0.0048
Epoch [400/500], Loss: 0.0034
Epoch [450/500], Loss: 0.0024
Epoch [500/500], Loss: 0.0021


In [11]:
## Integrate NTM-Augmented RL Agent
class NTMPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim, controller_dim=64, memory_units=32, memory_unit_size=20):
        super().__init__()
        self.ntm = NeuralTuringMachine(obs_dim, controller_dim, action_dim, memory_units, memory_unit_size)
        self.actor = nn.Linear(controller_dim+memory_unit_size, action_dim)
        self.critic = nn.Linear(controller_dim+memory_unit_size, 1)

    def forward(self, obs, memory, state):
        ntm_out, memory, state = self.ntm(obs, memory,state)
        h, _, read_vec = state
        features = torch.cat([h, read_vec], dim=-1)
        logits = self.actor(features)
        value = self.critic(features)
        return logits, value, memory, state
    
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = NTMPolicy(obs_dim, action_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.001)

def reinforce_train(episodes=500):
    gamma=0.99
    for ep in range(episodes):
        obs = env.reset()[0]
        memory = torch.zeros(1,32,20)
        state = (torch.zeros(1,64), torch.zeros(1,64), torch.zeros(1,20))
        log_probs = []
        rewards = []
        total = 0

        done = False
        while not done:
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
            logits, _, memory, state = policy(obs_tensor, memory, state)
            dist = torch.distributions.Categorical(logits.softmax(dim=-1))
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
            obs, reward, truncated, terminated, _ = env.step(action.item())
            done = truncated or terminated
            rewards.append(reward)
            total += reward

        returns = []
        R = 0
        for r in reversed(rewards):
            R = r+gamma*R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns-returns.mean())/(returns.std()+1e-8)

        loss = -sum(lp * G for lp, G in zip(log_probs, returns))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (ep+1) % 50 == 0:
            print(f'Episode {ep+1}, Total Reward: {total}')

reinforce_train()

Episode 50, Total Reward: 14.0
Episode 100, Total Reward: 10.0
Episode 150, Total Reward: 67.0
Episode 200, Total Reward: 59.0
Episode 250, Total Reward: 22.0
Episode 300, Total Reward: 103.0
Episode 350, Total Reward: 320.0
Episode 400, Total Reward: 213.0
Episode 450, Total Reward: 270.0
Episode 500, Total Reward: 359.0
