In [1]:
from agent import StrategicMemoryAgent
from environments import MemoryTaskEnv
from benchmark import AgentPerformanceBenchmark
from memory import StrategicMemoryBuffer,StrategicMemoryTransformerPolicy


In [12]:
# SETUP ===================================
DELAY = 16
MEM_DIM = 32
N_EPISODES = 2500
N_MEMORIES = 16

AGENT_KWARGS = dict(
    device="cpu",
    verbose=0,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
)
MEMORY_AGENT_KWARGS=dict(
    her=False,
    reward_norm=False,
    aux_modules=None,
    
    intrinsic_expl=True,
    intrinsic_eta=0.01,
    
    use_rnd=True, 
    rnd_emb_dim=32, 
    rnd_lr=1e-3,
)

# HELPERS =================================
def total_timesteps(delay,n_episodes):
    return delay * n_episodes

## **Example:** Simple training setup

In [None]:
# ENVIRONMENT =============================
env = MemoryTaskEnv(delay=DELAY, difficulty=0)

# MEMORY BUFFER ===========================
memory = StrategicMemoryBuffer(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,          # For Discrete(2)
    mem_dim=MEM_DIM,
    max_entries=N_MEMORIES,
    device="cpu"
)

# POLICY NETWORK (use class) ==============
policy = StrategicMemoryTransformerPolicy

# (optional) AUXILIARY MODULES ============
"""
aux_modules = [
    CueAuxModule(feat_dim=MEM_DIM*2, n_classes=2),
    ConfidenceAuxModule(feat_dim=MEM_DIM*2)
]
"""

# AGENT SETUP =============================
agent = StrategicMemoryAgent(
    policy_class=policy,
    env=env,
    memory=memory,
    memory_learn_retention=True,    
    memory_retention_coef=0.01,   
    # aux_modules=aux_modules,  
    **AGENT_KWARGS,
    **MEMORY_AGENT_KWARGS
)

# TRAIN THE AGENT =========================
agent.learn(
    total_timesteps=total_timesteps(DELAY, N_EPISODES),
    log_interval=100
)

## Benchmark this agent against a regular PPO and a RecurentPPO

Will be used a environment that requires the agent to remeber past observations to decide what to do on the last action.

The reward is 1 or -1 if the agent uses the same action as the first item of the first observation , any other steps get 0 reward so the causal/effect is very delayed

In [11]:




# --- Batch experiment setup ---
if __name__ == "__main__":
    EXPERIMENTS = [
        dict(delay=4, n_train_episodes=2000, total_timesteps=total_timesteps(4,2000), difficulty=0, mode_name="EASY", verbose=0, eval_base=True),
        dict(delay=4, n_train_episodes=5000, total_timesteps=total_timesteps(4,2500), difficulty=1, mode_name="HARD", verbose=0, eval_base=True),
        dict(delay=32, n_train_episodes=7500, total_timesteps=total_timesteps(32,3000), difficulty=0, mode_name="EASY", verbose=0, eval_base=False),
        dict(delay=32, n_train_episodes=7500, total_timesteps=total_timesteps(32,3500), difficulty=1, mode_name="EASY", verbose=0, eval_base=False),
        #dict(delay=64, n_train_episodes=15000, total_timesteps=15000*64, difficulty=0, mode_name="HARD", verbose=0, eval_base=False),
        dict(delay=128, n_train_episodes=20000, total_timesteps=20000*128, difficulty=0, mode_name="HARD", verbose=0, eval_base=False),
    ]

    # --- Custom memory agent config (edit as needed) ---
    memory_agent_config = dict(
        action_dim=1,          # For Discrete(2)
        mem_dim=MEM_DIM,
        max_entries=N_MEMORIES,
        policy_class=StrategicMemoryTransformerPolicy,
        **AGENT_KWARGS,
        **MEMORY_AGENT_KWARGS
        # Add more settings if needed
    )

    results = []
    for exp in EXPERIMENTS:
        # For each experiment, pass memory agent config
        benchmark = AgentPerformanceBenchmark(exp, memory_agent_config=memory_agent_config)
        results.append(benchmark.run())



Training in EASY mode with delay of 4 steps



Training StrategicMemoryAgent:  57%|█████▋    | 4/7 [00:31<00:32, 10.78s/step]

-------------------------------------
| rollout/              |           |
|    ep_len_mean        |    4.000  |
|    ep_rew_mean        |    0.131  |
|    ep_rew_std         |    0.992  |
|    policy_entropy     |    0.439  |
|    advantage_mean     |    1.513  |
|    advantage_std      |    0.378  |
|    aux_loss_mean      |    0.000  |
| time/                 |           |
|    fps                |      249  |
|    episodes           |      250  |
|    time_elapsed       |        4  |
|    total_timesteps    |     1000  |
| train/                |           |
|    loss               |    6.677  |
|    policy_loss        |    5.483  |
|    value_loss         |    2.398  |
|    explained_variance |  -10.056  |
|    n_updates          |      250  |
|    progress           |  12.50%   |
| rnd_net_dist/         |           |
|    mean_rnd_bonus     |    0.000  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_

Evaluating StrategicMemoryAgent:  71%|███████▏  | 5/7 [01:13<00:31, 16.00s/step]

-------------------------------------
| rollout/              |           |
|    ep_len_mean        |    4.000  |
|    ep_rew_mean        |    0.992  |
|    ep_rew_std         |    0.126  |
|    policy_entropy     |    0.063  |
|    advantage_mean     |   -0.019  |
|    advantage_std      |    0.134  |
|    aux_loss_mean      |    0.000  |
| time/                 |           |
|    fps                |      195  |
|    episodes           |     2000  |
|    time_elapsed       |       41  |
|    total_timesteps    |     8000  |
| train/                |           |
|    loss               |    0.001  |
|    policy_loss        |   -0.006  |
|    value_loss         |    0.014  |
|    explained_variance |  -71.993  |
|    n_updates          |     2000  |
|    progress           | 100.00%   |
| rnd_net_dist/         |           |
|    mean_rnd_bonus     |    0.000  |
-------------------------------------
Training complete. Total episodes: 2000, total steps: 8000


Finalizing Results: 100%|██████████| 7/7 [01:13<00:00, 10.49s/step]             

╭────┬──────────────────────┬─────────┬────────┬───────────────┬──────────────┬────────────────╮
│    │ Agent                │   Delay │ Mode   │   Mean Ep Rew │   Std Ep Rew │   Duration (s) │
├────┼──────────────────────┼─────────┼────────┼───────────────┼──────────────┼────────────────┤
│  0 │ PPO                  │       4 │ EASY   │           0   │     1        │        7.40373 │
│  1 │ RecurrentPPO         │       4 │ EASY   │           0   │     1        │       23.9085  │
│  2 │ StrategicMemoryAgent │       4 │ EASY   │          -0.2 │     0.979796 │       41.7879  │
╰────┴──────────────────────┴─────────┴────────┴───────────────┴──────────────┴────────────────╯



