# Motif Mining / Combined Memory Module

### What is a Motif (in time series/finance context):
A motif is a short, distinctive, recurring pattern in time series data (e.g., “three rising candles”, “hammer candlestick”).

Motif mining is about discovering these patterns that frequently occur, possibly before some important event (like price spikes).

Motifs can be fixed length (e.g., always 3 bars) or variable length, but they're usually not tied to any particular "agent experience" or reward—they're just patterns that are statistically common or relevant to outcomes.

### How does this differ from current agent’s memory?
Motifs could become part of memory if they prove useful, but in RL, memory entries are scored and selected by usefulness to the agent’s task, not just frequency.

* **Motif:**
  * **Purely pattern-based:** “What sequence shapes show up often in the market?”
  * **Unsupervised:** Does not depend on what the agent did or the rewards/outcomes.
  * Often discovered with algorithms like matrix profile, SAX, or clustering over subsequences.

* **Strategic RL Memory:**
  * Stores sequences of observations, actions, rewards from actual episodes, tied to what the agent did and what outcome it got.
  * Is used for retrieval during decision-making, not just for pattern mining.
  * Memory can be trained to only keep those episodes/patterns that are useful for policy improvement, not just frequent.

* **Summary:**
  * **Motif:**  Statistically recurring pattern in the world
  * **Memory_** Agent’s own experienced or retained pattern, which it can choose to use, forget, or score for future use


### Goal:

* Get both kinds of retrieval in one single process.

### Summary:
* All memory retrieval (episodic and motif) is neural, attention-based, and trainable.

* Motif memory can be used for either unsupervised mining (offline DTW) or end-to-end learned patterns.

* Everything is differentiable and ready for RL + auxiliary losses.





In [10]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np

sys.path.append('../')
from environments import MemoryTaskEnv
from memory import StrategicMemoryBuffer, BaseMemoryBuffer,StrategicMemoryTransformerPolicy
from agent import TraceRL

In [11]:
import torch
import torch.nn as nn
import numpy as np

class BaseMemoryBuffer(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def get_last_attention(self):
        """Return the latest attention weights, if available."""
        raise NotImplementedError("Memory module must implement get_last_attention()")

class StrategicMemoryBuffer(BaseMemoryBuffer):
    """
    Episodic memory buffer for RL agents using neural trajectory encoding and
    attention-based retrieval, with learnable usefulness/retention scores.

    Features:
        - Stores episode trajectories and outcomes.
        - Each trajectory is encoded as a vector with a Transformer.
        - Each memory entry gets a trainable usefulness parameter.
        - When full, discards the least-useful entry.
        - Returns soft-attended memory readout given a context trajectory.

    Args:
        obs_dim (int): Observation dimension.
        action_dim (int): Action dimension (scalar=1).
        mem_dim (int): Embedding size for memory entries.
        max_entries (int): Max entries to keep (FIFO with learning).
        device (str): Device for tensors ("cpu" or "cuda").
    """

    __version__     = "1.3.0"
    __description__ = "Memory usefullness is now trainable and retention is based on it"


    def __init__(self, obs_dim, action_dim, mem_dim=32, max_entries=100, device=DEFAULT_DEVICE):
        super().__init__()
      
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.mem_dim = mem_dim
        self.max_entries = max_entries
        self.device = device

        self.embedding_proj = nn.Linear(obs_dim + action_dim + 1, mem_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=mem_dim, nhead=2, batch_first=True),
            num_layers=1
        )

        # Fixed usefulness per slot (size max_entries)
        self.usefulness_vec = nn.Parameter(torch.zeros(max_entries, device=self.device), requires_grad=True)
        self.entries = []
        self._entry_indices = []  # Maps entries to slot indices

    def reset(self):
        """Clears all memory entries (keeps usefulness param, but marks buffer empty)."""
        self.entries = []
        self._entry_indices = []

    def add_entry(self, trajectory, outcome):
        """
        Stores a new trajectory/outcome. If full, replaces the least useful slot.
        """
        traj_np = np.array([np.concatenate([obs, [action], [reward]]) for obs, action, reward in trajectory], dtype=np.float32)
        traj = torch.from_numpy(traj_np).to(self.device)
        traj_proj = self.embedding_proj(traj)
        mem_embed = self.encoder(traj_proj.unsqueeze(0)).mean(dim=1).squeeze(0)
        entry = {
            'trajectory': trajectory,
            'outcome': outcome,
            'embedding': mem_embed.detach()
        }
        if len(self.entries) < self.max_entries:
            # Use next available slot (by index)
            slot_idx = len(self.entries)
            self.entries.append(entry)
            self._entry_indices.append(slot_idx)
        else:
            # Overwrite least-useful slot
            usefulness_scores = self.usefulness_vec.detach().cpu()
            idx_remove = usefulness_scores.argmin().item()
            self.entries[idx_remove] = entry
            self._entry_indices[idx_remove] = idx_remove

    def retrieve(self, context_trajectory):
        if len(self.entries) == 0:
            self.last_attn = None
            return torch.zeros(self.mem_dim, device=self.device), None, None
        traj_np = np.array([np.concatenate([obs, [action], [reward]]) for obs, action, reward in context_trajectory], dtype=np.float32)
        traj = torch.from_numpy(traj_np).to(self.device)
        traj_proj = self.embedding_proj(traj)
        context_embed = self.encoder(traj_proj.unsqueeze(0)).mean(dim=1).squeeze(0)
        mem_embeddings = torch.stack([e['embedding'] for e in self.entries])  # [N, mem_dim]
        attn_logits = torch.matmul(mem_embeddings, context_embed)
        attn = torch.softmax(attn_logits, dim=0)
        mem_readout = (attn.unsqueeze(1) * mem_embeddings).sum(dim=0)
        self.last_attn = attn.detach().cpu().numpy()
        return mem_readout, attn, None  # PATCH: add None for motif_attn

    def usefulness_loss(self, attn, reward):
        """
        Compute usefulness loss: encourages usefulness score to match (attn * reward).
        Only updates the active slots (those with entries).
        """
        # attn: [N] attention weights used for the memory readout
        # reward: scalar, or [N]
        N = len(self.entries)
        if N == 0:
            return torch.tensor(0.0, device=self.device)
        idxs = self._entry_indices[:N]
        usefulness_vec = self.usefulness_vec[idxs]  # Only the in-use slots
        if isinstance(reward, (float, int)):
            reward = torch.tensor(reward, dtype=torch.float32, device=self.device)
        targets = attn.detach() * reward  # [N]
        loss = ((usefulness_vec - targets) ** 2).mean()
        return loss

    def usefulness_parameters(self):
        return [self.usefulness_vec]

    def get_trainable_parameters(self):
        # Everything that should be trained here
        return list(self.parameters()) + list(self.usefulness_parameters())

    def get_last_attention(self):
        return getattr(self, "last_attn", None)
        
    def retrieve_with_custom_query(self, custom_query):
        """
        Attend over memory entries using a custom query vector (not trajectory context).
        """
        if len(self.entries) == 0:
            self.last_attn = None
            return torch.zeros(self.mem_dim, device=self.device), None
        mem_embeddings = torch.stack([e['embedding'] for e in self.entries])  # [N, mem_dim]
        attn_logits = torch.matmul(mem_embeddings, custom_query)
        attn = torch.softmax(attn_logits, dim=0)
        mem_readout = (attn.unsqueeze(1) * mem_embeddings).sum(dim=0)
        self.last_attn = attn.detach().cpu().numpy()
        return mem_readout, attn

