# Memory-Advanced RL (MARL):

In traditional RL, policies are often memoryless (Markovian), making decisions based solely on the current state. However, manu real-world tasks require remembering past observations to act effectively. Memory-Augmented RL introduces explicit memory mechanisms withing RL agents, allowing them to capture temporal dependencies and make more informed decisionas based on past experiences.
- **Long-term dependencies**: Better handles tasks requiring remembering events from the distant past.
- **Partial observability**: Deals effectively with partially observable environments.
- **Generalization**: Enhances the agent's capability to adapt to complex, dynamic, and structured tasks.

## Background:

- **Markovian vs Non-Markovian**: Memory-Aug RL explicitly addresses non-Markovian environments by introducing memory structures into the agent's policy.
- **RNNs**
- **Extenral Memory (Neural Turing Machines, Memory Networks)**: Networks that explicitly read/write to external memory.

## Theory

Memory-Aug RL agetns parameterize a policy $\pi$ as:
$$a_t \sim \pi(a_t|s_t,h_T)$$
where $h_T$ is a memory state that symmarizes past interactions:
- Implicit memory (RNN/LSTM):
$$h_{t+1} = f_{\theta}(h_t,s_t,a_t,r_t)$$
- Explicit memory (Memory Networks/NTM):
Memory is a separate storage structure explicitly written to and read from.
Aims to optimize policy paramters $\theta$ while incorporating past experiences effectively.

## Math

### Recurrent Policy (RNN/LSTM)
A recurrent policy integrates past experienves into a hidden memory state:
$$h_{t+1} = LSTM(h_t,s_t,a_t,r_t;\theta_h)$$
Policy output is conditioned on hidden state:
$$\pi(a_t|s_t,h_t;\theta_{\pi})$$ 
$$V(s_t,h_t;\theta_V)$$

### External Memory
Involve explicit read/write operations:
- Write:
$$M_t = W(M_{t-1},s_t,a_t,r_t)$$
- Read:
$$m_t = R(M_t,s_t)$$
Then, policy decisions use both current state and retrieved memory $m_t$:
$$a_t \sim \pi(a_t|s_t,m_t;\theta_{\pi})$$

## Implementations

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class LSTMActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128):
        super().__init__()
        self.lstm = nn.LSTM(obs_dim, hidden_size, batch_first=True)
        self.actor = nn.Linear(hidden_size, act_dim)
        self.critic = nn.Linear(hidden_size, 1)

    def forward(self, obs, hidden):
        obs = obs.view(1, 1, -1)  # Correct shape: (seq_len=1, batch=1, features)
        lstm_out, hidden = self.lstm(obs, hidden)
        action_mean = torch.tanh(self.actor(lstm_out[:, -1]))
        value = self.critic(lstm_out[:, -1])
        return action_mean, value, hidden
    
    def init_hidden(self):
        return (torch.zeros(1, 1, 128).to(device), 
                torch.zeros(1, 1, 128).to(device))


In [5]:
env = gym.make("Pendulum-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

agent = LSTMActorCritic(obs_dim, act_dim).to(device)
optimizer = optim.Adam(agent.parameters(), lr=3e-4)

def run_episode(env, agent, optimizer):
    obs, _ = env.reset()
    hidden = agent.init_hidden()

    log_probs, values, rewards = [], [], []
    ep_reward = 0

    for _ in range(200):
        obs_tensor = torch.tensor(obs, dtype = torch.float32).to(device)
        action_mean, value, hidden = agent(obs_tensor, hidden)

        dist = torch.distributions.Normal(action_mean, 0.1)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum()

        next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
        done = terminated or truncated

        log_probs.append(log_prob)
        values.append(value.squeeze(0))
        rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))

        obs = next_obs
        ep_reward += reward
        if done:
            break

    returns, R = [], 0
    gamma = 0.99
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.stack(returns)
    log_probs = torch.stack(log_probs)
    values = torch.stack(values)

    advantage = returns - values

    actor_loss = -(log_probs * advantage.detach()).mean()
    critic_loss = advantage.pow(2).mean()

    loss = actor_loss + critic_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return ep_reward

