# Strategic Memory Retrieval: An Agent With Active, Learnable Memory

**Strategic Memory Retrieval** is a reinforcement learning (RL) agent designed for **environments where optimal decisions require remembering and leveraging information from the distant past**—not just recent history. The agent maintains an **external, actively-managed episodic memory** where it stores compressed summaries of entire experiences (trajectories) and learns **which memories to retain and which to discard** as training progresses.

### The Problem:

Classic RL agents—like DQN, LSTM PPO, or even transformers—struggle when rewards are delayed, sparse, or depend on events far in the past. In such tasks, remembering the right event, state, or action at the right time is crucial for success. Most RL methods either forget, overfit to short-term cues, or retain irrelevant information, resulting in poor performance on long-horizon, memory-based tasks.

### The Solution:

Strategic Memory Retrieval **actively learns**:

- **What to store:** Which episodes or sequences are worth keeping in memory.
- **What to forget:** Which are unhelpful and can be safely discarded.
- **How to retrieve:** At every decision point, the agent uses attention to retrieve relevant past experiences from memory and integrates them with the current observation before acting.

**All of this is trained end-to-end**, so the agent autonomously discovers how to use its memory buffer for strategic decision-making—**no hints, flags, or engineered memory cues are required**.

---

## **Comparison Table**

| Feature / Method              | LSTM PPO     | DNC/NTM        | Decision Transformer | GTrXL             | NEC / DND | Neural Map | **Strategic Memory Retrieval** |
| ----------------------------- | ------------ | -------------- | -------------------- | ----------------- | --------- | ---------- | ------------------------------ |
| Core Memory Type              | Hidden state | External R/W   | In-Context (GPT)     | Segment history   | kNN table | 2D spatial | Episodic buffer + retention    |
| Memory Retention              | Fades        | Manual/learned | None                 | History window    | FIFO      | Manual     | _Learnable, optimized_         |
| Retrieval                     | Implicit     | Soft/explicit  | Implicit             | History attention | kNN/soft  | Soft/read  | _Soft attention_               |
| Retention Learning            | No           | Partial        | No                   | No                | No        | No         | **Yes**                        |
| Interpretable Recall          | No           | Hard           | No                   | Some              | Some      | No         | **Yes (attention, use)**       |
| Persistent Memory             | No           | Partial        | No                   | Partial           | Yes       | Yes        | **Yes**                        |
| Sequence Length               | Short/medium | Short          | _Long_               | _Long_            | Medium    | Medium     | _Long_                         |
| No Hints/Flags                | Yes          | Yes            | Yes                  | Yes               | Yes       | Yes        | **Yes**                        |
| Outperforms on Delayed Reward | ✗            | ±              | ±                    | ±                 | ±         | ±          | **✓✓✓**                        |

---

## Literature & Reference Models

This agent builds upon and advances the following lines of research:

| **Approach / Paper**                                               | **Core Idea**                                   | **Key Weakness vs This**                                        |
| ------------------------------------------------------------------ | ----------------------------------------------- | --------------------------------------------------------------- |
| **DQN/LSTM-based RL**<br>Hausknecht & Stone, 2015                  | RNN hidden state as memory                      | Struggles with long delays, limited memory                      |
| **Neural Episodic Control**<br>Pritzel et al., 2017                | Non-parametric DND table, kNN retrieval         | No learnable retention, no end-to-end training                  |
| **Differentiable Neural Computer**<br>Graves et al., 2016          | RNN w/ differentiable read/write memory         | Expensive, hard to scale, hard to tune                          |
| **Neural Map / Memory-Augmented RL**<br>Parisotto et al., 2018     | Spatially structured memory, soft addressing    | Retention/static, not fully learnable, not cue-driven           |
| **Unsupervised Predictive Memory**<br>Wayne et al., 2018           | Latent predictive memory for meta-RL            | Memory not explicitly strategic or retained                     |
| **MERLIN**<br>Wayne et al., 2018                                   | Latent memory with unsupervised auxiliary tasks | Retention not explicit, memory not strategic                    |
| **Decision Transformer**<br>Chen et al., 2021                      | Uses a GPT-style transformer over trajectory    | No explicit, persistent external memory; not episodic retrieval |
| **GTrXL (Transformer-XL RL)**<br>Parisotto et al., 2020            | Relational transformer for RL sequence modeling | "Memory" = recent history, not explicit retention or recall     |
| **MVP: Memory Value Propagation**<br>Oh et al., 2020               | Learnable memory with value propagation         | Not as interpretable, not retention-focused                     |
| **Recurrent Independent Mechanisms (RIMs)**<br>Goyal et al., 2021  | Modular memory units, attention-based gating    | No persistent, recallable episodic buffer                       |
| **Active Memory / Episodic Control (EC)**<br>Blundell et al., 2016 | Episodic memory with tabular kNN access         | No differentiable retention, no meta-learning                   |

