# StrategicMemoryBuffer with Trainable Retention

Key ideas:
* Each memory entry now stores a learned “usefulness” score (a scalar).

* When the buffer is full, discard the entry with the lowest usefulness, not just the oldest.

* The usefulness score is updated as the agent interacts:

* For now, let’s make it a simple trainable parameter for each entry, but you can make it a neural network that depends on the context, attention frequency, outcome, or reward.

* Optionally, I can update usefulness every time the memory is attended to (e.g., with a moving average or a learned auxiliary head)

In [6]:
from environments import MemoryTaskEnv

In [7]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.nn.functional as F
import time
from collections import defaultdict

# ──────────────────────────────────────────────────────────────
# 2. Reward Normalizer
# ──────────────────────────────────────────────────────────────

class RewardNormalizer:
    """
    Running mean/variance normalizer for rewards.
    Stabilizes learning by normalizing returns online.
    """
    def __init__(self, epsilon=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.count = 1e-4  # prevents division by zero
        self.epsilon = epsilon

    def update(self, rewards):
        """
        Updates running statistics given a batch of rewards.
        Args:
            rewards (array-like): list or np.ndarray of rewards
        """
        rewards = np.array(rewards)
        batch_mean = rewards.mean()
        batch_var = rewards.var()
        batch_count = len(rewards)
        # Online update of mean/var
        self.mean = (self.mean * self.count + batch_mean * batch_count) / (self.count + batch_count)
        self.var = (self.var * self.count + batch_var * batch_count) / (self.count + batch_count)
        self.count += batch_count

    def normalize(self, rewards):
        """
        Normalizes a batch of rewards based on running statistics.
        Args:
            rewards (array-like): list or np.ndarray of rewards
        Returns:
            List of normalized rewards
        """
        rewards = np.array(rewards)
        return ((rewards - self.mean) / (np.sqrt(self.var) + self.epsilon)).tolist()

# ──────────────────────────────────────────────────────────────
# 3. State Counter for Intrinsic Reward (Exploration Bonus)
# ──────────────────────────────────────────────────────────────

class StateCounter:
    """
    Counts state visitations for intrinsic motivation (exploration).
    Implements simple count-based exploration bonus.
    """
    def __init__(self):
        self.counts = defaultdict(int)

    def count(self, obs):
        """
        Increments and returns the visit count for a discretized observation.
        Args:
            obs (np.ndarray): observation
        Returns:
            int: visit count
        """
        key = tuple(np.round(obs, 2))  # discretize for generalization
        self.counts[key] += 1
        return self.counts[key]

    def intrinsic_reward(self, obs):
        """
        Returns intrinsic reward: inversely proportional to sqrt of count.
        Args:
            obs (np.ndarray): observation
        Returns:
            float: exploration bonus
        """
        c = self.count(obs)
        return 1.0 / np.sqrt(c)

# ──────────────────────────────────────────────────────────────
# 4. Generalized Advantage Estimation (GAE) and Explained Variance
# ──────────────────────────────────────────────────────────────

def compute_explained_variance(y_pred, y_true):
    """
    Computes explained variance between prediction and ground-truth.
    Used for value function diagnostics in RL.
    """
    var_y = torch.var(y_true)
    if var_y == 0:
        return torch.tensor(0.0)
    return 1 - torch.var(y_true - y_pred) / (var_y + 1e-8)

def compute_gae(rewards, values, gamma=0.99, lam=0.95, last_value=0.0):
    """
    Compute Generalized Advantage Estimation (GAE) for a trajectory.
    Args:
        rewards (torch.Tensor): reward sequence [T]
        values (torch.Tensor): value sequence [T]
        gamma (float): discount factor
        lam (float): GAE lambda
        last_value (float): bootstrap value after final state
    Returns:
        torch.Tensor: advantage sequence [T]
    """
    T = len(rewards)
    advantages = torch.zeros(T, dtype=torch.float32, device=values.device)
    gae = 0.0
    # concatenate last value for bootstrap
    values_ext = torch.cat([values, torch.tensor([last_value], dtype=torch.float32, device=values.device)])
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * values_ext[t + 1] - values_ext[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

# ──────────────────────────────────────────────────────────────
# 5. Memory Transformer Policy with Auxiliary Head
# ──────────────────────────────────────────────────────────────

class MemoryTransformerPolicy(nn.Module):
    """
    MemoryTransformerPolicy
    ----------------------
    Transformer-based policy for sequence decision tasks.
    - Processes full trajectory as input (not just current state)
    - Supports auxiliary head for supervised tasks (e.g., cue recall)

    Args:
        obs_dim (int): Dimension of observation vector
        mem_dim (int): Transformer embedding size
        nhead (int): Number of attention heads
    """
    def __init__(self, obs_dim, mem_dim=32, nhead=4):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim, 2)   # Action logits
        self.value_head = nn.Linear(mem_dim, 1)    # Value prediction
        self.aux_head = nn.Linear(mem_dim, 2)      # Auxiliary: predict initial cue

    def forward(self, trajectory):
        """
        Forward pass: processes trajectory and outputs policy, value, and aux predictions.
        Args:
            trajectory (torch.Tensor): [T, obs_dim] trajectory
        Returns:
            logits (torch.Tensor): [2,] action logits
            value (torch.Tensor): [1,] state value
            aux_pred (torch.Tensor): [2,] auxiliary logits (for cue prediction)
        """
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)  # [1, T, mem_dim]
        x = self.transformer(x)
        feat = x[0, -1]  # use final state embedding for decision
        logits = self.policy_head(feat)
        value = self.value_head(feat)
        aux_pred = self.aux_head(feat)
        return logits, value.squeeze(-1), aux_pred

# ──────────────────────────────────────────────────────────────
# 6. HER-enabled MemoryPPO Trainer
# ──────────────────────────────────────────────────────────────