for ep in range(100):
    ep_reward = run_episode(env, agent, optimizer)
    if ep % 10 == 0:
        print(f"Episode {ep}: {ep_reward}")

  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(


Episode 0: [-890.9534]
Episode 10: [-1374.8759]
Episode 20: [-747.6549]
Episode 30: [-1120.7549]
Episode 40: [-1889.1458]
Episode 50: [-1095.8055]
Episode 60: [-1864.7867]
Episode 70: [-906.3905]
Episode 80: [-1055.7147]
Episode 90: [-1070.7229]


## Next-Steps

### Transformer-Based Memory-Aug RL

Introduce Transformer-based policy to enhance memory capabilities via attention mechanisms, explicitly capturing long-term dependencies in the environment.
Transformers utilize self-attention mechanisms, allowing the agent to selectively focus on relevant past experiences, significantly improving long-range memory ober traditional recurrent models.

In [6]:
class TransformerActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, seq_len=10, hidden_dim=128, n_heads=4, n_layers=2):
        super().__init__()
        self.seq_len = seq_len
        self.obs_dim = obs_dim
        self.act_dim = act_dim

        self.input_proj = nn.Linear(obs_dim, hidden_dim)
        self.pos_encoding = nn.Parameter(torch.randn(seq_len, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.actor = nn.Linear(hidden_dim, act_dim)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, obs_seq):
        batch_size = obs_seq.shape[0]
        x = self.input_proj(obs_seq) + self.pos_encoding[:obs_seq.size(1),:]

        memory = self.transformer_encoder(x)
        latest_memory = memory[:, -1]

        action = torch.tanh(self.actor(latest_memory))
        value = self.critic(latest_memory)
        return action, value

In [10]:
env = gym.make("Pendulum-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

agent = TransformerActorCritic(obs_dim, act_dim).to(device)
optimizer = optim.Adam(agent.parameters(), lr=3e-4)

def collect_episode(env, policy, seq_len):
    obs_seq, action_seq, reward_seq = [], [], []
    obs, _ = env.reset()
    ep_reward = 0

    for step in range(seq_len):
        obs_seq.append(obs)
        if len(obs_seq) < seq_len:
            # Before having enough observations, take random actions
            action = env.action_space.sample()
        else:
            obs_tensor = torch.tensor(obs_seq[-seq_len:], dtype=torch.float32).unsqueeze(0).to(device)

            with torch.no_grad():
                action_pred, _ = policy(obs_tensor)

            action = action_pred.cpu().numpy().flatten()

        next_obs, reward, terminated, truncated, _ = env.step(action)
        reward_seq.append(reward)
        action_seq.append(action)
        obs = next_obs
        ep_reward += reward

        if terminated or truncated:
            break

    return obs_seq, action_seq, reward_seq, ep_reward



def train_step(policy, optimizer, obs_seq, action_seq, reward_seq, seq_len):
    if len(obs_seq) < seq_len:
        return  # Skip if insufficient length

    obs_tensor = torch.tensor(obs_seq[-seq_len:], dtype=torch.float32).unsqueeze(0).to(device)
    action_tensor = torch.tensor(action_seq[-seq_len:], dtype=torch.float32).to(device)
    rewards = torch.tensor(reward_seq[-seq_len:], dtype=torch.float32).to(device)

    action_preds, values = policy(obs_tensor)
    values = values.squeeze()

    actor_loss = ((action_preds - action_tensor)**2).mean()

    returns, R = [], 0
    gamma = 0.99
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns).to(device)

    critic_loss = ((returns - values)**2).mean()

    loss = actor_loss + critic_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


seq_len = 10
episodes = 100

for ep in range(episodes):
    obs_seq, act_seq, rew_seq, ep_reward = collect_episode(env, agent, seq_len)
    train_step(agent, optimizer, obs_seq, act_seq, rew_seq, seq_len)

    if ep % 10 == 0:
        print(f"Episode {ep}, Reward: {ep_reward:.2f}")

Episode 0, Reward: -7.75
Episode 10, Reward: -25.13
Episode 20, Reward: -56.99
Episode 30, Reward: -1.89
Episode 40, Reward: -43.98
Episode 50, Reward: -69.21
Episode 60, Reward: -36.11
Episode 70, Reward: -88.89
Episode 80, Reward: -83.48
Episode 90, Reward: -55.05


In [11]:
test_rewards = []
for ep in range(10):
    _, _, _, ep_reward = collect_episode(env, agent, seq_len=10)
    test_rewards.append(ep_reward)
    print(f"Test Episode {ep}, Reward: {ep_reward:.2f}")

print(f"Average Test Reward: {np.mean(test_rewards):.2f}")

Test Episode 0, Reward: -67.22
Test Episode 1, Reward: -87.23
Test Episode 2, Reward: -45.17
Test Episode 3, Reward: -67.43
Test Episode 4, Reward: -43.21
Test Episode 5, Reward: -82.03
Test Episode 6, Reward: -48.55
Test Episode 7, Reward: -25.49
Test Episode 8, Reward: -18.65
Test Episode 9, Reward: -81.23
Average Test Reward: -56.62


### Explicit Memory-Aug RL

While RNN implicitly store past information in their hidden states, explicit memory methods (Neural Turing Machines (NTM)) allow an agent to clearly and explicitly read from and write to a memory buffer, perform more structured memory operations and capture long-range dependencies better than typical RNNs or LSTMs.
An explicit memory model typically consists of:
- Memory matrix $M\in R^{N\times M}$ with N memory slots, each of dimension M.
- Read and write heads, controlling what info is written/read from memory.
- Controller (NN) deciding read/write actions.

In [16]:
class ExplicitMemoryAgent(nn.Module):
    def __init__(self, obs_dim, act_dim, memory_slots=20, memory_dim=32, hidden_dim=128):
        super().__init__()
        self.memory_slots = memory_slots
        self.memory_dim = memory_dim
        self.hidden_dim = hidden_dim

        self.controller = nn.GRUCell(obs_dim+memory_dim, hidden_dim)

        self.memory = torch.zeros(memory_slots, memory_dim).to(device)

        self.read_head = nn.Linear(hidden_dim, memory_slots)
        self.write_head = nn.Linear(hidden_dim, memory_slots)

        self.write_projection= nn.Linear(hidden_dim, memory_dim)

        self.actor = nn.Linear(hidden_dim+memory_dim, act_dim)
        self.critic = nn.Linear(hidden_dim+memory_dim, 1)

    def reset_memory(self):
        self.memory = torch.zeros(self.memory_slots, self.memory_dim).to(device)

    def forward(self, obs, hidden):
        obs = obs.unsqueeze(0)  # shape: (1, obs_dim)
        
        # Read from memory
        read_weights = torch.softmax(self.read_head(hidden), dim=-1)  # (1, memory_slots)
        read_vector = torch.matmul(read_weights, self.memory)         # (1, memory_dim)
        
        # Controller update (ensure matching dimensions)
        controller_input = torch.cat([obs, read_vector], dim=-1)      # (1, obs_dim + memory_dim)
        hidden = self.controller(controller_input, hidden)            # (1, hidden_dim)
        
        # Write to memory
        write_weights = torch.softmax(self.write_head(hidden), dim=-1)  # (1, memory_slots)
        write_content = torch.tanh(self.write_projection(hidden))       # (1, memory_dim)

        # update memory explicitly
        self.memory = self.memory + write_weights.squeeze(0).unsqueeze(-1) * write_content

        # Actor and Critic
        actor_input = torch.cat([hidden, read_vector], dim=-1)
        action = torch.tanh(self.actor(actor_input))
        value = self.critic(actor_input)

        return action.squeeze(0), value.squeeze(0), hidden

    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_dim).to(device)