In [12]:
class MotifMemoryBank(BaseMemoryBuffer):
    """
    Motif memory: learnable bank of pattern embeddings, attention-retrieved.

    Features:
        - Stores K motif embeddings, trainable.
        - Neural encoder to encode subtrajectories as motifs.
        - Attention over motifs given current context trajectory.
    """
    def __init__(self, obs_dim, action_dim, mem_dim=32, n_motifs=32, motif_len=4, device='cpu'):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.mem_dim = mem_dim
        self.n_motifs = n_motifs
        self.motif_len = motif_len
        self.device = device
        self.last_attn = None
        # Learnable motif memory bank
        self.motif_embeds = nn.Parameter(torch.randn(n_motifs, mem_dim))
        # Neural encoder for extracting motifs from subtrajectories
        self.embedding_proj = nn.Linear(obs_dim + action_dim + 1, mem_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=mem_dim, nhead=2, batch_first=True),
            num_layers=1
        )

    def retrieve(self, context_traj):
        if len(context_traj) < self.motif_len:
            pad = [context_traj[0]] * (self.motif_len - len(context_traj))
            motif_traj = pad + context_traj
        else:
            motif_traj = context_traj[-self.motif_len:]

        motif_np = np.array([np.concatenate([obs, [a], [r]]) for obs, a, r in motif_traj], dtype=np.float32)
        motif_input = torch.from_numpy(motif_np).unsqueeze(0).to(self.device)
        motif_embed = self.encoder(self.embedding_proj(motif_input)).mean(dim=1).squeeze(0)  # [mem_dim]
        attn_logits = torch.matmul(self.motif_embeds, motif_embed)
        attn = torch.softmax(attn_logits, dim=0)
        motif_readout = (attn.unsqueeze(1) * self.motif_embeds).sum(dim=0)
        self.last_attn = attn.detach().cpu().numpy()
        return motif_readout, None, attn  # PATCH: None for epi_attn, attn for motif_attn

    def motif_parameters(self):
        return [self.motif_embeds]

    def get_trainable_parameters(self):
        return list(self.parameters()) + list(self.motif_parameters())

    def get_last_attention(self):
        return self.last_attn

In [13]:
class CombinedMemoryModule(BaseMemoryBuffer):
    def __init__(self, episodic_buffer, motif_bank):
        super().__init__()
        self.episodic_buffer = episodic_buffer
        self.motif_bank = motif_bank
        self.last_attn = None


    def retrieve(self, context_trajectory):
        motif_readout, motif_attn, _ = self.motif_bank.retrieve(context_trajectory)
        epi_readout, epi_attn, _ = self.episodic_buffer.retrieve_with_custom_query(motif_readout)
        combined = torch.cat([epi_readout, motif_readout], dim=-1)
        self.last_attn = (epi_attn, motif_attn)
        return combined, epi_attn, motif_attn

    def add_entry(self, trajectory, outcome):
        self.episodic_buffer.add_entry(trajectory, outcome)
        # Motif bank may NOT need this, but later might optionally do motif mining here 
        # For now, only episodic buffer gets new entries
        # If you want motifs to be updated with experience, call self.motif_bank.add_entry(trajectory, outcome) if you define it

    def get_trainable_parameters(self):
        params = []
        if hasattr(self, "episodic_buffer"):
            params += self.episodic_buffer.get_trainable_parameters()
        if hasattr(self, "motif_bank"):
            params += self.motif_bank.get_trainable_parameters()
        return params

    def get_last_attention(self):
        return self.last_attn  # tuple: (episodic, motif)