class MemoryPPO:
    """
    MemoryPPO
    ---------
    PPO trainer for memory-based RL tasks.
    - Supports Hindsight Experience Replay (HER)
    - Reward normalization, auxiliary losses, intrinsic/exploration bonuses
    - Sequence-based trajectory modeling

    Args:
        policy_class: policy network class
        env: Gymnasium environment
        device: PyTorch device
        (other RL hyperparameters)
    """
    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        device="cpu",
        her=True,                 # Enable HER
        reward_norm=False,         # Enable reward normalization
        aux=True,                 # Enable auxiliary loss
        intrinsic_expl=True,      # Enable intrinsic exploration
        intrinsic_eta=0.05,        # Intrinsic bonus multiplier
        ent_coef=0.01
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef = ent_coef
        self.verbose = verbose
        self.policy = policy_class(obs_dim=env.observation_space.shape[0]).to(self.device)
        self.policy = torch.jit.script(self.policy)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.aux = aux
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.trajectory = []
    
    def reset_trajectory(self):
        """
        Resets internal trajectory buffer after each episode or reset.
        """
        self.trajectory = []
        
    def run_episode(self, her_target=None):
        """
        Executes a full episode in the environment.
        Supports HER by relabelling the initial cue if provided.
        Returns trajectory data for training.
        """
        obs, _ = self.env.reset()
        if her_target is not None:
            obs[0] = her_target  # Relabel initial cue with HER goal

        done = False
        trajectory = []
        rewards = []
        actions = []
        log_probs = []
        values = []
        entropies_ep = []
        aux_preds = []
        t = 0

        initial_cue = int(obs[0])  # Ground-truth cue for auxiliary prediction

        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            trajectory.append(obs_t)
            traj = torch.stack(trajectory)
            logits, value, aux_pred = self.policy(traj)
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())
            # Intrinsic (exploration) reward
            if self.intrinsic_expl:
                reward += self.intrinsic_eta * self.state_counter.intrinsic_reward(obs)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds.append(aux_pred)
            t += 1
        return {
            "trajectory": trajectory,
            "actions": actions,
            "rewards": rewards,
            "log_probs": log_probs,
            "values": values,
            "entropies": entropies_ep,
            "aux_preds": aux_preds,
            "initial_cue": initial_cue
        }

    def learn(self, total_timesteps=2000, log_interval=100):
        """
        Main training loop for PPO.
        Collects rollouts, computes losses, performs optimization, and logs metrics.
        Supports HER, reward normalization, auxiliary loss, and intrinsic bonuses.
        """
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []

        while steps < total_timesteps:
            # ---- Run normal episode ----
            episode = self.run_episode()
            # ---- Reward normalization ----
            if self.reward_norm:
                self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                episode["rewards"] = [torch.tensor(rn, dtype=torch.float32, device=self.device)
                                      for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])]

            # ---- HER relabelling ----
            if self.her:
                her_target = int(episode["actions"][-1].item())
                her_episode = self.run_episode(her_target=her_target)
                for k in ["trajectory", "actions", "rewards", "log_probs", "values", "entropies", "aux_preds"]:
                    episode[k] += her_episode[k]
                initial_cue = [episode["initial_cue"], her_target]
            else:
                initial_cue = [episode["initial_cue"]]

            # ---- Batchify for loss ----
            trajectory = episode["trajectory"]
            actions = episode["actions"]
            rewards = episode["rewards"]
            log_probs = episode["log_probs"]
            values = episode["values"]
            entropies_ep = episode["entropies"]
            aux_preds = episode["aux_preds"]
            T = len(rewards)

            rewards_t = torch.stack(rewards)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            actions_t = torch.stack(actions)
            aux_preds_t = torch.stack(aux_preds)
            last_value = 0.0
            advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
            returns = advantages + values_t.detach()

            # ---- Losses ----
            policy_loss = -(log_probs_t * advantages.detach()).sum()
            value_loss = F.mse_loss(values_t, returns)
            entropy_mean = torch.stack(entropies_ep).mean()
            explained_var = compute_explained_variance(values_t, returns)

            # Auxiliary (supervised) loss for cue recall
            aux_loss = torch.tensor(0.0, device=self.device)
            if self.aux:
                if self.her:
                    cues = torch.tensor(initial_cue, dtype=torch.long, device=self.device)       # [2]
                    aux_preds_to_use = torch.stack([aux_preds_t[0], aux_preds_t[T // 2]])        # [2,2]
                else:
                    cues = torch.tensor([initial_cue[0]], dtype=torch.long, device=self.device)  # [1]
                    aux_preds_to_use = aux_preds_t[0].unsqueeze(0)                               # [1,2]
                assert aux_preds_to_use.shape[0] == cues.shape[0], \
                    f"Shape mismatch: preds {aux_preds_to_use.shape}, cues {cues.shape}"
                aux_loss = F.cross_entropy(aux_preds_to_use, cues)
                aux_losses.append(aux_loss.item())

            # ---- Total loss ----
            loss = policy_loss + 0.5 * value_loss + 0.1 * aux_loss - self.ent_coef * entropy_mean

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_reward = sum([r.item() for r in rewards])
            self.episode_rewards.append(total_reward)
            self.episode_lengths.append(T)
            episodes += 1
            steps += T

            # SB3-style logging
            if episodes % log_interval == 0 and self.verbose == 1:
                elapsed = time.time() - start_time
                mean_rew = np.mean(self.episode_rewards[-log_interval:])
                mean_len = np.mean(self.episode_lengths[-log_interval:])
                fps = int(steps / (elapsed + 1e-8))
                adv_mean = advantages.mean().item()
                adv_std = advantages.std().item()
                mean_entropy = entropy_mean.item()
                mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                print("-" * 40)
                print(f"| rollout/               |")
                print(f"|    ep_len_mean         | {mean_len:8.2f}")
                print(f"|    ep_rew_mean         | {mean_rew:8.2f}")
                print(f"|    policy_entropy      | {mean_entropy:8.3f}")
                print(f"|    advantage_mean      | {adv_mean:8.3f}")
                print(f"|    advantage_std       | {adv_std:8.3f}")
                print(f"|    aux_loss_mean       | {mean_aux:8.3f}")
                print(f"| time/                  |")
                print(f"|    fps                 | {fps:8d}")
                print(f"|    episodes            | {episodes:8d}")
                print(f"|    time_elapsed        | {elapsed:8.1f}")
                print(f"|    total_timesteps     | {steps:8d}")
                print(f"| train/                 |")
                print(f"|    loss                | {loss.item():8.3f}")
                print(f"|    policy_loss         | {policy_loss.item():8.3f}")
                print(f"|    value_loss          | {value_loss.item():8.3f}")
                print(f"|    explained_variance  | {explained_var.item():8.3f}")
                print(f"|    n_updates           | {episodes:8d}")
                print("-" * 40)
        
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")
    
    def predict(self, obs, deterministic=False, done=False):
        """
        Predicts action given observation using the full trajectory.
        Optionally resets buffer if episode done.
        """
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        self.trajectory.append(obs_t)
        traj = torch.stack(self.trajectory)
        self.policy.eval()
        with torch.no_grad():
            logits, _, _ = self.policy(traj)
            if deterministic:
                action = torch.argmax(logits).item()
            else:
                dist = Categorical(logits=logits)
                action = dist.sample().item()
        self.policy.train()
        if done:
            self.reset_trajectory()
        return action

    def save(self, path="memoryppo.pt"):
        """
        Saves model parameters to file.
        """
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        """
        Loads model parameters from file.
        """
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        """
        Runs evaluation episodes to estimate mean and std of return.
        """
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.reset_trajectory()
            done = False
            total_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                if done:
                    self.reset_trajectory()
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return


# New modules

## Differentiable Episodic Memory Buffer
A trainable external memory, with:

* Key/Value: Store every step as (key = obs, value = obs, or richer tuple).

* Gating: Agent learns when to write (memory filter).

* Query: At each step, agent emits a query vector; we compute soft attention (weights) over all stored keys.

* Retrieval: Output is a weighted sum over stored values.

## Policy update:
Policy now receives:

* The full trajectory (as before)

* The current obs

* The memory readout (retrieved with a query emitted by the policy)

In [8]:
class DifferentiableEpisodicMemory(nn.Module):
    """
    Differentiable episodic memory: soft attention over stored keys/values.
    - Agent decides what to store (learnable gate).
    - Agent emits query vector; gets soft attention-weighted sum of stored values.
    """
    def __init__(self, obs_dim, mem_dim=32, max_size=32):
        super().__init__()
        self.last_attn = None 
        self.obs_dim = obs_dim
        self.mem_dim = mem_dim
        self.max_size = max_size
        self.key_proj = nn.Linear(obs_dim, mem_dim)
        self.val_proj = nn.Linear(obs_dim, mem_dim)
        self.gate = nn.Sequential(
            nn.Linear(obs_dim, 1), nn.Sigmoid()
        )
        self.reset()

    def reset(self):
        # Clear all memory at episode start
        self.keys = []
        self.values = []

    def write(self, obs):
        # Agent chooses whether to store (learned gate)
        obs = obs.detach() if isinstance(obs, torch.Tensor) else torch.tensor(obs, dtype=torch.float32)
        key = self.key_proj(obs)
        val = self.val_proj(obs)
        prob = self.gate(obs)
        if prob.item() > np.random.rand():  # stochastic gate
            self.keys.append(key)
            self.values.append(val)
        # Cap memory size
        if len(self.keys) > self.max_size:
            self.keys = self.keys[-self.max_size:]
            self.values = self.values[-self.max_size:]

    def query(self, query_vec):
        if not self.keys:
            self.last_attn = None
            return torch.zeros(self.mem_dim, device=query_vec.device)
        keys = torch.stack(self.keys)
        values = torch.stack(self.values)
        attn_logits = torch.matmul(keys, query_vec)
        attn = torch.softmax(attn_logits, dim=0)
        self.last_attn = attn.detach().cpu().numpy()   # <--- Store for logging
        mem_read = (attn.unsqueeze(-1) * values).sum(0)
        return mem_read

class ExternalMemoryTransformerPolicy(nn.Module):
    def __init__(self, obs_dim, mem_dim=32, nhead=4, memory=None, aux_modules=None):
        super().__init__()
        self.mem_dim = mem_dim
        self.memory = memory
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim * 2, 2)   # Action logits
        self.value_head = nn.Linear(mem_dim * 2, 1)    # Value prediction
        self.aux_modules = aux_modules if aux_modules is not None else []

    def forward(self, trajectory, curr_obs):
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)
        x = self.transformer(x)
        feat = x[0, -1]
        query_vec = feat.detach()  # optionally allow grad
        mem_read = self.memory.query(query_vec)
        full_feat = torch.cat([feat, mem_read], dim=-1)
        logits = self.policy_head(full_feat)
        value = self.value_head(full_feat)
   
        aux_preds = {}
        for aux in self.aux_modules:
            aux_preds[aux.name] = aux.head(feat)
        return logits, value.squeeze(-1), aux_preds



In [9]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch.nn.functional as F
import time
from collections import defaultdict

# ──────────────────────────────────────────────────────────────
# 2. Reward Normalizer
# ──────────────────────────────────────────────────────────────

class RewardNormalizer:
    """
    Running mean/variance normalizer for rewards.
    Stabilizes learning by normalizing returns online.
    """
    def __init__(self, epsilon=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.count = 1e-4  # prevents division by zero
        self.epsilon = epsilon

    def update(self, rewards):
        """
        Updates running statistics given a batch of rewards.
        Args:
            rewards (array-like): list or np.ndarray of rewards
        """
        rewards = np.array(rewards)
        batch_mean = rewards.mean()
        batch_var = rewards.var()
        batch_count = len(rewards)
        # Online update of mean/var
        self.mean = (self.mean * self.count + batch_mean * batch_count) / (self.count + batch_count)
        self.var = (self.var * self.count + batch_var * batch_count) / (self.count + batch_count)
        self.count += batch_count

    def normalize(self, rewards):
        """
        Normalizes a batch of rewards based on running statistics.
        Args:
            rewards (array-like): list or np.ndarray of rewards
        Returns:
            List of normalized rewards
        """
        rewards = np.array(rewards)
        return ((rewards - self.mean) / (np.sqrt(self.var) + self.epsilon)).tolist()

# ──────────────────────────────────────────────────────────────
# 3. State Counter for Intrinsic Reward (Exploration Bonus)
# ──────────────────────────────────────────────────────────────

class StateCounter:
    """
    Counts state visitations for intrinsic motivation (exploration).
    Implements simple count-based exploration bonus.
    """
    def __init__(self):
        self.counts = defaultdict(int)

    def count(self, obs):
        """
        Increments and returns the visit count for a discretized observation.
        Args:
            obs (np.ndarray): observation
        Returns:
            int: visit count
        """
        key = tuple(np.round(obs, 2))  # discretize for generalization
        self.counts[key] += 1
        return self.counts[key]

    def intrinsic_reward(self, obs):
        """
        Returns intrinsic reward: inversely proportional to sqrt of count.
        Args:
            obs (np.ndarray): observation
        Returns:
            float: exploration bonus
        """
        c = self.count(obs)
        return 1.0 / np.sqrt(c)

# ──────────────────────────────────────────────────────────────
# 4. Generalized Advantage Estimation (GAE) and Explained Variance
# ──────────────────────────────────────────────────────────────

def compute_explained_variance(y_pred, y_true):
    """
    Computes explained variance between prediction and ground-truth.
    Used for value function diagnostics in RL.
    """
    var_y = torch.var(y_true)
    if var_y == 0:
        return torch.tensor(0.0)
    return 1 - torch.var(y_true - y_pred) / (var_y + 1e-8)

def compute_gae(rewards, values, gamma=0.99, lam=0.95, last_value=0.0):
    """
    Compute Generalized Advantage Estimation (GAE) for a trajectory.
    Args:
        rewards (torch.Tensor): reward sequence [T]
        values (torch.Tensor): value sequence [T]
        gamma (float): discount factor
        lam (float): GAE lambda
        last_value (float): bootstrap value after final state
    Returns:
        torch.Tensor: advantage sequence [T]
    """
    T = len(rewards)
    advantages = torch.zeros(T, dtype=torch.float32, device=values.device)
    gae = 0.0
    # concatenate last value for bootstrap
    values_ext = torch.cat([values, torch.tensor([last_value], dtype=torch.float32, device=values.device)])
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * values_ext[t + 1] - values_ext[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

# ──────────────────────────────────────────────────────────────
# 5. Memory Transformer Policy with Auxiliary Head
# ──────────────────────────────────────────────────────────────

class MemoryTransformerPolicy(nn.Module):
    """
    MemoryTransformerPolicy
    ----------------------
    Transformer-based policy for sequence decision tasks.
    - Processes full trajectory as input (not just current state)
    - Supports auxiliary head for supervised tasks (e.g., cue recall)

    Args:
        obs_dim (int): Dimension of observation vector
        mem_dim (int): Transformer embedding size
        nhead (int): Number of attention heads
    """
    def __init__(self, obs_dim, mem_dim=32, nhead=4):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim, 2)   # Action logits
        self.value_head = nn.Linear(mem_dim, 1)    # Value prediction
        self.aux_head = nn.Linear(mem_dim, 2)      # Auxiliary: predict initial cue

    def forward(self, trajectory):
        """
        Forward pass: processes trajectory and outputs policy, value, and aux predictions.
        Args:
            trajectory (torch.Tensor): [T, obs_dim] trajectory
        Returns:
            logits (torch.Tensor): [2,] action logits
            value (torch.Tensor): [1,] state value
            aux_pred (torch.Tensor): [2,] auxiliary logits (for cue prediction)
        """
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)  # [1, T, mem_dim]
        x = self.transformer(x)
        feat = x[0, -1]  # use final state embedding for decision
        logits = self.policy_head(feat)
        value = self.value_head(feat)
        aux_pred = self.aux_head(feat)
        return logits, value.squeeze(-1), aux_pred

# ──────────────────────────────────────────────────────────────
# 6. HER-enabled MemoryPPO Trainer
# ──────────────────────────────────────────────────────────────

class ExternalMemoryPPO:
    """
    ExternalMemoryPPO
    ---------
    PPO trainer for memory-based RL tasks.
    - Supports Hindsight Experience Replay (HER)
    - Reward normalization, auxiliary losses, intrinsic/exploration bonuses
    - Sequence-based trajectory modeling

    Args:
        policy_class: policy network class
        env: Gymnasium environment
        device: PyTorch device
        (other RL hyperparameters)
    """
    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        device="cpu",
        her=True,                 # Enable HER
        reward_norm=False,         # Enable reward normalization
        aux=True,                 # Enable auxiliary loss
        intrinsic_expl=True,      # Enable intrinsic exploration
        intrinsic_eta=0.05,        # Intrinsic bonus multiplier
        memory=None,      
        ent_coef=0.01
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef=ent_coef
        self.verbose = verbose
        self.memory = memory
        self.policy = policy_class(obs_dim=env.observation_space.shape[0], memory=memory).to(self.device)
        #self.policy = torch.jit.script(self.policy)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.aux = aux
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.trajectory = []
    
    def reset_trajectory(self):
        """
        Resets internal trajectory buffer after each episode or reset.
        """
        self.trajectory = []
        
    def run_episode(self, her_target=None):
        """
        Executes a full episode in the environment.
        Supports HER by relabelling the initial cue if provided.
        Returns trajectory data for training.
        """
        obs, _ = self.env.reset()
        if her_target is not None:
            obs[0] = her_target  # Relabel initial cue with HER goal
        if self.memory is not None:
            self.memory.reset()
        done = False
        trajectory = []
        rewards = []
        actions = []
        log_probs = []
        values = []
        entropies_ep = []
        aux_preds = []
        t = 0
        gate_history = []
        memory_size_history = []
        attn_weights = None
        initial_cue = int(obs[0])  # Ground-truth cue for auxiliary prediction

        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            if self.memory is not None:
                self.memory.write(obs_t)
                # Log memory gate and size
                gate_prob = self.memory.gate(obs_t).item()
                gate_history.append(gate_prob)
                memory_size_history.append(len(self.memory.keys))
            trajectory.append(obs_t)
            traj = torch.stack(trajectory)
            logits, value, aux_pred = self.policy(traj, obs_t)#self.policy(traj)
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())
            # Intrinsic (exploration) reward
            if self.intrinsic_expl:
                reward += self.intrinsic_eta * self.state_counter.intrinsic_reward(obs)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds.append(aux_pred)
            t += 1

        if self.memory is not None:
            attn_weights = self.memory.last_attn  # (from query at last policy call)
        return {
            "trajectory": trajectory,
            "actions": actions,
            "rewards": rewards,
            "log_probs": log_probs,
            "values": values,
            "entropies": entropies_ep,
            "aux_preds": aux_preds,
            "initial_cue": initial_cue,
            "gate_history": gate_history,
            "memory_size_history": memory_size_history,
            "attn_weights": attn_weights
        }

    def learn(self, total_timesteps=2000, log_interval=100):
        """
        Main training loop for PPO.
        Collects rollouts, computes losses, performs optimization, and logs metrics.
        Supports HER, reward normalization, auxiliary loss, and intrinsic bonuses.
        """
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []

        while steps < total_timesteps:
            # ---- Run normal episode ----
            episode = self.run_episode()
            # ---- Reward normalization ----
            if self.reward_norm:
                self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                episode["rewards"] = [torch.tensor(rn, dtype=torch.float32, device=self.device)
                                      for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])]

            # ---- HER relabelling ----
            if self.her:
                her_target = int(episode["actions"][-1].item())
                her_episode = self.run_episode(her_target=her_target)
                for k in ["trajectory", "actions", "rewards", "log_probs", "values", "entropies", "aux_preds"]:
                    episode[k] += her_episode[k]
                initial_cue = [episode["initial_cue"], her_target]
            else:
                initial_cue = [episode["initial_cue"]]

            # ---- Batchify for loss ----
            trajectory = episode["trajectory"]
            actions = episode["actions"]
            rewards = episode["rewards"]
            log_probs = episode["log_probs"]
            values = episode["values"]
            entropies_ep = episode["entropies"]
            aux_preds = episode["aux_preds"]
            T = len(rewards)

            rewards_t = torch.stack(rewards)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            actions_t = torch.stack(actions)
            aux_preds_t = torch.stack(aux_preds)
            last_value = 0.0
            advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
            returns = advantages + values_t.detach()

            # ---- Losses ----
            policy_loss = -(log_probs_t * advantages.detach()).sum()
            value_loss = F.mse_loss(values_t, returns)
            entropy_mean = torch.stack(entropies_ep).mean()
            explained_var = compute_explained_variance(values_t, returns)

            # Auxiliary (supervised) loss for cue recall
            aux_loss = torch.tensor(0.0, device=self.device)
            if self.aux:
                if self.her:
                    cues = torch.tensor(initial_cue, dtype=torch.long, device=self.device)       # [2]
                    aux_preds_to_use = torch.stack([aux_preds_t[0], aux_preds_t[T // 2]])        # [2,2]
                else:
                    cues = torch.tensor([initial_cue[0]], dtype=torch.long, device=self.device)  # [1]
                    aux_preds_to_use = aux_preds_t[0].unsqueeze(0)                               # [1,2]
                assert aux_preds_to_use.shape[0] == cues.shape[0], \
                    f"Shape mismatch: preds {aux_preds_to_use.shape}, cues {cues.shape}"
                aux_loss = F.cross_entropy(aux_preds_to_use, cues)
                aux_losses.append(aux_loss.item())

            # ---- Total loss ----
            loss = policy_loss + 0.5 * value_loss + 0.1 * aux_loss - self.ent_coef * entropy_mean

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_reward = sum([r.item() for r in rewards])
            self.episode_rewards.append(total_reward)
            self.episode_lengths.append(T)
            episodes += 1
            steps += T

            # SB3-style logging
            if episodes % log_interval == 0 and self.verbose == 1:
                elapsed = time.time() - start_time
                mean_rew = np.mean(self.episode_rewards[-log_interval:])
                mean_len = np.mean(self.episode_lengths[-log_interval:])
                fps = int(steps / (elapsed + 1e-8))
                adv_mean = advantages.mean().item()
                adv_std = advantages.std().item()
                mean_entropy = entropy_mean.item()
                mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                print("-" * 40)
                print(f"| rollout/               |")
                print(f"|    ep_len_mean         | {mean_len:8.2f}")
                print(f"|    ep_rew_mean         | {mean_rew:8.2f}")
                print(f"|    policy_entropy      | {mean_entropy:8.3f}")
                print(f"|    advantage_mean      | {adv_mean:8.3f}")
                print(f"|    advantage_std       | {adv_std:8.3f}")
                print(f"|    aux_loss_mean       | {mean_aux:8.3f}")
                print(f"| time/                  |")
                print(f"|    fps                 | {fps:8d}")
                print(f"|    episodes            | {episodes:8d}")
                print(f"|    time_elapsed        | {elapsed:8.1f}")
                print(f"|    total_timesteps     | {steps:8d}")
                print(f"| train/                 |")
                print(f"|    loss                | {loss.item():8.3f}")
                print(f"|    policy_loss         | {policy_loss.item():8.3f}")
                print(f"|    value_loss          | {value_loss.item():8.3f}")
                print(f"|    explained_variance  | {explained_var.item():8.3f}")
                print(f"|    n_updates           | {episodes:8d}")
                print(f"| memory/                |")
                print(f"|    gate_history        | {len(set(episode['gate_history']))}")
                print(f"|    gate_history        | {episode['gate_history']}")
                print(f"|    memory_size_history | {episode['memory_size_history']}")
                print(f"|    attn_weights        | {episode['attn_weights']}")
                print("-" * 40)
        
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")
    

    def predict(self, obs, deterministic=False, done=False, log_diagnostics=False):
        """
        Predicts action given observation using the full trajectory.
        Optionally resets buffer if episode done.
        """
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        self.trajectory.append(obs_t)
        traj = torch.stack(self.trajectory)
        self.policy.eval()
        with torch.no_grad():
            logits, _, _ = self.policy(traj, obs_t)
            if deterministic:
                action = torch.argmax(logits).item()
            else:
                dist = Categorical(logits=logits)
                action = dist.sample().item()
            # Optionally, here you could log self.memory.gate(obs_t).item() and self.memory.last_attn
        self.policy.train()
        if done:
            self.reset_trajectory()
        return action
        
    def save(self, path="memoryppo.pt"):
        """
        Saves model parameters to file.
        """
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        """
        Loads model parameters from file.
        """
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        """
        Runs evaluation episodes to estimate mean and std of return.
        """
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.reset_trajectory()
            done = False
            total_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                if done:
                    self.reset_trajectory()
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return


# Auxiliar Logic as a plugin

### Usage:
```
# Auxiliary heads ====================

class CueAuxModule(MraAuxiliaryModule):
    def __init__(self, feat_dim, n_classes=2):
        super().__init__("cue")
        self.head = nn.Linear(feat_dim, n_classes)

    def aux_loss(self, pred, target, context=None):
        # pred: [B, n_classes], target: [B]
        return F.cross_entropy(pred, target)

# Auxiliary head injection ============
aux_modules = [
    CueAuxModule(feat_dim=32),
    ConfidenceAuxModule(feat_dim=32),
    # MistakeAuxModule, FutureAuxModule, etc...
]


agent = ExternalMemoryPPO(
    policy_class=MemoryTransformerPolicy,
    env=env,
    aux_modules=aux_modules,
    # ...
)

```

In [10]:
class MraAuxiliaryModule:
    """
    A plugin system so agent can compute and optimize all auxiliary losses together.
    * each module can have its own head and loss
    * plug and play
    * integrated logs on rollout

    """
    def __init__(self, name, task="classification"):
        self.name = name
        self.head = None  # will be nn.Module subclass
        self.task = task

    def aux_loss(self, pred, target, context=None):
        # Should be overridden by subclass
        raise NotImplementedError

    def get_head(self, feat_dim):
        # Should be overridden to return an nn.Module
        raise NotImplementedError

    def aux_metrics(self, pred, target, context=None):
        """
        Returns a dict of any diagnostics, e.g. accuracy, mse, etc.
        User can override this but by default he can use the
        corresponding task metric
        """
        if self.task == "classification":
            return self.__classification_metric(pred,target,context)
        else:
            return self.__regression_metric(pred,target,context)

    def __classification_metric(self,pred,target,context=None):
        pred_label = pred.argmax(dim=-1)
        correct = (pred_label == target).float()
        acc = correct.mean().item()
        return {'acc': acc}
        
    def __regression_metric(self,pred,target,context=None):
        mse = F.mse_loss(torch.sigmoid(pred.squeeze(-1)), target.float()).item()
        return {'mse': mse}




In [11]:
class CueAuxModule(MraAuxiliaryModule):
    def __init__(self, feat_dim, n_classes=2):
        super().__init__("cue")
        self.head = nn.Linear(feat_dim, n_classes)

    def aux_loss(self, pred, target, context=None):
        return F.cross_entropy(pred, target)



class ConfidenceAuxModule(MraAuxiliaryModule):
    def __init__(self, feat_dim):
        super().__init__("confidence")
        self.head = nn.Linear(feat_dim, 1)

    def aux_loss(self, pred, target, context=None):
        return F.mse_loss(torch.sigmoid(pred.squeeze(-1)), target.float())

    def aux_metrics(self, pred, target, context=None):
        mse = F.mse_loss(torch.sigmoid(pred.squeeze(-1)), target.float()).item()
        return {'mse': mse}

In [12]:
class EarlyStoppingMonitor:
    def __init__(self, target=0.95, min_logs=5):
        self.target = target
        self.min_logs = min_logs
        self.history = []

    def update(self, rew):
        self.history.append(rew)
        if len(self.history) > self.min_logs:
            self.history.pop(0)
        return self.should_stop()

    def should_stop(self):
        if len(self.history) < self.min_logs:
            return False
        # Stop if all recent rewards exceed target
        return all(r >= self.target for r in self.history)

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNDModule(nn.Module):
    def __init__(self, obs_dim, emb_dim=32):
        super().__init__()
        # Fixed random target net
        self.target = nn.Sequential(
            nn.Linear(obs_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        # Predictor net
        self.predictor = nn.Sequential(
            nn.Linear(obs_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        # Freeze target net
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, obs):
        with torch.no_grad():
            target_emb = self.target(obs)
        pred_emb = self.predictor(obs)
        # Return MSE as novelty signal
        return F.mse_loss(pred_emb, target_emb, reduction='none').mean(dim=-1)
        
class RewardNormalizer:
    def __init__(self, epsilon=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.count = 1e-4
        self.epsilon = epsilon

    def update(self, rewards):
        rewards = np.array(rewards)
        batch_mean = rewards.mean()
        batch_var = rewards.var()
        batch_count = len(rewards)
        self.mean = (self.mean * self.count + batch_mean * batch_count) / (self.count + batch_count)
        self.var = (self.var * self.count + batch_var * batch_count) / (self.count + batch_count)
        self.count += batch_count

    def normalize(self, rewards):
        rewards = np.array(rewards)
        return ((rewards - self.mean) / (np.sqrt(self.var) + self.epsilon)).tolist()

# ──────────────────────────────────────────────────────────────
# 3. State Counter for Intrinsic Reward (Exploration Bonus)
# ──────────────────────────────────────────────────────────────

class StateCounter:
    def __init__(self):
        self.counts = defaultdict(int)

    def count(self, obs):
        key = tuple(np.round(obs, 2))
        self.counts[key] += 1
        return self.counts[key]

    def intrinsic_reward(self, obs):
        c = self.count(obs)
        return 1.0 / np.sqrt(c)

# ──────────────────────────────────────────────────────────────
# 4. Generalized Advantage Estimation (GAE) and Explained Variance
# ──────────────────────────────────────────────────────────────

def compute_explained_variance(y_pred, y_true):
    var_y = torch.var(y_true)
    if var_y == 0:
        return torch.tensor(0.0)
    return 1 - torch.var(y_true - y_pred) / (var_y + 1e-8)

def compute_gae(rewards, values, gamma=0.99, lam=0.95, last_value=0.0):
    T = len(rewards)
    advantages = torch.zeros(T, dtype=torch.float32, device=values.device)
    gae = 0.0
    values_ext = torch.cat([values, torch.tensor([last_value], dtype=torch.float32, device=values.device)])
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * values_ext[t + 1] - values_ext[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

def print_sb3_style_log_box(stats):
    # Flatten stats for max width calculation
    all_rows = []
    for section in stats:
        all_rows.append((section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            all_rows.append((k, v, False))
    # Compute widths
    key_width = max(
        len("    " + k) if not is_section else len(k)
        for k, v, is_section in all_rows
    )
    val_width = 10
    box_width = 2 + key_width + 3 + val_width 

    def fmt_row(label, value, is_section):
        if is_section:
            return f"| {label:<{key_width}}|{' ' * val_width} |"
        else:
            # Format value: float = 8.3f, int = 8d, tensor fallback
            if hasattr(value, 'item'):
                value = value.item()
            if isinstance(value, float):
                s_value = f"{value:8.3f}"
            elif isinstance(value, int):
                s_value = f"{value:8d}"
            else:
                s_value = str(value)
            s_value_centered = f"{s_value:^{val_width}}"
            return f"|    {label:<{key_width-4}} |{s_value_centered} |"

    print("-" * box_width)
    for section in stats:
        print(fmt_row(section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            print(fmt_row(k, v, False))
    print("-" * box_width)

# ──────────────────────────────────────────────────────────────
# 5. Memory Transformer Policy WITH MULTI-AUX SUPPORT
# ──────────────────────────────────────────────────────────────

class MemoryTransformerPolicy(nn.Module):
    __version__     = "1.1.0"
    __description__ = "Now supports injection of custom auxiliary modules"
    
    def __init__(self, obs_dim, mem_dim=32, nhead=4, aux_modules=None, **kwargs):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim, 2)
        self.value_head = nn.Linear(mem_dim, 1)
        self.aux_modules = aux_modules if aux_modules is not None else []

    def forward(self, trajectory, obs_t=None):
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)
        x = self.transformer(x)
        feat = x[0, -1]
        logits = self.policy_head(feat)
        value = self.value_head(feat)
        aux_preds = {}
        for aux in self.aux_modules:
            aux_preds[aux.name] = aux.head(feat)
        return logits, value.squeeze(-1), aux_preds

# ──────────────────────────────────────────────────────────────
# 6. Multi-Aux HER-enabled MemoryPPO Trainer (with metric logging)
# ──────────────────────────────────────────────────────────────

class ExternalMemoryPPO:
    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        device="cpu",
        her=False,
        reward_norm=False,
        intrinsic_expl=True,
        intrinsic_eta=0.01,
        ent_coef=0.01,
        memory=None,
        aux_modules=None,
        use_rnd=True, rnd_emb_dim=32, rnd_lr=1e-3,
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef=ent_coef
        self.verbose = verbose
        self.memory = memory
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.aux = len(self.aux_modules) > 0
        self.policy = policy_class(
            obs_dim=env.observation_space.shape[0], 
            memory=memory,
            aux_modules=self.aux_modules
        ).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.use_rnd = use_rnd
        if self.use_rnd:
            self.rnd = RNDModule(env.observation_space.shape[0], emb_dim=rnd_emb_dim).to(self.device)
            self.rnd_optimizer = torch.optim.Adam(self.rnd.predictor.parameters(), lr=rnd_lr)
        self.trajectory = []

    def reset_trajectory(self):
        self.trajectory = []

    def run_episode(self, her_target=None):
        obs, _ = self.env.reset()
        if her_target is not None:
            obs[0] = her_target
        if self.memory is not None:
            self.memory.reset()
        done = False
        trajectory = []
        rewards = []
        actions = []
        log_probs = []
        values = []
        entropies_ep = []
        aux_preds_list = []
        gate_history = []
        memory_size_history = []
        attn_weights = None
        initial_cue = int(obs[0])
        aux_targets_ep = {aux.name: [] for aux in self.aux_modules}

        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            if self.memory is not None:
                self.memory.write(obs_t)
                gate_prob = self.memory.gate(obs_t).item()
                gate_history.append(gate_prob)
                memory_size_history.append(len(self.memory.keys))
            trajectory.append(obs_t)
            traj = torch.stack(trajectory)
            logits, value, aux_preds = self.policy(traj, obs_t)
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())
     
            if self.intrinsic_expl:
                reward += self.intrinsic_eta * self.state_counter.intrinsic_reward(obs)
            rnd_intrinsic = 0.0
            if self.use_rnd:
                with torch.no_grad():
                    obs_rnd = obs_t.unsqueeze(0)
                    rnd_intrinsic = self.rnd(obs_rnd).item()
                    reward += self.intrinsic_eta * rnd_intrinsic
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds_list.append(aux_preds)
            # Assign per-step targets for each aux (replace with your logic)
            for aux in self.aux_modules:
                if aux.name == "cue":
                    aux_targets_ep[aux.name].append(initial_cue)
                elif aux.name == "next_obs":
                    aux_targets_ep[aux.name].append(torch.tensor(obs, dtype=torch.float32))
                elif aux.name == "confidence":
                    dist = Categorical(logits=logits)
                    entropy = dist.entropy().item()
                    confidence = 1.0 - entropy  # Or scale appropriately
                    aux_targets_ep[aux.name].append(confidence)
                elif aux.name == "event":
                    event_flag = getattr(self.env, "event_flag", 0)
                    aux_targets_ep[aux.name].append(event_flag)
                elif aux.name == "oracle_action":
                    oracle_action = getattr(self.env, "oracle_action", None)
                    aux_targets_ep[aux.name].append(oracle_action)
                else:
                    aux_targets_ep[aux.name].append(0)  # Or np.nan
                    
        if self.memory is not None:
            attn_weights = self.memory.last_attn
        if self.use_rnd:
            obs_batch = torch.stack([torch.tensor(np.array(o), dtype=torch.float32, device=self.device) for o in trajectory])
            rnd_loss = self.rnd(obs_batch).mean()
            self.rnd_optimizer.zero_grad()
            rnd_loss.backward()
            self.rnd_optimizer.step()
        return {
            "trajectory": trajectory,
            "actions": actions,
            "rewards": rewards,
            "log_probs": log_probs,
            "values": values,
            "entropies": entropies_ep,
            "aux_preds": aux_preds_list,
            "aux_targets": aux_targets_ep,
            "initial_cue": initial_cue,
            "gate_history": gate_history,
            "memory_size_history": memory_size_history,
            "attn_weights": attn_weights
        }

    def learn(self, total_timesteps=2000, log_interval=100):
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []

        while steps < total_timesteps:
            episode = self.run_episode()
            if self.reward_norm:
                self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                episode["rewards"] = [torch.tensor(rn, dtype=torch.float32, device=self.device)
                                      for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])]

            trajectory = episode["trajectory"]
            actions = episode["actions"]
            rewards = episode["rewards"]
            log_probs = episode["log_probs"]
            values = episode["values"]
            entropies_ep = episode["entropies"]
            aux_preds = episode["aux_preds"]
            aux_targets = episode["aux_targets"]
            T = len(rewards)

            rewards_t = torch.stack(rewards)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            actions_t = torch.stack(actions)
            last_value = 0.0
            advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
            returns = advantages + values_t.detach()

            policy_loss = -(log_probs_t * advantages.detach()).sum()
            value_loss = F.mse_loss(values_t, returns)
            entropy_mean = torch.stack(entropies_ep).mean()
            explained_var = compute_explained_variance(values_t, returns)

            # ---- Multi-Aux Loss and Metrics ----
            aux_loss_total = torch.tensor(0.0, device=self.device)
            aux_metrics_log = {}
            if self.aux:
                for aux in self.aux_modules:
                    preds = torch.stack([ap[aux.name] for ap in aux_preds])
                    targets = torch.tensor(aux_targets[aux.name], device=self.device)
                    # For confidence, targets might be float!
                    if preds.dim() != targets.dim():
                        targets = targets.squeeze(-1)
                    loss = aux.aux_loss(preds, targets)
                    aux_loss_total += loss
                    # Metrics
                    metrics = aux.aux_metrics(preds, targets)
                    aux_metrics_log[aux.name] = metrics
                aux_losses.append(aux_loss_total.item())

            loss = policy_loss + 0.5 * value_loss + 0.1 * aux_loss_total - self.ent_coef * entropy_mean

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_reward = sum([r.item() for r in rewards])
            self.episode_rewards.append(total_reward)
            self.episode_lengths.append(T)
            episodes += 1
            steps += T

            # Logging
            if episodes % log_interval == 0 and self.verbose == 1:
                elapsed = int(time.time() - start_time)
                mean_rew = np.mean(self.episode_rewards[-log_interval:])
                std_rew = np.std(self.episode_rewards[-log_interval:])
                mean_len = np.mean(self.episode_lengths[-log_interval:])
                fps = int(steps / (elapsed + 1e-8))
                adv_mean = advantages.mean().item()
                adv_std = advantages.std().item()
                mean_entropy = entropy_mean.item()
                mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                stats = [{
                    "header":"rollout",
                    "stats":dict(
                        ep_len_mean=mean_len,
                        ep_rew_mean=mean_rew,
                        ep_rew_std=std_rew,
                        policy_entropy=mean_entropy,
                        advantage_mean=adv_mean,
                        advantage_std=adv_std,
                        aux_loss_mean=mean_aux
                    )},{
                    "header":"time",
                    "stats":dict(
                        fps=fps,
                        episodes=episodes,
                        time_elapsed=elapsed,
                        total_timesteps=steps,
                    )},{
                    "header":"train",
                    "stats":dict(
                        loss=loss.item(),
                        policy_loss=policy_loss.item(),
                        value_loss=value_loss.item(),
                        explained_variance=explained_var.item(),
                        n_updates=episodes
                    )}
                ]
                if len(aux_metrics_log.items()) > 0:
                    aux_stats = {
                        "header": "aux_train",
                        "stats": {}
                    }
                    for aux_name, metrics in aux_metrics_log.items():
                        for k, v in metrics.items():
                            aux_stats["stats"][f"aux_{aux_name}_{k}"] = v
                    stats.append(aux_stats)
                if self.use_rnd:
                    mean_rnd_bonus = np.mean([self.rnd(torch.tensor(np.array(o), dtype=torch.float32, device=self.device).unsqueeze(0)).item() for o in trajectory])
                    stats.append({
                        "header": "rnd_net_dist",
                        "stats": {"mean_rnd_bonus":mean_rnd_bonus}
                    })
                
                print_sb3_style_log_box(stats)
      
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")

    def predict(self, obs, deterministic=False, done=False, log_diagnostics=False):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        self.trajectory.append(obs_t)
        traj = torch.stack(self.trajectory)
        self.policy.eval()
        with torch.no_grad():
            logits, _, _ = self.policy(traj, obs_t)
            if deterministic:
                action = torch.argmax(logits).item()
            else:
                dist = Categorical(logits=logits)
                action = dist.sample().item()
        self.policy.train()
        if done:
            self.reset_trajectory()
        return action

    def save(self, path="memoryppo.pt"):
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.reset_trajectory()
            done = False
            total_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                if done:
                    self.reset_trajectory()
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return


# Strategic Episodic memory


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RNDModule(nn.Module):
    def __init__(self, obs_dim, emb_dim=32):
        super().__init__()
        # Fixed random target net
        self.target = nn.Sequential(
            nn.Linear(obs_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        # Predictor net
        self.predictor = nn.Sequential(
            nn.Linear(obs_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        # Freeze target net
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, obs):
        with torch.no_grad():
            target_emb = self.target(obs)
        pred_emb = self.predictor(obs)
        # Return MSE as novelty signal
        return F.mse_loss(pred_emb, target_emb, reduction='none').mean(dim=-1)
        
class RewardNormalizer:
    def __init__(self, epsilon=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.count = 1e-4
        self.epsilon = epsilon

    def update(self, rewards):
        rewards = np.array(rewards)
        batch_mean = rewards.mean()
        batch_var = rewards.var()
        batch_count = len(rewards)
        self.mean = (self.mean * self.count + batch_mean * batch_count) / (self.count + batch_count)
        self.var = (self.var * self.count + batch_var * batch_count) / (self.count + batch_count)
        self.count += batch_count

    def normalize(self, rewards):
        rewards = np.array(rewards)
        return ((rewards - self.mean) / (np.sqrt(self.var) + self.epsilon)).tolist()

# ──────────────────────────────────────────────────────────────
# 3. State Counter for Intrinsic Reward (Exploration Bonus)
# ──────────────────────────────────────────────────────────────

class StateCounter:
    def __init__(self):
        self.counts = defaultdict(int)

    def count(self, obs):
        key = tuple(np.round(obs, 2))
        self.counts[key] += 1
        return self.counts[key]

    def intrinsic_reward(self, obs):
        c = self.count(obs)
        return 1.0 / np.sqrt(c)

# ──────────────────────────────────────────────────────────────
# 4. Generalized Advantage Estimation (GAE) and Explained Variance
# ──────────────────────────────────────────────────────────────

def compute_explained_variance(y_pred, y_true):
    var_y = torch.var(y_true)
    if var_y == 0:
        return torch.tensor(0.0)
    return 1 - torch.var(y_true - y_pred) / (var_y + 1e-8)

def compute_gae(rewards, values, gamma=0.99, lam=0.95, last_value=0.0):
    T = len(rewards)
    advantages = torch.zeros(T, dtype=torch.float32, device=values.device)
    gae = 0.0
    values_ext = torch.cat([values, torch.tensor([last_value], dtype=torch.float32, device=values.device)])
    for t in reversed(range(T)):
        delta = rewards[t] + gamma * values_ext[t + 1] - values_ext[t]
        gae = delta + gamma * lam * gae
        advantages[t] = gae
    return advantages

def print_sb3_style_log_box(stats):
    # Flatten stats for max width calculation
    all_rows = []
    for section in stats:
        all_rows.append((section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            all_rows.append((k, v, False))
    # Compute widths
    key_width = max(
        len("    " + k) if not is_section else len(k)
        for k, v, is_section in all_rows
    )
    val_width = 10
    box_width = 2 + key_width + 3 + val_width 

    def fmt_row(label, value, is_section):
        if is_section:
            return f"| {label:<{key_width}}|{' ' * val_width} |"
        else:
            # Format value: float = 8.3f, int = 8d, tensor fallback
            if hasattr(value, 'item'):
                value = value.item()
            if isinstance(value, float):
                s_value = f"{value:8.3f}"
            elif isinstance(value, int):
                s_value = f"{value:8d}"
            else:
                s_value = str(value)
            s_value_centered = f"{s_value:^{val_width}}"
            return f"|    {label:<{key_width-4}} |{s_value_centered} |"

    print("-" * box_width)
    for section in stats:
        print(fmt_row(section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            print(fmt_row(k, v, False))
    print("-" * box_width)

# ──────────────────────────────────────────────────────────────
# 5. Memory Transformer Policy WITH MULTI-AUX SUPPORT
# ──────────────────────────────────────────────────────────────

class MemoryTransformerPolicy(nn.Module):
    __version__     = "1.1.0"
    __description__ = "Now supports injection of custom auxiliary modules"
    
    def __init__(self, obs_dim, mem_dim=32, nhead=4, aux_modules=None, **kwargs):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim, 2)
        self.value_head = nn.Linear(mem_dim, 1)
        self.aux_modules = aux_modules if aux_modules is not None else []

    def forward(self, trajectory, obs_t=None):
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)
        x = self.transformer(x)
        feat = x[0, -1]
        logits = self.policy_head(feat)
        value = self.value_head(feat)
        aux_preds = {}
        for aux in self.aux_modules:
            aux_preds[aux.name] = aux.head(feat)
        return logits, value.squeeze(-1), aux_preds

# ──────────────────────────────────────────────────────────────
# 6. Multi-Aux HER-enabled MemoryPPO Trainer (with metric logging)
# ──────────────────────────────────────────────────────────────

class ExternalMemoryPPO:
    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        device="cpu",
        her=False,
        reward_norm=False,
        intrinsic_expl=True,
        intrinsic_eta=0.01,
        ent_coef=0.01,
        memory=None,
        aux_modules=None,
        use_rnd=True, rnd_emb_dim=32, rnd_lr=1e-3,
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef=ent_coef
        self.verbose = verbose
        self.memory = memory
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.aux = len(self.aux_modules) > 0
        self.policy = policy_class(
            obs_dim=env.observation_space.shape[0], 
            memory=memory,
            aux_modules=self.aux_modules
        ).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.use_rnd = use_rnd
        if self.use_rnd:
            self.rnd = RNDModule(env.observation_space.shape[0], emb_dim=rnd_emb_dim).to(self.device)
            self.rnd_optimizer = torch.optim.Adam(self.rnd.predictor.parameters(), lr=rnd_lr)
        self.trajectory = []

    def reset_trajectory(self):
        self.trajectory = []

    def run_episode(self, her_target=None):
        obs, _ = self.env.reset()
        done = False
        trajectory = []
        rewards = []
        actions = []
        log_probs = []
        values = []
        entropies_ep = []
        aux_preds_list = []
        # for memory context
        rewards_float = []
        actions_int = []
    
        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            trajectory.append(obs_t)
            # Store actions/rewards for memory context
            if len(actions) > 0:
                actions_int.append(actions[-1].item())
                rewards_float.append(rewards[-1].item())
            else:
                actions_int.append(0)
                rewards_float.append(0.0)
            traj = torch.stack(trajectory)
            logits, value, aux_preds = self.policy(traj, obs_t, actions=torch.tensor(actions_int, device=self.device) if actions_int else None, rewards=torch.tensor(rewards_float, device=self.device) if rewards_float else None)
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds_list.append(aux_preds)
            if self.aux:
                # Assign per-step targets for each aux (replace with your logic)
                for aux in self.aux_modules:
                    if aux.name == "cue":
                        aux_targets_ep[aux.name].append(initial_cue)
                    elif aux.name == "next_obs":
                        aux_targets_ep[aux.name].append(torch.tensor(obs, dtype=torch.float32))
                    elif aux.name == "confidence":
                        dist = Categorical(logits=logits)
                        entropy = dist.entropy().item()
                        confidence = 1.0 - entropy  # Or scale appropriately
                        aux_targets_ep[aux.name].append(confidence)
                    elif aux.name == "event":
                        event_flag = getattr(self.env, "event_flag", 0)
                        aux_targets_ep[aux.name].append(event_flag)
                    elif aux.name == "oracle_action":
                        oracle_action = getattr(self.env, "oracle_action", None)
                        aux_targets_ep[aux.name].append(oracle_action)
                    else:
                        aux_targets_ep[aux.name].append(0)  # Or np.nan
                        
            if self.memory is not None:
                attn_weights = self.memory.last_attn
            if self.use_rnd:
                obs_batch = torch.stack([torch.tensor(np.array(o), dtype=torch.float32, device=self.device) for o in trajectory])
                rnd_loss = self.rnd(obs_batch).mean()
                self.rnd_optimizer.zero_grad()
                rnd_loss.backward()
                self.rnd_optimizer.step()
            return {
                "trajectory": trajectory,
                "actions": actions,
                "rewards": rewards,
                "log_probs": log_probs,
                "values": values,
                "entropies": entropies_ep,
                "aux_preds": aux_preds_list,
                "aux_targets": aux_targets_ep,
                "initial_cue": initial_cue,
                "gate_history": gate_history,
                "memory_size_history": memory_size_history,
                "attn_weights": attn_weights
            }

    def learn(self, total_timesteps=2000, log_interval=100):
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []

        while steps < total_timesteps:
            episode = self.run_episode()
            if self.reward_norm:
                self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                episode["rewards"] = [torch.tensor(rn, dtype=torch.float32, device=self.device)
                                      for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])]

            trajectory = episode["trajectory"]
            actions = episode["actions"]
            rewards = episode["rewards"]
            log_probs = episode["log_probs"]
            values = episode["values"]
            entropies_ep = episode["entropies"]
            aux_preds = episode["aux_preds"]
            aux_targets = episode["aux_targets"]
            T = len(rewards)

            rewards_t = torch.stack(rewards)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            actions_t = torch.stack(actions)
            last_value = 0.0
            advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
            returns = advantages + values_t.detach()

            policy_loss = -(log_probs_t * advantages.detach()).sum()
            value_loss = F.mse_loss(values_t, returns)
            entropy_mean = torch.stack(entropies_ep).mean()
            explained_var = compute_explained_variance(values_t, returns)

            # ---- Multi-Aux Loss and Metrics ----
            aux_loss_total = torch.tensor(0.0, device=self.device)
            aux_metrics_log = {}
            if self.aux:
                for aux in self.aux_modules:
                    preds = torch.stack([ap[aux.name] for ap in aux_preds])
                    targets = torch.tensor(aux_targets[aux.name], device=self.device)
                    # For confidence, targets might be float!
                    if preds.dim() != targets.dim():
                        targets = targets.squeeze(-1)
                    loss = aux.aux_loss(preds, targets)
                    aux_loss_total += loss
                    # Metrics
                    metrics = aux.aux_metrics(preds, targets)
                    aux_metrics_log[aux.name] = metrics
                aux_losses.append(aux_loss_total.item())

            loss = policy_loss + 0.5 * value_loss + 0.1 * aux_loss_total - self.ent_coef * entropy_mean

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_reward = sum([r.item() for r in rewards])
            self.episode_rewards.append(total_reward)
            self.episode_lengths.append(T)
            episodes += 1
            steps += T

            # Logging
            if episodes % log_interval == 0 and self.verbose == 1:
                elapsed = int(time.time() - start_time)
                mean_rew = np.mean(self.episode_rewards[-log_interval:])
                std_rew = np.std(self.episode_rewards[-log_interval:])
                mean_len = np.mean(self.episode_lengths[-log_interval:])
                fps = int(steps / (elapsed + 1e-8))
                adv_mean = advantages.mean().item()
                adv_std = advantages.std().item()
                mean_entropy = entropy_mean.item()
                mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                stats = [{
                    "header":"rollout",
                    "stats":dict(
                        ep_len_mean=mean_len,
                        ep_rew_mean=mean_rew,
                        ep_rew_std=std_rew,
                        policy_entropy=mean_entropy,
                        advantage_mean=adv_mean,
                        advantage_std=adv_std,
                        aux_loss_mean=mean_aux
                    )},{
                    "header":"time",
                    "stats":dict(
                        fps=fps,
                        episodes=episodes,
                        time_elapsed=elapsed,
                        total_timesteps=steps,
                    )},{
                    "header":"train",
                    "stats":dict(
                        loss=loss.item(),
                        policy_loss=policy_loss.item(),
                        value_loss=value_loss.item(),
                        explained_variance=explained_var.item(),
                        n_updates=episodes
                    )}
                ]
                if len(aux_metrics_log.items()) > 0:
                    aux_stats = {
                        "header": "aux_train",
                        "stats": {}
                    }
                    for aux_name, metrics in aux_metrics_log.items():
                        for k, v in metrics.items():
                            aux_stats["stats"][f"aux_{aux_name}_{k}"] = v
                    stats.append(aux_stats)
                if self.use_rnd:
                    mean_rnd_bonus = np.mean([self.rnd(torch.tensor(np.array(o), dtype=torch.float32, device=self.device).unsqueeze(0)).item() for o in trajectory])
                    stats.append({
                        "header": "rnd_net_dist",
                        "stats": {"mean_rnd_bonus":mean_rnd_bonus}
                    })
                
                print_sb3_style_log_box(stats)
      
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")

    def predict(self, obs, deterministic=False, done=False, log_diagnostics=False):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        self.trajectory.append(obs_t)
        traj = torch.stack(self.trajectory)
        self.policy.eval()
        with torch.no_grad():
            logits, _, _ = self.policy(traj, obs_t)
            if deterministic:
                action = torch.argmax(logits).item()
            else:
                dist = Categorical(logits=logits)
                action = dist.sample().item()
        self.policy.train()
        if done:
            self.reset_trajectory()
        return action

    def save(self, path="memoryppo.pt"):
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.reset_trajectory()
            done = False
            total_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                if done:
                    self.reset_trajectory()
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return


In [27]:
import torch
import torch.nn as nn
import numpy as np

class StrategicMemoryBuffer(nn.Module):
    def __init__(self, obs_dim, action_dim, mem_dim=32, max_entries=100, device='cpu'):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.mem_dim = mem_dim
        self.max_entries = max_entries
        self.device = device
        self.reset()
        # Only need this! (input_proj to mem_dim, encoder uses mem_dim)
        self.embedding_proj = nn.Linear(obs_dim + action_dim + 1, mem_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=mem_dim, nhead=2, batch_first=True),
            num_layers=1
        )

    def reset(self):
        self.entries = []

    def add_entry(self, trajectory, outcome):
        traj = torch.tensor(
            [np.concatenate([obs, [action], [reward]]) for obs, action, reward in trajectory],
            dtype=torch.float32, device=self.device
        )  # [T, obs_dim+action_dim+1]
        traj_proj = self.embedding_proj(traj)  # [T, mem_dim]
        mem_embed = self.encoder(traj_proj.unsqueeze(0)).mean(dim=1).squeeze(0)  # [mem_dim]
        entry = {
            'trajectory': trajectory,
            'outcome': outcome,
            'embedding': mem_embed.detach()
        }
        self.entries.append(entry)
        if len(self.entries) > self.max_entries:
            self.entries = self.entries[-self.max_entries:]

    def retrieve(self, context_trajectory):
        if len(self.entries) == 0:
            return torch.zeros(self.mem_dim, device=self.device), None
        traj = torch.tensor(
            [np.concatenate([obs, [action], [reward]]) for obs, action, reward in context_trajectory],
            dtype=torch.float32, device=self.device
        )
        traj_proj = self.embedding_proj(traj)
        context_embed = self.encoder(traj_proj.unsqueeze(0)).mean(dim=1).squeeze(0)  # [mem_dim]
        mem_embeddings = torch.stack([e['embedding'] for e in self.entries])  # [N, mem_dim]
        attn_logits = torch.matmul(mem_embeddings, context_embed)
        attn = torch.softmax(attn_logits, dim=0)
        mem_readout = (attn.unsqueeze(1) * mem_embeddings).sum(dim=0)
        return mem_readout, attn


In [28]:
class MemoryTransformerPolicy(nn.Module):
    __version__     = "1.2.0"
    __description__ = "Hint-free memory retrieval using strategic memory buffer"

    def __init__(self, obs_dim, mem_dim=32, nhead=4, memory=None, aux_modules=None, **kwargs):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        # Memory will be concatenated, so double the input size to heads
        self.policy_head = nn.Linear(mem_dim + mem_dim, 2)
        self.value_head = nn.Linear(mem_dim + mem_dim, 1)
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.memory = memory

    def forward(self, trajectory, obs_t=None, actions=None, rewards=None):
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)
        x = self.transformer(x)
        feat = x[0, -1]  # [mem_dim]
    
        # ---- PATCH: robust history alignment ----
        if self.memory is not None and actions is not None and rewards is not None:
            actions_list = actions.tolist()
            rewards_list = rewards.tolist()
            # Pad if actions/rewards are shorter than T (common on first step)
            if len(actions_list) < T:
                actions_list = [0] * (T - len(actions_list)) + actions_list
            if len(rewards_list) < T:
                rewards_list = [0.0] * (T - len(rewards_list)) + rewards_list
            context_traj = []
            for i in range(T):
                context_traj.append((
                    trajectory[i].cpu().numpy(),
                    actions_list[i],
                    rewards_list[i]
                ))
            mem_readout, attn = self.memory.retrieve(context_traj)
        else:
            mem_readout = torch.zeros_like(feat)
        final_feat = torch.cat([feat, mem_readout], dim=-1)
        logits = self.policy_head(final_feat)
        value = self.value_head(final_feat)
        aux_preds = {}
        for aux in self.aux_modules:
            aux_preds[aux.name] = aux.head(final_feat)
        return logits, value.squeeze(-1), aux_preds



In [29]:

class ExternalMemoryPPO:
    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        device="cpu",
        her=False,
        reward_norm=False,
        intrinsic_expl=True,
        intrinsic_eta=0.01,
        ent_coef=0.01,
        memory=None,
        aux_modules=None,
        use_rnd=True, rnd_emb_dim=32, rnd_lr=1e-3,
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef=ent_coef
        self.verbose = verbose
        self.memory = memory
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.aux = len(self.aux_modules) > 0
        self.policy = policy_class(
            obs_dim=env.observation_space.shape[0], 
            memory=memory,
            aux_modules=self.aux_modules
        ).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.use_rnd = use_rnd
        if self.use_rnd:
            self.rnd = RNDModule(env.observation_space.shape[0], emb_dim=rnd_emb_dim).to(self.device)
            self.rnd_optimizer = torch.optim.Adam(self.rnd.predictor.parameters(), lr=rnd_lr)
        self.trajectory = []

    def reset_trajectory(self):
        self.trajectory = []

    def run_episode(self, her_target=None):
        obs, _ = self.env.reset()
        if her_target is not None:
            obs[0] = her_target
        # self.memory.reset()  # REMOVE unless you want to clear between episodes
    
        done = False
        trajectory = []
        actions = []
        rewards = []
        log_probs = []
        values = []
        entropies_ep = []
        aux_preds_list = []
        gate_history = []
        memory_size_history = []
        attn_weights = None
        initial_cue = int(obs[0])
        aux_targets_ep = {aux.name: [] for aux in self.aux_modules}
    
        # For memory context: action/reward at t matches obs_t
        context_traj = []  # (obs, action, reward) for the memory buffer
    
        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            trajectory.append(obs_t)
            traj = torch.stack(trajectory)
            # For correct context, fill with 0 if first step (no previous action/reward)
            action_for_mem = actions[-1].item() if len(actions) > 0 else 0
            reward_for_mem = rewards[-1].item() if len(rewards) > 0 else 0.0
            context_traj.append((obs_t.cpu().numpy(), action_for_mem, reward_for_mem))
    
            logits, value, aux_preds = self.policy(
                traj, obs_t,
                actions=torch.tensor([a.item() for a in actions], device=self.device) if actions else None,
                rewards=torch.tensor([r.item() for r in rewards], device=self.device) if rewards else None
            )
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())
    
            if self.intrinsic_expl:
                reward += self.intrinsic_eta * self.state_counter.intrinsic_reward(obs)
            rnd_intrinsic = 0.0
            if self.use_rnd:
                with torch.no_grad():
                    obs_rnd = obs_t.unsqueeze(0)
                    rnd_intrinsic = self.rnd(obs_rnd).item()
                    reward += self.intrinsic_eta * rnd_intrinsic
    
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds_list.append(aux_preds)
    
            # (aux logic stays the same)
            for aux in self.aux_modules:
                if aux.name == "cue":
                    aux_targets_ep[aux.name].append(initial_cue)
                elif aux.name == "next_obs":
                    aux_targets_ep[aux.name].append(torch.tensor(obs, dtype=torch.float32))
                elif aux.name == "confidence":
                    dist = Categorical(logits=logits)
                    entropy = dist.entropy().item()
                    confidence = 1.0 - entropy  # Or scale appropriately
                    aux_targets_ep[aux.name].append(confidence)
                elif aux.name == "event":
                    event_flag = getattr(self.env, "event_flag", 0)
                    aux_targets_ep[aux.name].append(event_flag)
                elif aux.name == "oracle_action":
                    oracle_action = getattr(self.env, "oracle_action", None)
                    aux_targets_ep[aux.name].append(oracle_action)
                else:
                    aux_targets_ep[aux.name].append(0)
    
        # ----- WRITE TO MEMORY: full episode only -----
        if self.memory is not None:
            # Use the full episode context (obs, action, reward) tuples
            # For strict "no hint", do NOT provide outcome label/score except as total reward (which agent already receives)
            outcome = sum([r.item() for r in rewards])
            self.memory.add_entry(context_traj, outcome)
    
        if self.memory is not None and hasattr(self.memory, 'last_attn'):
            attn_weights = self.memory.last_attn
        if self.use_rnd:
            obs_batch = torch.stack([torch.tensor(np.array(o), dtype=torch.float32, device=self.device) for o in trajectory])
            rnd_loss = self.rnd(obs_batch).mean()
            self.rnd_optimizer.zero_grad()
            rnd_loss.backward()
            self.rnd_optimizer.step()
    
        return {
            "trajectory": trajectory,
            "actions": actions,
            "rewards": rewards,
            "log_probs": log_probs,
            "values": values,
            "entropies": entropies_ep,
            "aux_preds": aux_preds_list,
            "aux_targets": aux_targets_ep,
            "initial_cue": initial_cue,
            "gate_history": gate_history,
            "memory_size_history": memory_size_history,
            "attn_weights": attn_weights
        }


    def learn(self, total_timesteps=2000, log_interval=100):
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []

        while steps < total_timesteps:
            episode = self.run_episode()
            if self.reward_norm:
                self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                episode["rewards"] = [torch.tensor(rn, dtype=torch.float32, device=self.device)
                                      for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])]

            trajectory = episode["trajectory"]
            actions = episode["actions"]
            rewards = episode["rewards"]
            log_probs = episode["log_probs"]
            values = episode["values"]
            entropies_ep = episode["entropies"]
            aux_preds = episode["aux_preds"]
            aux_targets = episode["aux_targets"]
            T = len(rewards)

            rewards_t = torch.stack(rewards)
            values_t = torch.stack(values)
            log_probs_t = torch.stack(log_probs)
            actions_t = torch.stack(actions)
            last_value = 0.0
            advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
            returns = advantages + values_t.detach()

            policy_loss = -(log_probs_t * advantages.detach()).sum()
            value_loss = F.mse_loss(values_t, returns)
            entropy_mean = torch.stack(entropies_ep).mean()
            explained_var = compute_explained_variance(values_t, returns)

            # ---- Multi-Aux Loss and Metrics ----
            aux_loss_total = torch.tensor(0.0, device=self.device)
            aux_metrics_log = {}
            if self.aux:
                for aux in self.aux_modules:
                    preds = torch.stack([ap[aux.name] for ap in aux_preds])
                    targets = torch.tensor(aux_targets[aux.name], device=self.device)
                    # For confidence, targets might be float!
                    if preds.dim() != targets.dim():
                        targets = targets.squeeze(-1)
                    loss = aux.aux_loss(preds, targets)
                    aux_loss_total += loss
                    # Metrics
                    metrics = aux.aux_metrics(preds, targets)
                    aux_metrics_log[aux.name] = metrics
                aux_losses.append(aux_loss_total.item())

            loss = policy_loss + 0.5 * value_loss + 0.1 * aux_loss_total - self.ent_coef * entropy_mean

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_reward = sum([r.item() for r in rewards])
            self.episode_rewards.append(total_reward)
            self.episode_lengths.append(T)
            episodes += 1
            steps += T

            # Logging
            if episodes % log_interval == 0 and self.verbose == 1:
                elapsed = int(time.time() - start_time)
                mean_rew = np.mean(self.episode_rewards[-log_interval:])
                std_rew = np.std(self.episode_rewards[-log_interval:])
                mean_len = np.mean(self.episode_lengths[-log_interval:])
                fps = int(steps / (elapsed + 1e-8))
                adv_mean = advantages.mean().item()
                adv_std = advantages.std().item()
                mean_entropy = entropy_mean.item()
                mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                stats = [{
                    "header":"rollout",
                    "stats":dict(
                        ep_len_mean=mean_len,
                        ep_rew_mean=mean_rew,
                        ep_rew_std=std_rew,
                        policy_entropy=mean_entropy,
                        advantage_mean=adv_mean,
                        advantage_std=adv_std,
                        aux_loss_mean=mean_aux
                    )},{
                    "header":"time",
                    "stats":dict(
                        fps=fps,
                        episodes=episodes,
                        time_elapsed=elapsed,
                        total_timesteps=steps,
                    )},{
                    "header":"train",
                    "stats":dict(
                        loss=loss.item(),
                        policy_loss=policy_loss.item(),
                        value_loss=value_loss.item(),
                        explained_variance=explained_var.item(),
                        n_updates=episodes
                    )}
                ]
                if len(aux_metrics_log.items()) > 0:
                    aux_stats = {
                        "header": "aux_train",
                        "stats": {}
                    }
                    for aux_name, metrics in aux_metrics_log.items():
                        for k, v in metrics.items():
                            aux_stats["stats"][f"aux_{aux_name}_{k}"] = v
                    stats.append(aux_stats)
                if self.use_rnd:
                    mean_rnd_bonus = np.mean([self.rnd(torch.tensor(np.array(o), dtype=torch.float32, device=self.device).unsqueeze(0)).item() for o in trajectory])
                    stats.append({
                        "header": "rnd_net_dist",
                        "stats": {"mean_rnd_bonus":mean_rnd_bonus}
                    })
                
                print_sb3_style_log_box(stats)
      
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")

    def predict(self, obs, deterministic=False, done=False, reward=0.0):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        # --- Track trajectory buffer ---
        if not hasattr(self, "trajectory_buffer") or self.trajectory_buffer is None:
            self.trajectory_buffer = []
        if len(self.trajectory_buffer) == 0:
            # Initial step: action/reward are dummy
            self.trajectory_buffer.append((obs_t.cpu().numpy(), 0, 0.0))
        else:
            # Last taken action, last received reward
            last_action = self.last_action if hasattr(self, "last_action") else 0
            last_reward = self.last_reward if hasattr(self, "last_reward") else 0.0
            self.trajectory_buffer.append((obs_t.cpu().numpy(), last_action, last_reward))
        # Build context for memory
        context_traj = self.trajectory_buffer.copy()
        # Actions/rewards arrays for input to policy
        actions_int = [a for _, a, _ in context_traj]
        rewards_float = [r for _, _, r in context_traj]
        obs_stack = torch.stack([torch.tensor(o, dtype=torch.float32, device=self.device) for o, _, _ in context_traj])
        logits, _, _ = self.policy(
            obs_stack, obs_t,
            actions=torch.tensor(actions_int, device=self.device),
            rewards=torch.tensor(rewards_float, device=self.device)
        )
        if deterministic:
            action = torch.argmax(logits).item()
        else:
            dist = Categorical(logits=logits)
            action = dist.sample().item()
        self.last_action = action
        self.last_reward = reward
        if done:
            self.trajectory_buffer = []
        return action

    def save(self, path="memoryppo.pt"):
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.trajectory_buffer = []
            done = False
            total_reward = 0.0
            last_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic, reward=last_reward)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                last_reward = reward
                if done:
                    self.trajectory_buffer = []
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return
    


In [33]:
delay = 64
env = MemoryTaskEnv(delay=delay, difficulty=0)
mem_dim = 32 #64*1.5

aux_modules = [
    CueAuxModule(feat_dim=mem_dim*2, n_classes=2),
    ConfidenceAuxModule(feat_dim=mem_dim*2)
]

policy = MemoryTransformerPolicy  # Or your patched ExternalMemoryTransformerPolicy

memory = StrategicMemoryBuffer(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,        # Discrete(2)
    mem_dim=mem_dim,
    max_entries=16,
    device="cpu"
)

policy = MemoryTransformerPolicy  # Or your patched ExternalMemoryTransformerPolicy



agent = ExternalMemoryPPO(
    policy_class=policy,
    use_rnd=True,
    env=env,
    memory=memory,
    #aux_modules=aux_modules,
    device="cpu",
    #learning_rate=1e-4,
    her=False,
    verbose=1,
    reward_norm=False,
    ent_coef=0.1
)

agent.learn(total_timesteps=delay*10_000, log_interval=100)

-------------------------------------
| rollout/              |           |
|    ep_len_mean        |   64.000  |
|    ep_rew_mean        |   -0.175  |
|    ep_rew_std         |    0.983  |
|    policy_entropy     |    0.245  |
|    advantage_mean     |    0.339  |
|    advantage_std      |    0.322  |
|    aux_loss_mean      |    0.000  |
| time/                 |           |
|    fps                |      168  |
|    episodes           |      100  |
|    time_elapsed       |       38  |
|    total_timesteps    |     6400  |
| train/                |           |
|    loss               |    5.045  |
|    policy_loss        |    4.962  |
|    value_loss         |    0.217  |
|    explained_variance |   -0.028  |
|    n_updates          |      100  |
| rnd_net_dist/         |           |
|    mean_rnd_bonus     |    0.000  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        |   64.000  |
|    ep_rew_

In [None]:
xxxxxxxxxxxxx

In [8]:
import gymnasium as gym
import numpy as np

class MarketHiddenRegimeMemoryEnv(gym.Env):
    """
    Custom Gymnasium environment:
    - At reset, samples a hidden regime {0: bull, 1: bear, 2: neutral}.
    - For N steps, agent observes noisy price changes (regime only weakly affects drift).
    - On final step, agent must pick regime; reward=1 if correct, else 0.
    - Agent can only win if it remembers early cues.
    """
    metadata = {"render_modes": ["human"]}
    def __init__(self, n_steps=40, obs_noise=1.0, drift_strength=0.05, seed=None):
        super().__init__()
        self.n_steps = n_steps
        self.obs_noise = obs_noise
        self.drift_strength = drift_strength
        self.rng = np.random.default_rng(seed)
        # Observations: [price_change] (can add more features)
        self.observation_space = gym.spaces.Box(low=-5, high=5, shape=(1,), dtype=np.float32)
        # Actions: Only on last step: 0=bull, 1=bear, 2=neutral
        self.action_space = gym.spaces.Discrete(3)
        self.current_step = 0
        self.regime = None
        self.price = 0.0

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = 0
        self.regime = int(self.rng.integers(0, 3))  # 0,1,2
        self.price = 0.0
        # Optionally: stronger initial drift for "hint"
        self.drift = {0: +self.drift_strength, 1: -self.drift_strength, 2: 0.0}[self.regime]
        obs = np.array([self._next_price()], dtype=np.float32)
        return obs, {}

    def _next_price(self):
        # Regime determines drift, but noise dominates
        noise = self.rng.normal(0, self.obs_noise)
        drift = self.drift
        price_change = drift + noise
        self.price += price_change
        return price_change

    def step(self, action):
        self.current_step += 1
        done = (self.current_step >= self.n_steps)
        reward = 0.0
        # Only give reward at the end, based on regime guess
        if done:
            reward = 1.0 if int(action) == self.regime else 0.0
        obs = np.array([self._next_price()], dtype=np.float32) if not done else np.zeros((1,), dtype=np.float32)
        info = {'regime': self.regime, 'step': self.current_step}
        return obs, reward, done, False, info

    def render(self, mode="human"):
        print(f"Step {self.current_step}, price={self.price:.2f}, regime={self.regime}")

    def close(self):
        pass

# Register with Gymnasium for convenience
from gymnasium.envs.registration import register
register(
    id="MarketHiddenRegimeMemory-v0",
    entry_point=MarketHiddenRegimeMemoryEnv,
    max_episode_steps=40,
)


In [22]:
env = MemoryTaskEnv(delay=64, difficulty=1)
env = gym.make("MarketHiddenRegimeMemory-v0", n_steps=40)
memory = DifferentiableEpisodicMemory(obs_dim=env.observation_space.shape[0], mem_dim=32, max_size=16)
#policy = MemoryTransformerPolicy#ExternalMemoryTransformerPolicy  
policy = ExternalMemoryTransformerPolicy  
aux_modules = [
        #CueAuxModule(feat_dim=32, n_classes=3),
        ConfidenceAuxModule(feat_dim=32)
    ]
agent = ExternalMemoryPPO(
    policy_class=policy,
    use_rnd=True,
    env=env,
    memory=memory,   
    aux_modules=aux_modules,
    device="cpu",
    learning_rate=1e-4,
    her=False,
    intrinsic_expl=False,
    verbose=1,
    reward_norm=False,
    ent_coef=0.1
)
agent.learn(total_timesteps=1_000_000, log_interval=100)

-------------------------------------
| rollout/              |           |
|    ep_len_mean        |   40.000  |
|    ep_rew_mean        |    0.414  |
|    ep_rew_std         |    0.444  |
|    policy_entropy     |    0.576  |
|    advantage_mean     |   -0.109  |
|    advantage_std      |    0.139  |
|    aux_loss_mean      |    0.024  |
| time/                 |           |
|    fps                |      199  |
|    episodes           |      100  |
|    time_elapsed       |       20  |
|    total_timesteps    |     4000  |
| train/                |           |
|    loss               |   -2.757  |
|    policy_loss        |   -2.716  |
|    value_loss         |    0.031  |
|    explained_variance |   -5.454  |
|    n_updates          |      100  |
| aux_train/            |           |
|    aux_confidence_mse |    0.016  |
| rnd_net_dist/         |           |
|    mean_rnd_bonus     |    0.015  |
-------------------------------------
-------------------------------------
| rollout/  

In [19]:
rews = []
for i in range(100):
    obs = env.reset()[0]
    target = obs[0]
    done = False
    agent.reset_trajectory()
    total_rew = 0
    while not done:
        
        action = agent.predict(obs,deterministic=True)
        obs,rew,done,_,_ = env.step(action)
        total_rew += rew
    rews.append(total_rew)
    #print(total_rew)
np.mean(rews),np.std(rews)

(0.28, 0.4489988864128729)

# Benchmark 

**Agents:**

* PPO + MlpPolicy
* RecurrentPPO + MlpLstmPolicy
* PPO + MemoryTransformerPolicy

**Environment**

* MemoryTaskEnv
    * Delays: 4,32,128 
    * Dificulties: easy, hard

**Evaluation:**
* Eval loop will consist of:
    * 50% episodes with target = 1
    * 50% episodes with target = 0
    * same environment for all agents
    * A minimum of 20 episodes
* Will return:
    * Mean episode reward
    * Std episode reward 

In [11]:
from sb3_contrib import RecurrentPPO
from stable_baselines3 import PPO
from gymnasium.wrappers import RecordEpisodeStatistics
from tqdm import tqdm
from tabulate import tabulate

import pandas as pd
import time

class AgentPerformanceBenchmark:
    """
    Benchmark class for standardized evaluation and reporting of agent performance
    on memory-based RL tasks. Handles experiment setup, training, evaluation, and result display.
    """

    def __init__(self, env_config):
        """
        Initializes the benchmark runner with experiment parameters.

        Parameters
        ----------
        env_config : dict
            Configuration dictionary with environment and experiment parameters.
        """
        self.env_config = env_config
        self.env = MemoryTaskEnv(
            delay=env_config["delay"],
            difficulty=env_config.get("difficulty", 0)
        )
        self.n_eval_episodes = env_config.get("n_eval_episodes", 20)
        self.verbose = env_config.get("verbose", 0)
        self.log_interval = env_config.get("log_interval", 10)
        self.learning_rate = env_config.get("learning_rate", 1e-3)
        self.total_timesteps = env_config.get("total_timesteps", 10000)
        self.eval_base = env_config.get("eval_base", False)
        self.mode_name = env_config.get(
            "mode_name", "EASY" if env_config.get("difficulty", 0) == 0 else "HARD"
        )

    def print_train_results(self, reward, std, model_name):
        """
        Print formatted summary of evaluation results.

        Parameters
        ----------
        reward : float
            Mean episode reward (percentage, -1.0 to 1.0)
        std : float
            Standard deviation of episode reward
        model_name : str
            Name of the agent/model
        """
        print(
            f"[{model_name} @ MemoryTaskEnv with delay={self.env_config['delay']} in {self.mode_name} Mode]  "
            f"mean_ep_rew: {reward*100:.1f}% .  std_ep_rew: {std:.2f} in {self.n_eval_episodes} episodes"
        )

    def evaluate(self, model, model_name, deterministic=True, verbose=False):
        """
        Evaluates an agent over multiple episodes, balancing both target classes.

        Parameters
        ----------
        model : object
            The RL agent (must implement `predict`)
        model_name : str
            Name for reporting
        deterministic : bool
            Use deterministic policy for evaluation
        verbose : bool
            Print result summary if True

        Returns
        -------
        mean_return : float
        std_return : float
        """
        returns = []
        target_counter = [0, 0]
        eval_complete = False
        eval_runs = 0
        n_target_samples = int(self.n_eval_episodes / 2)

        while not eval_complete:
            # Reset model's memory if possible (important for memory agents)
            if hasattr(model, "reset_trajectory") and callable(getattr(model, "reset_trajectory")):
                model.reset_trajectory()

            obs, _ = self.env.reset()
            target = int(obs[0])
            eval_runs += 1

            # Infinite loop guard
            if eval_runs > 1000:
                print("Warning: Evaluation ran over 1000 attempts, aborting early.")
                break

            # Balance classes: skip if this target has enough samples
            if target_counter[target] >= n_target_samples:
                continue
            target_counter[target] += 1

            done = False
            total_reward = 0.0
            
            while not done:
                
                action = model.predict(obs, deterministic=deterministic)
                if isinstance(action, tuple):
                    action = action[0]
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                
            returns.append(total_reward)
            eval_complete = sum(target_counter) >= self.n_eval_episodes

        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            self.print_train_results(mean_return, std_return, model_name)
        return mean_return, std_return

    def run(self):
        """
        Runs the full training and evaluation pipeline for all specified agents.
        Returns
        -------
        results : list of dicts
            Each dict contains experiment config and result metrics.
        """
        results = []
        print(
            f"\nTraining in {self.mode_name} mode with delay of {self.env_config['delay']} steps\n"
        )
    
        # List of agent configs to iterate through
        agents = []
        if self.eval_base:
            agents.append(('PPO', lambda: PPO(
                'MlpPolicy',
                RecordEpisodeStatistics(self.env),
                learning_rate=self.learning_rate,
                verbose=self.verbose
            )))
            agents.append(('RecurrentPPO', lambda: RecurrentPPO(
                "MlpLstmPolicy",
                RecordEpisodeStatistics(self.env),
                verbose=self.verbose,
                learning_rate=self.learning_rate
            )))
        agents.append(('MemoryPPO', lambda: MemoryPPO(
            MemoryTransformerPolicy,
            self.env,
            learning_rate=self.learning_rate,
            her=False,
            verbose=self.verbose,
            aux=False
        )))
    
        # Custom tqdm progress bar for agent training & evaluation
        with tqdm(total=len(agents)*2 + 1, desc="Benchmark Progress", unit="step") as pbar:
            for agent_name, agent_builder in agents:
                # TRAIN
                pbar.set_description(f"Training {agent_name}")
                start_time = time.time()
                model = agent_builder()
                model.learn(
                    total_timesteps=self.total_timesteps,
                    log_interval=self.log_interval
                )
                duration = time.time() - start_time
                pbar.update(1)
    
                # EVALUATE
                pbar.set_description(f"Evaluating {agent_name}")
                mean, std = self.evaluate(model, agent_name, verbose=False)
                pbar.update(1)
    
                results.append({
                    **self.env_config,
                    'agent': agent_name,
                    'mean_return': mean,
                    'std_return': std,
                    'duration': duration
                })
    
            # Before printing table, update tqdm and close bar
            pbar.set_description("Finalizing Results")
            pbar.update(1)
    
        # --- Format & Print Table ---
        pdf = pd.DataFrame(results)
        pdf = pdf[['agent', 'delay', 'mode_name', 'mean_return', 'std_return', 'duration']]
        pdf.rename(
            columns={
                "agent": "Agent",
                "delay": "Delay",
                "mode_name": "Mode",
                "mean_return": "Mean Ep Rew",
                "std_return": "Std Ep Rew",
                "duration": "Duration (s)"
            },
            inplace=True
        )
        print(tabulate(pdf, headers="keys", tablefmt="rounded_outline"))
    
        return results


  fn()


In [12]:
import warnings
warnings.filterwarnings('ignore')

# ---- Batch experiment setup ----
if __name__ == "__main__":
    results = []
    EXPERIMENTS = [
        dict(
            delay=4,
            n_train_episodes=2000,
            total_timesteps=2000*4,
            difficulty=0,
            mode_name="EASY",
            verbose=0,
            eval_base=True
        ),
        dict(
            delay=4,
            n_train_episodes=5000,
            total_timesteps=5000*4,
            difficulty=1,
            mode_name="HARD",
            verbose=0,
            eval_base=True
        ),
        dict(
            delay=32,
            n_train_episodes=7500,
            total_timesteps=7500*32,
            difficulty=0,
            mode_name="EASY",
            verbose=0,
            eval_base=False
        ),
        dict(
            delay=32,
            n_train_episodes=10000,
            total_timesteps=10000*32,
            difficulty=1,
            mode_name="HARD",
            verbose=0,
            eval_base=False
        ),
        dict(
            delay=64,
            n_train_episodes=15000,
            total_timesteps=15000*64,
            difficulty=0,
            mode_name="HARD",
            verbose=0,
            eval_base=False
        ),dict(
            delay=128,
            n_train_episodes=20000,
            total_timesteps=20000*128,
            difficulty=0,
            mode_name="HARD",
            verbose=0,
            eval_base=False
        )]
    
    for exp in EXPERIMENTS:
        benchmark = AgentPerformanceBenchmark(exp)
        results.append(benchmark.run())


Training in EASY mode with delay of 4 steps



Training MemoryPPO:  57%|█████▋    | 4/7 [00:33<00:25,  8.44s/step]     


RuntimeError: 
Module 'MemoryTransformerPolicy' has no attribute 'aux_modules' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type. List trace inputs must have elements. Its type was inferred; try adding a type annotation for the attribute.):
  File "C:\Users\filip_a58djhu\AppData\Local\Temp\ipykernel_22968\1245403817.py", line 158
        value = self.value_head(feat)
        aux_preds = {}
        for aux in self.aux_modules:
                   ~~~~~~~~~~~~~~~~ <--- HERE
            aux_preds[aux.name] = aux.head(feat)
        return logits, value.squeeze(-1), aux_preds


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

results_nested = results
flat_results = [item for sublist in results_nested for item in sublist]

df = pd.DataFrame(flat_results)

group_keys = df.groupby(['delay', 'difficulty']).groups.keys()
grouped = df.groupby(['delay', 'difficulty'])

agent_list = sorted(df['agent'].unique())
agent_colors = {
    agent: color for agent, color in zip(
        agent_list,
        ['#0072B2', '#009E73', '#D55E00', '#CC79A7', '#F0E442', '#56B4E9']
    )
}

grouped = df.groupby(['delay', 'difficulty'])

n_groups = len(grouped)
fig, axes = plt.subplots(1, n_groups, figsize=(4*n_groups, 6), sharey=True)

if n_groups == 1:
    axes = [axes]  

for ax, ((delay, difficulty), group) in zip(axes, grouped):
    mode = group['mode_name'].iloc[0]
    agents = group['agent']
    means = group['mean_return']
    stds = group['std_return']
    
    # Get bar colors based on agent names
    bar_colors = [agent_colors[agent] for agent in agents]
    
    bars = ax.bar(
        agents,
        means,
        yerr=stds,
        capsize=6,
        color=bar_colors,
        width=0.7
    )
    ax.set_title(f"Delay={delay}\nMode={mode}", fontsize=12)
    ax.set_xlabel("Agent")
    ax.set_ylim(-1.2, 1.2)
    ax.axhline(0, color='grey', linestyle='--', linewidth=1)
    ax.set_xticks(range(len(agents)))
    ax.set_xticklabels(agents, rotation=20)
    if ax == axes[0]:
        ax.set_ylabel("Mean Return")


handles = [plt.Rectangle((0,0),1,1, color=agent_colors[agent]) for agent in agent_list]
labels = agent_list
fig.legend(handles, labels, title="Agent", loc='lower center', ncol=len(agent_list), bbox_to_anchor=(0.5, -0.02))

plt.suptitle("Agent Performance per Experiment Group", fontsize=16, y=1.04)
plt.tight_layout()
plt.subplots_adjust(bottom=0.18)  
plt.show()



In [None]:

df[['delay', 'difficulty', 'mode_name', 'agent', 'mean_return', 'std_return']]


In [None]:
from tabulate import tabulate

def print_training_stats(
    mean_len, mean_rew, mean_entropy, adv_mean, adv_std, mean_aux,
    aux_metrics_log, fps, episodes, elapsed, steps,
    loss, policy_loss, value_loss, explained_var
):
    print(aux_metrics_log)
    # Gather rollout stats
    rollout_stats = [
        ["ep_len_mean",      f"{mean_len:8.2f}"],
        ["ep_rew_mean",      f"{mean_rew:8.2f}"],
        ["policy_entropy",   f"{mean_entropy:8.3f}"],
        ["advantage_mean",   f"{adv_mean:8.3f}"],
        ["advantage_std",    f"{adv_std:8.3f}"],
        ["aux_loss_mean",    f"{mean_aux:8.3f}"],
    ]

    # Add aux metrics
    for aux_name, metrics in aux_metrics_log.items():
        for k, v in metrics.items():
            rollout_stats.append([f"aux_{aux_name}_{k}", f"{v:8.4f}"])

    # Time stats
    time_stats = [
        ["fps",            f"{fps:8d}"],
        ["episodes",       f"{episodes:8d}"],
        ["time_elapsed",   f"{elapsed:8.1f}"],
        ["total_timesteps",f"{steps:8d}"],
    ]
    """
    # Training stats
    train_stats = [
        ["loss",              f"{loss.item():8.3f}"],
        ["policy_loss",       f"{policy_loss.item():8.3f}"],
        ["value_loss",        f"{value_loss.item():8.3f}"],
        ["explained_variance",f"{explained_var.item():8.3f}"],
        ["n_updates",         f"{episodes:8d}"],
    ]
    """
    print("\n" + "="*20 + " ROLLOUT " + "="*20)
    print(tabulate(rollout_stats, headers=["Metric", "Value"], tablefmt="github"))
    print("\n" + "="*20 + " TIME " + "="*22)
    print(tabulate(time_stats, headers=["Metric", "Value"], tablefmt="github"))
    print("\n" + "="*20 + " TRAIN " + "="*21)
    #print(tabulate(train_stats, headers=["Metric", "Value"], tablefmt="github"))
    print("="*50 + "\n")

In [None]:
def print_sb3_style_log(
    mean_len, mean_rew, mean_entropy, adv_mean, adv_std, mean_aux,
    aux_metrics_log, fps, episodes, elapsed, steps,
    loss, policy_loss, value_loss, explained_var
):
    print("-" * 38)
    print(f"| rollout/                 |")
    print(f"|    {'ep_len_mean':<21} | {mean_len:8.2f}")
    print(f"|    {'ep_rew_mean':<21} | {mean_rew:8.2f}")
    print(f"|    {'policy_entropy':<21} | {mean_entropy:8.3f}")
    print(f"|    {'advantage_mean':<21} | {adv_mean:8.3f}")
    print(f"|    {'advantage_std':<21} | {adv_std:8.3f}")
    print(f"|    {'aux_loss_mean':<21} | {mean_aux:8.3f}")
    # Print auxiliary metrics if present
    for aux_name, metrics in aux_metrics_log.items():
        for k, v in metrics.items():
            metric_name = f"aux_{aux_name}_{k}"
            print(f"|    {metric_name:<21} | {v:8.4f}")

    print(f"| time/                    |")
    print(f"|    {'fps':<21} | {fps:8d}")
    print(f"|    {'episodes':<21} | {episodes:8d}")
    print(f"|    {'time_elapsed':<21} | {elapsed:8.1f}")
    print(f"|    {'total_timesteps':<21} | {steps:8d}")

    print(f"| train/                   |")
    print(f"|    {'loss':<21} | {loss.item():8.3f}")
    print(f"|    {'policy_loss':<21} | {policy_loss.item():8.3f}")
    print(f"|    {'value_loss':<21} | {value_loss.item():8.3f}")
    print(f"|    {'explained_variance':<21} | {explained_var.item():8.3f}")
    print(f"|    {'n_updates':<21} | {episodes:8d}")
    print("-" * 38)

"""
--------------------------------------
| rollout/               |           |
|    ep_len_mean         |     8.00  |
|    ep_rew_mean         |     0.98  |
|    policy_entropy      |    0.563  |
|    advantage_mean      |   -1.623  |
|    advantage_std       |    0.229  |
|    aux_loss_mean       |    0.000  |
| time/                  |           |
|    fps                 |      281  |
|    episodes            |    61600  |
|    time_elapsed        |   1749.3  |
|    total_timesteps     |   492800  |
| train/                 |           |
|    loss                |  -14.634  |
|    policy_loss         |  -15.917  |
|    value_loss          |    2.679  |
|    explained_variance  |   -0.068  |
|    n_updates           |    61600  |
--------------------------------------
"""
print('xx')

In [None]:
def print_sb3_style_log_box(stats):
    # Flatten stats for max width calculation
    all_rows = []
    for section in stats:
        all_rows.append((section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            all_rows.append((k, v, False))
    # Compute widths
    key_width = max(
        len("    " + k) if not is_section else len(k)
        for k, v, is_section in all_rows
    )
    val_width = 10
    box_width = 2 + key_width + 3 + val_width 

    def fmt_row(label, value, is_section):
        if is_section:
            return f"| {label:<{key_width}}|{' ' * val_width} |"
        else:
            # Format value: float = 8.3f, int = 8d, tensor fallback
            if hasattr(value, 'item'):
                value = value.item()
            if isinstance(value, float):
                s_value = f"{value:8.3f}"
            elif isinstance(value, int):
                s_value = f"{value:8d}"
            else:
                s_value = str(value)
            s_value_centered = f"{s_value:^{val_width}}"
            return f"|    {label:<{key_width-4}} |{s_value_centered} |"

    print("-" * box_width)
    for section in stats:
        print(fmt_row(section["header"] + "/", None, True))
        for k, v in section["stats"].items():
            print(fmt_row(k, v, False))
    print("-" * box_width)

# Example usage:
if __name__ == "__main__":
    import torch
    stats = [{
        "header":"rollout",
        "stats":dict(
            ep_len_mean=8.0,
            ep_rew_mean=0.98,
            policy_entropy=0.563,
            advantage_mean=-1.623,
            advantage_std=0.229,
            aux_loss_mean=0.0
        )},{
        "header":"time",
        "stats":dict(
            fps=281,
            episodes=61600,
            time_elapsed=1749.3,
            total_timesteps=492800,
        )},{
        "header":"train",
        "stats":dict(
            loss=torch.tensor(-14.634),
            policy_loss=torch.tensor(-15.917),
            value_loss=torch.tensor(2.679),
            explained_variance=torch.tensor(-0.068),
            n_updates=61600
        )}
    ]
    print_sb3_style_log_box(stats)


# Progress check list

## Implemented Features

* [x] **GAE (Generalized Advantage Estimation)**
* [x] **Reward Normalization**
* [x] **Auxiliary Head (cue prediction)**
* [x] **Intrinsic State Novelty Bonus**
* [x] **HER (Hindsight Experience Replay)**
* [x] **Transformer Policy (sequence aware)**
* [x] **SB3-style Logging/Diagnostics**
* [x] **Clipped Surrogate Objective** *(PPO clip range)*
* [x] **Entropy Bonus Tuning**
* [x] **Value Function Clipping**
* [x] **Observation Normalization**
* [x] **Fine-grained Logging** *(KL, grad norm, action dist)*

---

## Next Tweaks

### 1. PPO Core Improvements

* [x] **Clipped Surrogate Objective:**
  (Implement the PPO `clip_range` for policy ratio.)
* [x] **Entropy Bonus Tuning:**
  (Different entropy coefficients.)

### 2. Policy/Value Stabilization

* [x] **Value Function Clipping:**
  (Clip value updates for more stability.)

### 3. Generalization & Robustness

* [x] **Observation Normalization:**
  (Normalize obs per feature dimension.)

### 4. Training Techniques

* [ ] **Curriculum Learning:**
  (Ramp up environment delay over training.)
* [ ] **KL Penalty / Early Stopping:**
  (Monitor KL divergence; adapt learning or stop early.)
* [ ] **Multi-Task Heads:**
  (Add additional prediction heads, e.g., next-step, event marker.)

### 5. Exploration/Intrinsic Motivation

* [ ] **Random Network Distillation (RND):**
  (Encourage exploring states with high prediction error.)
* [ ] **Parameter Noise:**
  (Add noise to weights for more diverse policies.)

### 6. Diagnostics/Analysis

* [ ] **Distributional RL Head:**
  (Predict return distribution, not just mean.)

---

## Real-World Stock/Market Adaptation

1. [ ] **Extend Observation/Action Spaces:**
   (Use richer features: price, volume, meta-events, etc.)
2. [ ] **Contextualize "Goal"/HER:**
   (E.g., relabel episodes based on "target" events or future outcome.)
3. [ ] **Event/Pattern Detection Head:**
   (Auxiliary head for predicting/flagging event boundaries.)
4. [ ] **Longer Trajectories & Robust Memory:**
   (Train with variable or longer delays. Use curriculum.)
5. [ ] **Batch Training/Efficient Sampling:**
   (Switch to multi-episode minibatches for efficiency.)
6. [ ] **Backtesting & Market Evaluation:**
   (Test generalization to unseen market regimes/periods.)
7. [ ] **Regime-Aware Curriculum:**
   (Adapt delay, noise, or event frequency according to market regime.)
8. [ ] **Add Macro/External Signals:**
   (Economic indicators, sector events, etc.)
9. [ ] **Production Robustness:**
   (Safe action constraints, stable inference, memory persistence.)