---

## **Additional References**

- **Hausknecht & Stone, 2015**: “Deep Recurrent Q-Learning for Partially Observable MDPs”
- **Pritzel et al., 2017**: “Neural Episodic Control”, [arXiv:1703.01988](https://arxiv.org/abs/1703.01988)
- **Parisotto et al., 2018**: “Neural Map: Structured Memory for Deep Reinforcement Learning”, [ICLR 2018](https://openreview.net/forum?id=B14TlG-RW)
- **Wayne et al., 2018**: “Unsupervised Predictive Memory in a Goal-Directed Agent”, [arXiv:1803.10760](https://arxiv.org/abs/1803.10760)
- **Wayne et al., 2018**: “The Unreasonable Effectiveness of Recurrent Neural Networks in Reinforcement Learning” (MERLIN), [arXiv:1804.00761](https://arxiv.org/abs/1804.00761)
- **Chen et al., 2021**: “Decision Transformer: Reinforcement Learning via Sequence Modeling”, [arXiv:2106.01345](https://arxiv.org/abs/2106.01345)
- **Parisotto et al., 2020**: “Stabilizing Transformers for Reinforcement Learning”, [ICML 2020 (GTrXL)](http://proceedings.mlr.press/v119/parisotto20a.html)
- **Oh et al., 2020**: “Value Propagation Networks”, [ICLR 2020](https://openreview.net/forum?id=B1xSperKvB)
- **Goyal et al., 2021**: “Recurrent Independent Mechanisms”, [ICLR 2021](https://openreview.net/forum?id=mLcmdlEUxy-)
- **Blundell et al., 2016**: “Model-Free Episodic Control”, [arXiv:1606.04460](https://arxiv.org/abs/1606.04460)
- **Graves et al., 2016**: “Hybrid computing using a neural network with dynamic external memory” (DNC), [Nature 2016](https://www.nature.com/articles/nature20101)
- **Sukhbaatar et al., 2015**: “End-To-End Memory Networks”, [arXiv:1503.08895](https://arxiv.org/abs/1503.08895)

---

## TL;DR;

- **First to jointly optimize both memory retention (what to keep/discard) and retrieval (what to attend to) in a single, end-to-end RL agent**.
- **Flexible plug-and-play memory**: Can be swapped for many memory architectures (transformers, graph attention, learned compression).
- **No task-specific hacks**: Outperforms the above on classic RL memory benchmarks _without using any domain knowledge_ or “cheat” features.
- **Interpretable, practical, and scalable**: Suitable for real-world problems where “what matters” is unknown and must be discovered.

---

**Author:** Filipe Sá  
**Contact:** filipemotasa@hotmail.com | [GitHub](https://github.com/pihh/)

---


In [1]:
import warnings
warnings.filterwarnings('ignore')

from agent import StrategicMemoryAgent
from environments import MemoryTaskEnv
from benchmark import AgentPerformanceBenchmark
from memory import StrategicMemoryBuffer,StrategicMemoryTransformerPolicy


  fn()


In [2]:
# SETUP ===================================
DELAY = 4
MEM_DIM = 32
N_EPISODES = 2500
N_MEMORIES = 32

AGENT_KWARGS = dict(
    device="cpu",
    verbose=0,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
)
MEMORY_AGENT_KWARGS=dict(
    her=False,
    reward_norm=False,
    aux_modules=None,
    
    intrinsic_expl=False,
    intrinsic_eta=0.01,
    
    use_rnd=False, 
    rnd_emb_dim=32, 
    rnd_lr=1e-3,
)

# HELPERS =================================
def total_timesteps(delay,n_episodes):
    return delay * n_episodes

## **Example:** Simple training setup

In [3]:
# ENVIRONMENT =============================
env = MemoryTaskEnv(delay=DELAY, difficulty=0)

# MEMORY BUFFER ===========================
memory = StrategicMemoryBuffer(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,          # For Discrete(2)
    mem_dim=MEM_DIM,
    max_entries=N_MEMORIES,
    device="cpu"
)

# POLICY NETWORK (use class) ==============
policy = StrategicMemoryTransformerPolicy

# (optional) AUXILIARY MODULES ============
"""
aux_modules = [
    CueAuxModule(feat_dim=MEM_DIM*2, n_classes=2),
    ConfidenceAuxModule(feat_dim=MEM_DIM*2)
]
"""

# AGENT SETUP =============================
agent = StrategicMemoryAgent(
    policy_class=policy,
    env=env,
    memory=memory,
    memory_learn_retention=True,    
    memory_retention_coef=0.01,   
    # aux_modules=aux_modules,  
    device="cpu",
    verbose=1,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
    **MEMORY_AGENT_KWARGS
)

# TRAIN THE AGENT =========================
agent.learn(
    total_timesteps=total_timesteps(DELAY, 4000),
    log_interval=250
)

-------------------------------------
| rollout/              |           |
|    ep_len_mean        |    4.000  |
|    ep_rew_mean        |    0.000  |
|    ep_rew_std         |    1.000  |
|    policy_entropy     |    0.200  |
|    advantage_mean     |    0.023  |
|    advantage_std      |    0.178  |
|    aux_loss_mean      |    0.000  |
| time/                 |           |
|    fps                |      199  |
|    episodes           |      250  |
|    time_elapsed       |        5  |
|    total_timesteps    |     1000  |
| train/                |           |
|    loss               |   -0.009  |
|    policy_loss        |   -0.019  |
|    value_loss         |    0.024  |
|    explained_variance |  -89.842  |
|    n_updates          |      250  |
|    progress           |    6.250  |
| memory/               |           |
|    usefulness_loss    |    0.012  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_

## Benchmark Overview
This benchmark systematically evaluates this agent on a synthetic memory task with varying levels of difficulty and delay. Each experiment defines a unique scenario (e.g., short vs. long memory delays, easy vs. hard distractors) and trains the agent to solve it using a fixed number of episodes and timesteps.

The goal is to compare the performance and generalization as task complexity increases. 

Optionally, the benchmark can compare this agent againts baseline models (e.g., standard PPO , Recurrent PPO (LstmPPO) under the same conditions. 

Results can be used to diagnose strengths and limitations of the agent, inform ablations, and guide further development.

In [None]:
# BATCH EXPERIMENT SETUP ==================
if __name__ == "__main__":
    EXPERIMENTS = [
        dict(delay=4, n_train_episodes=2000, total_timesteps=total_timesteps(4,2500), difficulty=0, mode_name="EASY", verbose=0, eval_base=True),
        dict(delay=4, n_train_episodes=5000, total_timesteps=total_timesteps(4,2500), difficulty=1, mode_name="HARD", verbose=0, eval_base=True),
        dict(delay=16, n_train_episodes=7500, total_timesteps=total_timesteps(16,3500), difficulty=0, mode_name="EASY", verbose=0, eval_base=False),
        dict(delay=32, n_train_episodes=7500, total_timesteps=total_timesteps(32,10000), difficulty=1, mode_name="HARD", verbose=0, eval_base=True),
        #dict(delay=64, n_train_episodes=15000, total_timesteps=15000*64, difficulty=0, mode_name="HARD", verbose=0, eval_base=False),
        dict(delay=256, n_train_episodes=20000, total_timesteps=total_timesteps(256,10000), difficulty=0, mode_name="HARD", verbose=1, eval_base=False),
    ]

    # Custom memory agent config 
    memory_agent_config = dict(
        action_dim=1,          # For Discrete(2)
        mem_dim=MEM_DIM,
        max_entries=N_MEMORIES,
        policy_class=StrategicMemoryTransformerPolicy,
        **AGENT_KWARGS,
        **MEMORY_AGENT_KWARGS
       
    )

    results = []
    for exp in EXPERIMENTS:
        benchmark = AgentPerformanceBenchmark(exp, memory_agent_config=memory_agent_config)
        results.append(benchmark.run())



Training in EASY mode with delay of 4 steps



Finalizing Results: 100%|██████████| 7/7 [02:09<00:00, 18.47s/step]             


╭────┬──────────────────────┬─────────┬────────┬───────────────┬──────────────┬────────────────╮
│    │ Agent                │   Delay │ Mode   │   Mean Ep Rew │   Std Ep Rew │   Duration (s) │
├────┼──────────────────────┼─────────┼────────┼───────────────┼──────────────┼────────────────┤
│  0 │ PPO                  │       4 │ EASY   │             0 │            1 │        26.0689 │
│  1 │ RecurrentPPO         │       4 │ EASY   │             0 │            1 │        51.213  │
│  2 │ StrategicMemoryAgent │       4 │ EASY   │             1 │            0 │        51.6281 │
╰────┴──────────────────────┴─────────┴────────┴───────────────┴──────────────┴────────────────╯

Training in HARD mode with delay of 4 steps



Finalizing Results: 100%|██████████| 7/7 [02:12<00:00, 18.90s/step]             


╭────┬──────────────────────┬─────────┬────────┬───────────────┬──────────────┬────────────────╮
│    │ Agent                │   Delay │ Mode   │   Mean Ep Rew │   Std Ep Rew │   Duration (s) │
├────┼──────────────────────┼─────────┼────────┼───────────────┼──────────────┼────────────────┤
│  0 │ PPO                  │       4 │ HARD   │           0.2 │     0.979796 │        25.6446 │
│  1 │ RecurrentPPO         │       4 │ HARD   │           0   │     1        │        55.5133 │
│  2 │ StrategicMemoryAgent │       4 │ HARD   │           1   │     0        │        50.7957 │
╰────┴──────────────────────┴─────────┴────────┴───────────────┴──────────────┴────────────────╯

Training in EASY mode with delay of 16 steps



Finalizing Results: 100%|██████████| 3/3 [04:46<00:00, 95.52s/step]              


╭────┬──────────────────────┬─────────┬────────┬───────────────┬──────────────┬────────────────╮
│    │ Agent                │   Delay │ Mode   │   Mean Ep Rew │   Std Ep Rew │   Duration (s) │
├────┼──────────────────────┼─────────┼────────┼───────────────┼──────────────┼────────────────┤
│  0 │ StrategicMemoryAgent │      16 │ EASY   │             1 │            0 │        285.867 │
╰────┴──────────────────────┴─────────┴────────┴───────────────┴──────────────┴────────────────╯

Training in HARD mode with delay of 32 steps



Finalizing Results: 100%|██████████| 7/7 [1:29:04<00:00, 763.49s/step]              


╭────┬──────────────────────┬─────────┬────────┬───────────────┬──────────────┬────────────────╮
│    │ Agent                │   Delay │ Mode   │   Mean Ep Rew │   Std Ep Rew │   Duration (s) │
├────┼──────────────────────┼─────────┼────────┼───────────────┼──────────────┼────────────────┤
│  0 │ PPO                  │      32 │ HARD   │          -0.1 │     0.994987 │        778.429 │
│  1 │ RecurrentPPO         │      32 │ HARD   │           0   │     1        │       2579     │
│  2 │ StrategicMemoryAgent │      32 │ HARD   │           1   │     0        │       1983.58  │
╰────┴──────────────────────┴─────────┴────────┴───────────────┴──────────────┴────────────────╯

Training in HARD mode with delay of 256 steps



Training StrategicMemoryAgent:   0%|          | 0/3 [00:00<?, ?step/s]

## Whats in the future?

### Already detects:
* Repeating event sequences (“if this and then that within 3 steps, reward later”)

* Rare event triggers (“when you see X-Y-Z in order, prepare for Z+N”)

* Long-term cues (“pattern at t=1…t=3, outcome at t=10”)

* Short sub-patterns that are only meaningful in context

### Next improvements:

* Store rolling subtrajectories or multi-scale embeddings.

* Use “motif mining” or sub-sequence attention in memory.

* Allow retrieval by partial match, not just whole-trajectory.

### Next step:

* Motif mining:
  * Trajectory Encoder - Transformer that maps a trajectory/subtrajectory to a fixed-dimensional vector.
  * Motif Memory Bank - A differentiable, learnable set of motif embeddings (shape: [num_motifs, mem_dim]). Optionally, each motif also learns a value or usefulness score.

  * Motif Mining Head - Module that, during or after a trajectory, produces a set of candidate motifs (sub-trajectories). Encodes them (with the encoder), clusters or scores them, and updates the memory bank (with gradient flow if possible). Motif selection can be based on similarity, usefulness (to reward), or novelty.

  * Differentiable DTW (Soft-DTW) - This allows for “fuzzy” alignment scores between current trajectory and stored motifs, all within the computation graph.

  * Motif Attention Integration 