In [14]:
class StrategicCombinedMemoryPolicy(nn.Module):
    def __init__(self, obs_dim, mem_dim=32, nhead=4, memory=None, aux_modules=None, **kwargs):
        super().__init__()
        self.mem_dim = mem_dim
        self.embed = nn.Linear(obs_dim, mem_dim)
        self.pos_embed = nn.Embedding(256, mem_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=mem_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.policy_head = nn.Linear(mem_dim + 2 * mem_dim, 2)   # now +2mem_dim (episodic + motif)
        self.value_head = nn.Linear(mem_dim + 2 * mem_dim, 1)
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.memory = memory

    def forward(self, trajectory, obs_t=None, actions=None, rewards=None):
        T = trajectory.shape[0]
        x = self.embed(trajectory)
        pos = torch.arange(T, device=trajectory.device)
        x = x + self.pos_embed(pos)
        x = x.unsqueeze(0)
        x = self.transformer(x)
        feat = x[0, -1]

        mem_feat = torch.zeros(2 * self.mem_dim, device=feat.device)
        epi_attn, motif_attn = None, None
        if self.memory is not None and actions is not None and rewards is not None:
            actions_list = actions.tolist()
            rewards_list = rewards.tolist()
            if len(actions_list) < T:
                actions_list = [0] * (T - len(actions_list)) + actions_list
            if len(rewards_list) < T:
                rewards_list = [0.0] * (T - len(rewards_list)) + rewards_list
            context_traj = [
                (trajectory[i].cpu().numpy(), actions_list[i], rewards_list[i]) for i in range(T)
            ]
            mem_raw, epi_attn, motif_attn = self.memory.retrieve(context_traj)
            # Robust: shape of mem_raw
            if mem_raw.shape[0] == 2 * self.mem_dim:
                mem_feat = mem_raw
            elif mem_raw.shape[0] == self.mem_dim:
                # Pad: put episodic first, motif second (or vice versa, as you wish)
                # Here: [episodic, motif], pad motif with zeros if only episodic is present
                if epi_attn is not None and motif_attn is None:
                    mem_feat = torch.cat([mem_raw, torch.zeros(self.mem_dim, device=mem_raw.device)], dim=0)
                elif motif_attn is not None and epi_attn is None:
                    mem_feat = torch.cat([torch.zeros(self.mem_dim, device=mem_raw.device), mem_raw], dim=0)
                else:
                    # fallback: both None, just zeros
                    mem_feat = torch.zeros(2 * self.mem_dim, device=mem_raw.device)
        final_feat = torch.cat([feat, mem_feat], dim=-1)
        logits = self.policy_head(final_feat)
        value = self.value_head(final_feat)
        aux_preds = {}
        for aux in self.aux_modules:
            aux_preds[aux.name] = aux.head(final_feat)
        return logits, value.squeeze(-1), aux_preds


In [15]:
from sklearn.cluster import KMeans

def mine_motifs_from_buffer(episodic_buffer, motif_bank, motif_len=4, n_motifs=32, min_windows=100):
    """
    Mine most frequent motifs from episodic buffer and refresh the motif bank.
    - motif_bank: instance of MotifMemoryBank
    - episodic_buffer: should have .entries (each has 'trajectory')
    """
    subtraj_embeds = []
    for entry in episodic_buffer.entries:
        traj = entry['trajectory']
        if len(traj) >= motif_len:
            for i in range(len(traj) - motif_len + 1):
                window = traj[i:i+motif_len]
                # Convert window to tensor and embed
                window_np = np.array([np.concatenate([obs, [a], [r]]) for obs, a, r in window], dtype=np.float32)
                window_tensor = torch.from_numpy(window_np).unsqueeze(0).to(motif_bank.device)
                with torch.no_grad():
                    embed = motif_bank.encoder(motif_bank.embedding_proj(window_tensor)).mean(dim=1).squeeze(0).cpu().numpy()
                subtraj_embeds.append(embed)
    if len(subtraj_embeds) < max(n_motifs, min_windows):
        print(f"Motif mining: not enough motif windows ({len(subtraj_embeds)}) for {n_motifs} motifs.")
        return
    X = np.stack(subtraj_embeds)
    kmeans = KMeans(n_clusters=n_motifs, random_state=0, n_init="auto").fit(X)
    centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=motif_bank.device)
    with torch.no_grad():
        motif_bank.motif_embeds.copy_(centroids)
    print(f"[Motif mining] Updated motif bank with {n_motifs} clusters from buffer.")


In [16]:
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tabulate import tabulate

from core_modules import RewardNormalizer, StateCounter, RNDModule
from core_calculations import compute_gae, compute_explained_variance
from callbacks import print_sb3_style_log_box
MINING_INTERVAL = 100 