In [18]:
env = gym.make("Pendulum-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

agent = ExplicitMemoryAgent(obs_dim, act_dim).to(device)
optimizer = optim.Adam(agent.parameters(), lr=1e-4)

def run_episode(env, agent, optimizer):
    obs, _ = env.reset()
    agent.reset_memory()
    hidden = agent.init_hidden()

    log_probs, values, rewards = [], [], []
    ep_reward = 0

    for step in range(200):
        obs_tensor = torch.tensor(obs, dtype = torch.float32).to(device)
        action, value, hidden = agent(obs_tensor, hidden)

        dist = torch.distributions.Normal(action, 0.1)
        sampled_action = dist.sample()
        log_prob = dist.log_prob(sampled_action).sum()

        next_obs, reward, terminated, truncated, _ = env.step(sampled_action.cpu().numpy())
        done = terminated or truncated

        log_probs.append(log_prob)
        values.append(value.squeeze())
        rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))

        obs = next_obs
        ep_reward += reward
        if done:
            break

    returns, R = [], 0
    gamma = 0.99
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.stack(returns)
    log_probs = torch.stack(log_probs)
    values = torch.stack(values)

    advantage = returns - values

    actor_loss = -(log_probs * advantage.detach()).mean()
    critic_loss = advantage.pow(2).mean()

    loss = actor_loss + critic_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return ep_reward

for ep in range(100):
    ep_reward = run_episode(env, agent, optimizer)
    if ep % 10 == 0:
        print(f"Episode {ep}: {ep_reward:.2f}")

Episode 0: -889.102543802381
Episode 10: -1921.4417541711236
Episode 20: -907.8367765463909
Episode 30: -968.2965106261779
Episode 40: -1663.0584427264087
Episode 50: -1074.7426779459788
Episode 60: -897.1711201635997
Episode 70: -1759.5716529330393
Episode 80: -1170.2580476895098
Episode 90: -1803.3554597443747
