In [None]:
from agent import TraceRL
from environments import MemoryTaskEnv
from benchmark import AgentPerformanceBenchmark
from memory import StrategicMemoryBuffer,StrategicMemoryTransformerPolicy


In [None]:
# SETUP ===================================
DELAY = 16
MEM_DIM = 32
N_EPISODES = 2500
N_MEMORIES = 16

AGENT_KWARGS = dict(
    device="cpu",
    verbose=0,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
)
MEMORY_AGENT_KWARGS=dict(
    her=False,
    reward_norm=False,
    aux_modules=None,
    
    intrinsic_expl=True,
    intrinsic_eta=0.01,
    
    use_rnd=True, 
    rnd_emb_dim=32, 
    rnd_lr=1e-3,
)

# HELPERS =================================
def total_timesteps(delay,n_episodes):
    return delay * n_episodes

## **Example:** Simple training setup

In [None]:
# ENVIRONMENT =============================
env = MemoryTaskEnv(delay=DELAY, difficulty=0)

# MEMORY BUFFER ===========================
memory = StrategicMemoryBuffer(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,          # For Discrete(2)
    mem_dim=MEM_DIM,
    max_entries=N_MEMORIES,
    device="cpu"
)

# POLICY NETWORK (use class) ==============
policy = StrategicMemoryTransformerPolicy

# (optional) AUXILIARY MODULES ============
"""
aux_modules = [
    CueAuxModule(feat_dim=MEM_DIM*2, n_classes=2),
    ConfidenceAuxModule(feat_dim=MEM_DIM*2)
]
"""

# AGENT SETUP =============================
agent = TraceRL(
    policy_class=policy,
    env=env,
    memory=memory,
    memory_learn_retention=True,    
    memory_retention_coef=0.01,   
    # aux_modules=aux_modules,  
    **AGENT_KWARGS,
    **MEMORY_AGENT_KWARGS
)

# TRAIN THE AGENT =========================
agent.learn(
    total_timesteps=total_timesteps(DELAY, N_EPISODES),
    log_interval=100
)

## Benchmark this agent against a regular PPO and a RecurentPPO

Will be used a environment that requires the agent to remeber past observations to decide what to do on the last action.

The reward is 1 or -1 if the agent uses the same action as the first item of the first observation , any other steps get 0 reward so the causal/effect is very delayed

In [None]:




# --- Batch experiment setup ---
if __name__ == "__main__":
    EXPERIMENTS = [
        dict(delay=4, n_train_episodes=2000, total_timesteps=total_timesteps(4,2000), difficulty=0, mode_name="EASY", verbose=0, eval_base=True),
        dict(delay=4, n_train_episodes=5000, total_timesteps=total_timesteps(4,2500), difficulty=1, mode_name="HARD", verbose=0, eval_base=True),
        dict(delay=32, n_train_episodes=7500, total_timesteps=total_timesteps(32,3000), difficulty=0, mode_name="EASY", verbose=0, eval_base=False),
        dict(delay=32, n_train_episodes=7500, total_timesteps=total_timesteps(32,3500), difficulty=1, mode_name="EASY", verbose=0, eval_base=False),
        #dict(delay=64, n_train_episodes=15000, total_timesteps=15000*64, difficulty=0, mode_name="HARD", verbose=0, eval_base=False),
        dict(delay=256, n_train_episodes=20000, total_timesteps=total_timesteps(256,5000), difficulty=0, mode_name="HARD", verbose=0, eval_base=False),
    ]

    # --- Custom memory agent config (edit as needed) ---
    memory_agent_config = dict(
        action_dim=1,          # For Discrete(2)
        mem_dim=MEM_DIM,
        max_entries=N_MEMORIES,
        policy_class=StrategicMemoryTransformerPolicy,
        **AGENT_KWARGS,
        **MEMORY_AGENT_KWARGS
        # Add more settings if needed
    )

    results = []
    for exp in EXPERIMENTS:
        # For each experiment, pass memory agent config
        benchmark = AgentPerformanceBenchmark(exp, memory_agent_config=memory_agent_config)
        results.append(benchmark.run())