class TraceRL:
    """
    Proximal Policy Optimization (PPO) agent with integrated external memory retrieval.

    Features:
        - Supports auxiliary losses, HER, reward normalization, and RND-based exploration.
        - Episodic or contextual memory (passed as `memory`) for strategic RL.
        - Plug-and-play auxiliary modules (e.g., cue, event, confidence).
        - Stable training with reward normalization and intrinsic/extrinsic reward mixing.

    Args:
        policy_class (nn.Module): Policy network class (should accept obs_dim, memory, aux_modules).
        env (gym.Env): Gymnasium environment.
        verbose (int): Logging verbosity (0 = silent, 1 = logs).
        learning_rate (float): Adam optimizer learning rate.
        gamma (float): Discount factor.
        lam (float): GAE lambda.
        device (str): Torch device.
        her (bool): Enable Hindsight Experience Replay (if supported by env).
        reward_norm (bool): Normalize reward with running stats.
        intrinsic_expl (bool): Use count-based intrinsic reward.
        intrinsic_eta (float): Scaling for intrinsic bonus.
        ent_coef (float): Entropy coefficient.
        memory: Memory module for contextual/episodic learning (optional).
        aux_modules (list): List of auxiliary task modules (optional).
        use_rnd (bool): Enable Random Network Distillation intrinsic reward.
        rnd_emb_dim (int): Embedding dim for RND networks.
        rnd_lr (float): Learning rate for RND predictor.
    """


    __version__ = "1.4.0"

    def __init__(
        self, 
        policy_class, 
        env, 
        verbose=0,
        learning_rate=1e-3, 
        gamma=0.99, 
        lam=0.95, 
        ent_coef=0.01,
        device="cpu",
        her=False,
        reward_norm=False,
        intrinsic_expl=True,
        intrinsic_eta=0.01,
        memory=None,
        aux_modules=None,
        use_rnd=False, 
        rnd_emb_dim=32, 
        rnd_lr=1e-3,
        memory_learn_retention=False,      
        memory_retention_coef=0.01,
        early_stop=True,
        early_stop_n_samples=100,
        early_stop_mean_threshold=0.95,
        early_stop_std_threshold=0.05,
    ):
        self.env = env
        self.device = torch.device(device)
        self.gamma = gamma
        self.lam = lam
        self.ent_coef = ent_coef
        self.verbose = verbose
        self.memory = memory
        self.memory_learn_retention = memory_learn_retention
        self.memory_retention_coef = memory_retention_coef
        self.aux_modules = aux_modules if aux_modules is not None else []
        self.aux = len(self.aux_modules) > 0
        self.early_stop= early_stop
        self.early_stop_n_samples=early_stop_n_samples
        self.early_stop_mean_threshold= early_stop_mean_threshold
        self.early_stop_std_threshold= early_stop_std_threshold
        # Policy: must accept obs_dim, memory, aux_modules
        self.policy = policy_class(
            obs_dim=env.observation_space.shape[0], 
            memory=memory,
            aux_modules=self.aux_modules
        ).to(self.device)

        # PATCH: include modular learning parameters to the optimizer 

        params = list(self.policy.parameters())
        if self.memory_learn_retention and hasattr(self.memory, "usefulness_parameters"):
            params += list(self.memory.usefulness_parameters())
        params = list({id(p): p for p in params}.values())  # REMOVE DUPLICATES
        self.optimizer = torch.optim.Adam(params, lr=learning_rate)

        self.training_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []
        self.her = her
        self.reward_norm = reward_norm
        self.intrinsic_expl = intrinsic_expl
        self.intrinsic_eta = intrinsic_eta
        self.reward_normalizer = RewardNormalizer()
        self.state_counter = StateCounter()
        self.use_rnd = use_rnd
        if self.use_rnd:
            self.rnd = RNDModule(env.observation_space.shape[0], emb_dim=rnd_emb_dim).to(self.device)
            self.rnd_optimizer = torch.optim.Adam(self.rnd.predictor.parameters(), lr=rnd_lr)
        self.trajectory_buffer = []

    def reset_trajectory(self):
        self.trajectory_buffer = []

    def run_episode(self, her_target=None):
        obs, _ = self.env.reset()
        if her_target is not None:
            obs[0] = her_target

        done = False
        trajectory, actions, rewards, log_probs, values = [], [], [], [], []
        entropies_ep, aux_preds_list = [], []
        gate_history, memory_size_history = [], []
        attn_weights = None
        initial_cue = int(obs[0])
        aux_targets_ep = {aux.name: [] for aux in self.aux_modules}
        context_traj = []  # For memory module

        while not done:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
            trajectory.append(obs_t)
            traj = torch.stack(trajectory)
            action_for_mem = actions[-1].item() if len(actions) > 0 else 0
            reward_for_mem = rewards[-1].item() if len(rewards) > 0 else 0.0
            context_traj.append((obs_t.cpu().numpy(), action_for_mem, reward_for_mem))

            logits, value, aux_preds = self.policy(
                traj, obs_t,
                actions=torch.tensor([a.item() for a in actions], device=self.device) if actions else None,
                rewards=torch.tensor([r.item() for r in rewards], device=self.device) if rewards else None
            )
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
            obs, reward, done, _, _ = self.env.step(action.item())

            # Intrinsic reward: count-based and/or RND
            if self.intrinsic_expl:
                reward += self.intrinsic_eta * self.state_counter.intrinsic_reward(obs)
            rnd_intrinsic = 0.0
            if self.use_rnd:
                with torch.no_grad():
                    obs_rnd = obs_t.unsqueeze(0)
                    rnd_intrinsic = self.rnd(obs_rnd).item()
                    reward += self.intrinsic_eta * rnd_intrinsic

            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(torch.tensor(reward, dtype=torch.float32, device=self.device))
            values.append(value)
            entropies_ep.append(entropy)
            aux_preds_list.append(aux_preds)

            # Auxiliary targets (for supervised heads)
            for aux in self.aux_modules:
                if aux.name == "cue":
                    aux_targets_ep[aux.name].append(initial_cue)
                elif aux.name == "next_obs":
                    aux_targets_ep[aux.name].append(torch.tensor(obs, dtype=torch.float32))
                elif aux.name == "confidence":
                    dist = Categorical(logits=logits)
                    entropy = dist.entropy().item()
                    confidence = 1.0 - entropy  # Heuristic; can be improved
                    aux_targets_ep[aux.name].append(confidence)
                elif aux.name == "event":
                    event_flag = getattr(self.env, "event_flag", 0)
                    aux_targets_ep[aux.name].append(event_flag)
                elif aux.name == "oracle_action":
                    oracle_action = getattr(self.env, "oracle_action", None)
                    aux_targets_ep[aux.name].append(oracle_action)
                else:
                    aux_targets_ep[aux.name].append(0)

        # Store full trajectory in memory module (episodic buffer)
        if self.memory is not None:
            outcome = sum([r.item() for r in rewards])
            # Modular handling: always update episode buffer if available, else fallback
            
            if hasattr(self.memory, "episodic_buffer") and hasattr(self.memory.episodic_buffer, "add_entry"):
                self.memory.episodic_buffer.add_entry(context_traj, outcome)
            elif hasattr(self.memory, "add_entry"):
                self.memory.add_entry(context_traj, outcome)
            # Optionally: update motifs if needed (usually not online, but up to you)
            # if hasattr(self.memory, "motif_bank") and hasattr(self.memory.motif_bank, "add_entry"):
            #     self.memory.motif_bank.add_entry(context_traj, outcome)
        if self.memory is not None and hasattr(self.memory, 'get_last_attention'):
            attn_weights = self.memory.get_last_attention()

        # RND predictor update (only predictor trained)
        if self.use_rnd:
            obs_batch = torch.stack([torch.tensor(np.array(o), dtype=torch.float32, device=self.device) for o in trajectory])
            rnd_loss = self.rnd(obs_batch).mean()
            self.rnd_optimizer.zero_grad()
            rnd_loss.backward()
            self.rnd_optimizer.step()

        return {
            "trajectory": trajectory,
            "actions": actions,
            "rewards": rewards,
            "log_probs": log_probs,
            "values": values,
            "entropies": entropies_ep,
            "aux_preds": aux_preds_list,
            "aux_targets": aux_targets_ep,
            "initial_cue": initial_cue,
            "gate_history": gate_history,
            "memory_size_history": memory_size_history,
            "attn_weights": attn_weights
        }

    def get_episodic_buffer(self):
        episodic_buffer = None
        if self.memory :
            episodic_buffer = self.memory.episodic_buffer if hasattr(self.memory,"episodic_buffer") else  self.memory
        return episodic_buffer

        
        
    def learn(self, total_timesteps=2000, log_interval=100):
        steps = 0
        episodes = 0
        all_returns = []
        start_time = time.time()
        aux_losses = []
        unlock_early_stopping = len(self.episode_rewards)+self.early_stop_n_samples
        while steps < total_timesteps:
            try:
                #if hasattr(sys, 'last_traceback'):  # Quick hack: set by IPython on error/stop
                #    print("Interrupted in Jupyter (sys.last_traceback). Exiting.")
                #    break
                episode = self.run_episode()
                if self.reward_norm:
                    self.reward_normalizer.update([r.item() for r in episode["rewards"]])
                    episode["rewards"] = [
                        torch.tensor(rn, dtype=torch.float32, device=self.device)
                        for rn in self.reward_normalizer.normalize([r.item() for r in episode["rewards"]])
                    ]
    
                trajectory = episode["trajectory"]
                actions = episode["actions"]
                rewards = episode["rewards"]
                log_probs = episode["log_probs"]
                values = episode["values"]
                entropies_ep = episode["entropies"]
                aux_preds = episode["aux_preds"]
                aux_targets = episode["aux_targets"]
                T = len(rewards)
                rewards_t = torch.stack(rewards)
                values_t = torch.stack(values)
                log_probs_t = torch.stack(log_probs)
                actions_t = torch.stack(actions)
                last_value = 0.0
                advantages = compute_gae(rewards_t, values_t, gamma=self.gamma, lam=self.lam, last_value=last_value)
                returns = advantages + values_t.detach()
    
                policy_loss = -(log_probs_t * advantages.detach()).sum()
                value_loss = F.mse_loss(values_t, returns)
                entropy_mean = torch.stack(entropies_ep).mean()
                explained_var = compute_explained_variance(values_t, returns)
    
                # Auxiliary losses
                aux_loss_total = torch.tensor(0.0, device=self.device)
                aux_metrics_log = {}
                if self.aux:
                    for aux in self.aux_modules:
                        preds = torch.stack([ap[aux.name] for ap in aux_preds])
                        targets = torch.tensor(aux_targets[aux.name], device=self.device)
                        if preds.dim() != targets.dim():
                            targets = targets.squeeze(-1)
                        loss = aux.aux_loss(preds, targets)
                        aux_loss_total += loss
                        metrics = aux.aux_metrics(preds, targets)
                        aux_metrics_log[aux.name] = metrics
                    aux_losses.append(aux_loss_total.item())
    
                # Memory usefullness (if enabled) =====
                episodic_buffer = self.get_episodic_buffer()
                if (
                    self.memory_learn_retention
                    and self.memory is not None
                    and hasattr(episodic_buffer, 'get_last_attention')
                    and episodic_buffer.last_attn is not None
                    and len(episodic_buffer.usefulness_vec) == len(episodic_buffer.last_attn)
                    and len(episodic_buffer.usefulness_vec) > 0
                ):
                    total_reward = sum([r.item() for r in rewards])
                    if hasattr(self.memory,'episodic_buffer'):
                        
                        attn_tensor = torch.tensor(self.memory.episodic_buffer.last_attn, dtype=torch.float32, device=self.device)
                        mem_loss = self.memory.episodic_buffer.usefulness_loss(attn_tensor, total_reward)
                    else:
                      
                        attn_tensor = torch.tensor(self.memory.last_attn, dtype=torch.float32, device=self.device)
                        mem_loss = self.memory.usefulness_loss(attn_tensor, total_reward)
                else:
                    mem_loss = torch.tensor(0.0, device=self.device)
                    
    
                loss = (
                    policy_loss 
                    + 0.5 * value_loss 
                    + 0.1 * aux_loss_total 
                    - self.ent_coef * entropy_mean
                    + (self.memory_retention_coef * mem_loss if self.memory_learn_retention else 0.0)
                )
    
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
    
                total_reward = sum([r.item() for r in rewards])
                self.episode_rewards.append(total_reward)
                self.episode_lengths.append(T)
                episodes += 1
                steps += T

                if self.early_stop and len(self.episode_rewards) >= unlock_early_stopping:
                    mean_rew = np.mean(self.episode_rewards[-self.early_stop_n_samples:])
                    std_rew = np.std(self.episode_rewards[-self.early_stop_n_samples:])
                    if mean_rew >self.early_stop_mean_threshold and std_rew <= self.early_stop_std_threshold:
                        mean_len = np.mean(self.episode_lengths[-log_interval:])
                        elapsed = int(time.time() - start_time)
                        ep_duration = elapsed/episodes
                        table = [
                                ["Train duration",f"{elapsed}s"],
                                ["Avg episode duration",f"{ep_duration:.2f}s"],
                                ["Rolling ep rew mean", f"{mean_rew:.2f}"],
                                ["Rolling ep rew std",f"{std_rew:.2f}"],
                                ["Rolling ep length",f"{mean_len:.2f}"],
                                ["N updates", episodes]]
                        
                        print(tabulate(table ,tablefmt="rounded_outline" , headers=["Early Stop",""]))
                        return

                if hasattr(self.memory, 'motif_bank') and hasattr(self, 'get_episodic_buffer'):
                    if episodes > 0 and episodes % MINING_INTERVAL == 0:
                        mine_motifs_from_buffer(
                            self.get_episodic_buffer(),
                            self.memory.motif_bank,
                            motif_len=self.memory.motif_bank.motif_len,
                            n_motifs=self.memory.motif_bank.n_motifs
                        )
                                # LOGGING (SB3-STYLE) =====================
                if episodes % log_interval == 0 and self.verbose == 1:
                    elapsed = int(time.time() - start_time)
                    mean_rew = np.mean(self.episode_rewards[-log_interval:])
                    std_rew = np.std(self.episode_rewards[-log_interval:])
                    mean_len = np.mean(self.episode_lengths[-log_interval:])
                
                    fps = int(steps / (elapsed + 1e-8))
                    adv_mean = advantages.mean().item()
                    adv_std = advantages.std().item()
                    mean_entropy = entropy_mean.item()
                    mean_aux = np.mean(aux_losses[-log_interval:]) if aux_losses else 0.0
                    stats = [{
                        "header": "rollout",
                        "stats": dict(
                            ep_len_mean=mean_len,
                            ep_rew_mean=mean_rew,
                            ep_rew_std=std_rew,
                            policy_entropy=mean_entropy,
                            advantage_mean=adv_mean,
                            advantage_std=adv_std,
                            aux_loss_mean=mean_aux
                        )}, {
                        "header": "time",
                        "stats": dict(
                            fps=fps,
                            episodes=episodes,
                            time_elapsed=elapsed,
                            total_timesteps=steps
                        )}, {
                        "header": "train",
                        "stats": dict(
                            loss=loss.item(),
                            policy_loss=policy_loss.item(),
                            value_loss=value_loss.item(),
                            explained_variance=explained_var.item(),
                            n_updates=episodes,
                            progress=100 * steps / total_timesteps
                        )}
                    ]
                    if len(aux_metrics_log.items()) > 0:
                        aux_stats = {
                            "header": "aux_train",
                            "stats": {}
                        }
                        for aux_name, metrics in aux_metrics_log.items():
                            for k, v in metrics.items():
                                aux_stats["stats"][f"aux_{aux_name}_{k}"] = v
                        stats.append(aux_stats)
                    if self.use_rnd:
                        mean_rnd_bonus = np.mean([self.rnd(torch.tensor(np.array(o), dtype=torch.float32, device=self.device).unsqueeze(0)).item() for o in trajectory])
                        stats.append({
                            "header": "rnd_net_dist",
                            "stats": {"mean_rnd_bonus": mean_rnd_bonus}
                        })
                    if self.memory_learn_retention:
                        stats.append({
                            "header": "memory",
                            "stats": {
                                "usefulness_loss": mem_loss.item()}
                        })
                    
                    print_sb3_style_log_box(stats)
                    
            except KeyboardInterrupt:
                print("\n[Stopped by user] Gracefully exiting training loop...")
                return
            
        if self.verbose == 1:
            print(f"Training complete. Total episodes: {episodes}, total steps: {steps}")

    def predict(self, obs, deterministic=False, done=False, reward=0.0):
        """
        Computes action for a given observation, with support for memory context.

        Args:
            obs (np.ndarray): Environment observation.
            deterministic (bool): Use argmax instead of sampling.
            done (bool): If episode ended, will reset trajectory buffer.
            reward (float): Last received reward (for memory context).

        Returns:
            int: Action index.
        """
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
        # Track full trajectory for memory
        if not hasattr(self, "trajectory_buffer") or self.trajectory_buffer is None:
            self.trajectory_buffer = []
        if len(self.trajectory_buffer) == 0:
            self.trajectory_buffer.append((obs_t.cpu().numpy(), 0, 0.0))
        else:
            last_action = self.last_action if hasattr(self, "last_action") else 0
            last_reward = self.last_reward if hasattr(self, "last_reward") else 0.0
            self.trajectory_buffer.append((obs_t.cpu().numpy(), last_action, last_reward))
        context_traj = self.trajectory_buffer.copy()
        actions_int = [a for _, a, _ in context_traj]
        rewards_float = [r for _, _, r in context_traj]
        obs_stack = torch.stack([torch.tensor(o, dtype=torch.float32, device=self.device) for o, _, _ in context_traj])
        logits, _, _ = self.policy(
            obs_stack, obs_t,
            actions=torch.tensor(actions_int, device=self.device),
            rewards=torch.tensor(rewards_float, device=self.device)
        )
        if deterministic:
            action = torch.argmax(logits).item()
        else:
            dist = Categorical(logits=logits)
            action = dist.sample().item()
        self.last_action = action
        self.last_reward = reward
        if done:
            self.trajectory_buffer = []
        return action

    def save(self, path="memoryppo.pt"):
        """Save policy weights to file."""
        torch.save(self.policy.state_dict(), path)

    def load(self, path="memoryppo.pt"):
        """Load policy weights from file."""
        self.policy.load_state_dict(torch.load(path, map_location=self.device))

    def evaluate(self, n_episodes=10, deterministic=False, verbose=True):
        """
        Evaluates policy over several episodes, reporting mean/std return.

        Args:
            n_episodes (int): Number of test episodes.
            deterministic (bool): Use argmax instead of sampling.
            verbose (bool): Print results to console.

        Returns:
            mean_return (float): Average reward.
            std_return (float): Std deviation of rewards.
        """
        returns = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            self.trajectory_buffer = []
            done = False
            total_reward = 0.0
            last_reward = 0.0
            while not done:
                action = self.predict(obs, deterministic=deterministic, reward=last_reward)
                obs, reward, done, _, _ = self.env.step(action)
                total_reward += reward
                last_reward = reward
                if done:
                    self.trajectory_buffer = []
            returns.append(total_reward)
        mean_return = np.mean(returns)
        std_return = np.std(returns)
        if verbose:
            print(f"Evaluation over {n_episodes} episodes: mean return {mean_return:.2f}, std {std_return:.2f}")
        return mean_return, std_return




In [17]:

# ──────────────────────────────────────────────────────────────
# Example training loop
# ──────────────────────────────────────────────────────────────

# SETUP ===================================
DELAY = 4
MEM_DIM = 32
N_EPISODES = 2500
N_MEMORIES = 32

AGENT_KWARGS = dict(
    device="cpu",
    verbose=0,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
)
MEMORY_AGENT_KWARGS=dict(
    her=False,
    reward_norm=False,
    aux_modules=None,
    
    intrinsic_expl=False,
    intrinsic_eta=0.01,
    
    use_rnd=False, 
    rnd_emb_dim=32, 
    rnd_lr=1e-3,
)

# HELPERS =================================
def total_timesteps(delay,n_episodes):
    return delay * n_episodes

# ENVIRONMENT =============================
env = MemoryTaskEnv(delay=DELAY, difficulty=0)

# MEMORY BUFFER ===========================
episodic_buffer = StrategicMemoryBuffer(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,
    mem_dim=MEM_DIM,
    max_entries=N_MEMORIES,
    device="cpu"
)
motif_bank = MotifMemoryBank(
    obs_dim=env.observation_space.shape[0],
    action_dim=1,
    mem_dim=MEM_DIM,
    n_motifs=32,
    motif_len=4,
    device="cpu"
)
combined_memory = CombinedMemoryModule(episodic_buffer, motif_bank)


# POLICY NETWORK (use class) ==============
policy = StrategicCombinedMemoryPolicy


# AGENT SETUP =============================
agent = TraceRL(
    policy_class=policy,
    env=env,
    memory=combined_memory,
    memory_learn_retention=True,    
    memory_retention_coef=0.01,   
    # aux_modules=aux_modules,  
    device="cpu",
    verbose=1,
    lam=0.95, 
    gamma=0.99, 
    ent_coef=0.01,
    learning_rate=1e-3, 
    
    **MEMORY_AGENT_KWARGS
)

# TRAIN THE AGENT =========================
#agent.learn(
#    total_timesteps=total_timesteps(DELAY, 1000),
#    log_interval=50
#)

In [None]:
from benchmark import AgentPerformanceBenchmark
from tabulate import tabulate
env = MemoryTaskEnv(delay=4, difficulty=0)

# MEMORY BUFFER ===========================
memory = StrategicMemoryBuffer(
        obs_dim=env.observation_space.shape[0],
        action_dim=1,          # For Discrete(2)
        mem_dim=MEM_DIM,
        max_entries=N_MEMORIES,
        device="cpu"
    )
    
# POLICY NETWORK (use class) ==============
#policy = StrategicMemoryTransformerPolicy
policy = StrategicCombinedMemoryPolicy
agent = TraceRL(
        policy_class=policy,
        env=env,
        memory=memory,
        memory_learn_retention=True,    
        memory_retention_coef=0.01,   
        # aux_modules=aux_modules,  
        device="cpu",
        verbose=0,
        lam=0.95, 
        gamma=0.99, 
        ent_coef=0.01,
        learning_rate=1e-3, 
        
        **MEMORY_AGENT_KWARGS
    )
curriculum = [2, 4, 8, 16,32,64,128,256]
for delay in curriculum:
    agent.env.delay = delay
    #agent.get_episodic_buffer().reset()
    print(f"\n--- Training with delay={delay} ---")
    agent.learn(total_timesteps=total_timesteps(delay, 100000), log_interval=50)
    
    benchmark = AgentPerformanceBenchmark(dict(delay=delay, n_train_episodes=2000, total_timesteps=1_000_000, difficulty=0, mode_name="EASY", verbose=0, eval_base=True),)
    e_r, e_s = benchmark.evaluate(agent,'motif')
    table = [["Avg reward",e_r],["Std reward",e_s]]
    print(tabulate(table, headers=["Evaluation",""], tablefmt="rounded_outline"))


--- Training with delay=2 ---
╭──────────────────────┬───────╮
│ Early Stop           │       │
├──────────────────────┼───────┤
│ Train duration       │ 7s    │
│ Avg episode duration │ 0.01s │
│ Rolling ep rew mean  │ 1.00  │
│ Rolling ep rew std   │ 0.00  │
│ Rolling ep length    │ 2.00  │
│ N updates            │ 553   │
╰──────────────────────┴───────╯
╭──────────────┬────╮
│ Evaluation   │    │
├──────────────┼────┤
│ Avg reward   │  1 │
│ Std reward   │  0 │
╰──────────────┴────╯

--- Training with delay=4 ---
╭──────────────────────┬───────╮
│ Early Stop           │       │
├──────────────────────┼───────┤
│ Train duration       │ 13s   │
│ Avg episode duration │ 0.02s │
│ Rolling ep rew mean  │ 1.00  │
│ Rolling ep rew std   │ 0.00  │
│ Rolling ep length    │ 4.00  │
│ N updates            │ 530   │
╰──────────────────────┴───────╯
╭──────────────┬────╮
│ Evaluation   │    │
├──────────────┼────┤
│ Avg reward   │  1 │
│ Std reward   │  0 │
╰──────────────┴────╯

--- Training 

In [None]:
from benchmark import AgentPerformanceBenchmark

env = MemoryTaskEnv(delay=delay, difficulty=0)

# MEMORY BUFFER ===========================
memory = StrategicMemoryBuffer(
        obs_dim=env.observation_space.shape[0],
        action_dim=1,          # For Discrete(2)
        mem_dim=MEM_DIM,
        max_entries=N_MEMORIES,
        device="cpu"
    )
    
# POLICY NETWORK (use class) ==============
#policy = StrategicMemoryTransformerPolicy
policy = StrategicCombinedMemoryPolicy
agent = TraceRL(
        policy_class=policy,
        env=env,
        memory=memory,
        memory_learn_retention=True,    
        memory_retention_coef=0.01,   
        # aux_modules=aux_modules,  
        device="cpu",
        verbose=0,
        lam=0.95, 
        gamma=0.99, 
        ent_coef=0.01,
        learning_rate=1e-3, 
        
        **MEMORY_AGENT_KWARGS
    )
curriculum = [2, 4, 8, 16,32,64,128,256]
for delay in curriculum:
    agent.env.delay = delay
    #agent.get_episodic_buffer().reset()
    print(f"\n--- Training with delay={delay} ---")
    agent.learn(total_timesteps=total_timesteps(delay, 100000), log_interval=50)
    
    benchmark = AgentPerformanceBenchmark(dict(delay=delay, n_train_episodes=2000, total_timesteps=1_000_000, difficulty=0, mode_name="EASY", verbose=0, eval_base=True),)
    e_r, e_s = benchmark.evaluate(agent,'motif')
    table = [["Avg reward",e_r],["Std reward",e_s]]
    print(tabulate(table, headers=["Evaluation",""], tablefmt="rounded_outline"))