# Recurrent Reinforcement Learning (LSTM, GRU, GRU-RSSM)

Most RL algorithms assume the agent observes a complete Markov state $s_t$. But in reality:
- Observations are partial or noisy
- Current observations $o_t$ is not sufficient for optimal decisions
The soulution is to use a recurrent model that can accumulate information over time to form a belief about the hidden state.

## Partial Observable MArkov Decision Process (POMDP)
A POMDP is a tuple $(S, A, O, T, R, \gamma)$ where:
- $S$: set of states
- $A$: set of actions
- $O$: set of observations
- $T$: state transition function $T(s'|s,a)$
- $R$: reward function $R(s,a)$
- $\gamma$: discount factor

Because the agent can't directly observe $s_t$, it must infer a belief state $b_t$ using all past observations and actions:
$$b_t = f(o_{1:t}, a_{1:t})$$
This is where RNNs, LSTMs, and GRUs come in, they compress the sequence history into a hidden state.

## Recurrent Policies: LSTM and GRU
A recurrent policy in RL looks like this:
$$h_T = RNN(h_{t-1}, o_t)$$
$$a_t \sim \pi(a_t|h_t)$$
$$V_t = V(h_t)$$
Where:
- $h_t$: hidden state of the RNN
- $\pi$: policy function
- $V$: value function

### LSTM
LSTM introduces gated mechanisms to control information flow:
- Input gate: Controls how much of the new input to keep
- Forget gate: Controls how much of the previous state to keep
- Output gate: Controls how much of the current state to output
Formally, at time t:
$$f_t = \sigma(W_f[h_{t-1}, x_t] + b_f)$$
$$i_t = \sigma(W_i[h_{t-1}, x_t] + b_i)$$
$$o_t = \sigma(W_o[h_{t-1}, x_t] + b_o)$$
$$\tilde{c}_t = \tanh(W_c[h_{t-1}, x_t] + b_c)$$
$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$
$$h_t = o_t \odot \tanh(c_t)$$
Where $\odot$ is the element-wise product, $W$ are weight matrices, and $b$ are biases.

### GRU
A lighter alternative to LSTM with fewer gates:
$$z_t = \sigma(W_zx_t + U_zh_{t-1})\ (update\ gate)$$
$$r_t = \sigma(W_rx_t + U_rh_{t-1})\ (reset\ gate)$$
$$\tilde{h}_t = \tanh(W_hx_t + U(r_t \odot h_{t-1}))$$
$$h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$
GRUs are computationally cheaper and often perform just as well as LSTMs.

## GRU-RSSM: Recurrent State Space Model
The GRU-RSSM is a latent dynamics model combining:
- Deterministic memory (GRU)
- Stochastic latent states (variational)
Used in PlaNet and Dreamer, it models:
$$h_t = GRU(h_{t-1}, a_{t-1}, s_{t-1})$$
$$s_t \sim p(s_t|h_t)\ (prior)$$
$$s_t \sim q(s_t|h_t, e(o_t))\ (posterior)$$
Then use $(s_t, h_t)$ to:
- Predict future latent states
- Decode rewards, observations
- Imagine rollouts for planning or policy learning

## RL Algorithms + Recurrent Policies
Recurrent policies can be trained using:
- **REINFORCE** (policy gradient with history)
- **A2C/PPO** with recurrent rollouts
- **SAC** (with RNN or GRU-based actor/critic)
- **Dreamer** (learned latent rollouts from GRU-RSSM)

## Implementation

### LSTM-based Policy with REINFORCE


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym

In [2]:
class LSTMPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.lstm = nn.LSTM(obs_dim, hidden_dim, batch_first=True)
        self.actor = nn.Linear(hidden_dim, action_dim)

    def forward(self, x, hidden=None):
        out, hidden = self.lstm(x,hidden)
        logits = self.actor(out)
        return logits, hidden

env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = LSTMPolicy(obs_dim, action_dim)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)

def train_lstm_reinforce(episode=300):
    gamma = 0.99
    for ep in range(episode):
        obs, _ = env.reset()
        hidden = (torch.zeros(1,1,64), torch.zeros(1,1,64))
        rewards, log_probs = [], []
        done = False

        while not done:
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            logits, hidden = policy(obs_tensor, hidden)
            probs = torch.softmax(logits[0,0], dim=-1)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            obs, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            rewards.append(reward)
            log_probs.append(log_prob)

        R = 0
        returns = []
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-6)

        loss = -sum(lp * R for lp, R in zip(log_probs, returns))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if ep % 50 == 0:
            print(f"Episode {ep}, Loss: {loss.item()}, Return: {sum(rewards)}")
    env.close()

train_lstm_reinforce()

Episode 0, Loss: -0.2010430097579956, Return: 31.0
Episode 50, Loss: 0.03944242000579834, Return: 18.0
Episode 100, Loss: -0.5274049043655396, Return: 108.0
Episode 150, Loss: 2.6946613788604736, Return: 39.0
Episode 200, Loss: -0.16298528015613556, Return: 87.0
Episode 250, Loss: -8.582036972045898, Return: 500.0


### GRU-RSSM

In [78]:
class GRURSSM(nn.Module):
    def __init__(self, action_dim, latent_dim=30, hidden_dim=200, obs_embed_dim=1024):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim

        self.gru = nn.GRUCell(latent_dim + action_dim, hidden_dim)

        self.prior = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2*latent_dim)
        )

        self.posterior = nn.Sequential(
            nn.Linear(hidden_dim + obs_embed_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2*latent_dim)
        )

    def init_state(self, batch_size):
        h = torch.zeros(batch_size, self.hidden_dim)
        s = torch.zeros(batch_size, self.latent_dim)
        return h, s
    def get_dist(self, stats):
        mean, log_std = torch.chunk(stats, 2, dim=-1)
        std = F.softplus(log_std)
        return torch.distributions.Normal(mean, std)
    
    def forward(self, h_prev, s_prev, a_prev, obs_embed=None):
        x = torch.cat([s_prev, a_prev], dim=-1)
        h = self.gru(x, h_prev)

        prior_stats = self.prior(h)
        prior_dist = self.get_dist(prior_stats)

        if obs_embed is not None:
            post_input = torch.cat([h, obs_embed], dim=-1)
            post_stats = self.posterior(post_input)
            post_dist = self.get_dist(post_stats)
            s = post_dist.rsample()
        else:
            post_stats = prior_stats
            post_dist = prior_dist
            s = prior_dist.rsample()
        return (h,s), prior_dist, post_dist, prior_stats, post_stats




In [79]:
class ObsEncoder(nn.Module):
    def __init__(self, obs_shape=(3,), embed_dim=1024):
        super().__init__()
        input_dim = np.prod(obs_shape)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, embed_dim), nn.ReLU()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)



In [80]:
class ObsDecoder(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200, obs_shape=(3,)):
        super().__init__()
        output_dim = np.prod(obs_shape)
        self.obs_shape = obs_shape
        self.net = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, 128), nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, s, h):
        x = torch.cat([s, h], dim=-1)
        return self.net(x).view(-1, *self.obs_shape)


In [81]:
class RewardDecoder(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, 200), nn.ReLU(),
            nn.Linear(200, 1)
        )

    def forward(self, s, h):
        x = torch.cat([s, h], dim=-1)
        return self.net(x)


In [82]:
class DreamerActor(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200, action_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
        )
        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)

    def forward(self, s, h):
        x = torch.cat([s, h], dim=-1)
        x = self.net(x)
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), -5, 2)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        return dist


In [83]:
class DreamerCritic(nn.Module):
    def __init__(self, latent_dim=30, hidden_dim=200):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, s, h):
        x = torch.cat([s, h], dim=-1)
        return self.net(x)


In [84]:
def imagine_rollout(h, s, actor, rssm, horizon=15):
    s_list, h_list, a_list, logp_list = [], [], [], []
    for _ in range(horizon):
        dist = actor(s, h)
        a = torch.tanh(dist.rsample())
        logp = dist.log_prob(a).sum(-1, keepdim=True) - torch.log(1 - a.pow(2) + 1e-6).sum(-1, keepdim=True)

        (h, s), *_ = rssm(h, s, a, obs_embed=None)
        s_list.append(s)
        h_list.append(h)
        a_list.append(a)
        logp_list.append(logp)

    return torch.stack(s_list), torch.stack(h_list), torch.stack(a_list), torch.stack(logp_list)


In [85]:
def compute_elbo_loss(obs_seq, act_seq, rew_seq, encoder, decoder_obs, decoder_r, rssm, beta=1.0):
    B, T, *_ = obs_seq.shape
    h, s = rssm.init_state(B)

    loss_obs, loss_r, loss_kl = 0.0, 0.0, 0.0
    for t in range(T):
        o_t = obs_seq[:, t]
        a_t = act_seq[:, t]
        r_t = rew_seq[:, t]

        emb = encoder(o_t)
        (h, s), prior, post, _, _ = rssm(h, s, a_t, emb)

        o_hat = decoder_obs(s, h)
        r_hat = decoder_r(s, h)

        loss_obs += F.mse_loss(o_hat, o_t, reduction="mean")
        loss_r += F.mse_loss(r_hat, r_t, reduction="mean")
        loss_kl += torch.distributions.kl.kl_divergence(post, prior).mean()

    total = loss_obs + loss_r + beta * loss_kl
    return total, loss_obs, loss_r, loss_kl


In [86]:
def actor_loss(critic, s_rollout, h_rollout, logp_rollout, entropy_weight=1.0):
    values = critic(s_rollout, h_rollout).squeeze(-1)
    logp = logp_rollout.squeeze(-1)
    return -(values + entropy_weight * logp).mean()

def value_loss(critic, s_rollout, h_rollout, rewards, gamma=0.99):
    target = []
    G = torch.zeros_like(rewards[0])
    for r in reversed(rewards):
        G = r + gamma * G
        target.insert(0, G.clone())
    target = torch.stack(target)
    values = critic(s_rollout, h_rollout).squeeze(-1)
    return F.mse_loss(values, target.detach())


In [87]:
def dreamer_train_step(batch, encoder, decoder_obs, decoder_r, rssm, actor, critic, 
                       optimizer_world, optimizer_actor, optimizer_critic):
    obs_seq, act_seq, rew_seq = batch  # (B, T, ...)
    
    # === World Model Update ===
    elbo_loss, _, _, _ = compute_elbo_loss(obs_seq, act_seq, rew_seq,
                                           encoder, decoder_obs, decoder_r, rssm)
    optimizer_world.zero_grad()
    elbo_loss.backward()
    optimizer_world.step()

    # === Imagination ===
    o_t = obs_seq[:, -1]
    a_t = act_seq[:, -1]
    emb = encoder(o_t)
    h, s = rssm.init_state(obs_seq.size(0))
    (h, s), *_ = rssm(h, s, a_t, emb)

    s_imag, h_imag, a_imag, logp = imagine_rollout(h, s, actor, rssm, horizon=15)
    rewards = decoder_r(s_imag, h_imag).squeeze(-1)

    # === Value Function Update ===
    loss_v = value_loss(critic, s_imag.detach(), h_imag.detach(), rewards.detach())
    optimizer_critic.zero_grad()
    loss_v.backward()
    optimizer_critic.step()

    # === Policy Update ===
    loss_a = actor_loss(critic, s_imag, h_imag, logp)
    optimizer_actor.zero_grad()
    loss_a.backward()
    optimizer_actor.step()

    return elbo_loss.item(), loss_a.item(), loss_v.item()


In [88]:
import numpy as np
def sample_batch(replay, batch_size, seq_len):
    idxs = np.random.randint(0, len(replay) - seq_len, size=batch_size)
    obs_batch = []
    act_batch = []
    rew_batch = []

    for i in idxs:
        seq = replay[i:i+seq_len]
        obs_seq = torch.tensor([s[0] for s in seq], dtype=torch.float32)
        act_seq = torch.tensor([s[1] for s in seq], dtype=torch.float32)
        rew_seq = torch.tensor([[s[2]] for s in seq], dtype=torch.float32)
        obs_batch.append(obs_seq)
        act_batch.append(act_seq)
        rew_batch.append(rew_seq)

    return (torch.stack(obs_batch),
            torch.stack(act_batch),
            torch.stack(rew_batch))


In [89]:
def train_dreamer_agent(env, num_epochs, steps_per_epoch, batch_size,
                        encoder, decoder_obs, decoder_r,
                        rssm, actor, critic):

    # Initialize replay buffer
    replay = []  # list of (obs, act, rew)
    max_buffer_size = 100000

    # Initialize optimizers
    optim_world = torch.optim.Adam(
        list(encoder.parameters()) +
        list(decoder_obs.parameters()) +
        list(decoder_r.parameters()) +
        list(rssm.parameters()), lr=1e-3)

    optim_actor = torch.optim.Adam(actor.parameters(), lr=8e-5)
    optim_critic = torch.optim.Adam(critic.parameters(), lr=8e-5)

    obs = env.reset()[0]
    done = False
    total_steps = 0

    for epoch in range(num_epochs):
        for step in range(steps_per_epoch):
            h, s = rssm.init_state(1)
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
            emb = encoder(obs_tensor)
            a = torch.zeros(1, action_dim)  # dummy initial action
            (h, s), prior, post, _, _ = rssm(h, s, a, emb)

            dist = actor(s, h)
            action = torch.tanh(dist.sample()).squeeze(0).cpu().numpy()

            next_obs, reward, done, _, _ = env.step(action)
            replay.append((obs, action, reward))
            if len(replay) > max_buffer_size:
                replay.pop(0)

            obs = next_obs
            if done:
                obs = env.reset()[0]

            total_steps += 1

        # Sample batches from replay
        obs_seq, act_seq, rew_seq = sample_batch(replay, batch_size=batch_size, seq_len=15)

        elbo, loss_a, loss_v = dreamer_train_step(
            (obs_seq, act_seq, rew_seq),
            encoder, decoder_obs, decoder_r,
            rssm, actor, critic,
            optim_world, optim_actor, optim_critic
        )

        print(f"[Epoch {epoch}] ELBO: {elbo:.2f} | Actor: {loss_a:.2f} | Critic: {loss_v:.2f}")


In [90]:
env = gym.make("Pendulum-v1")
obs_shape = env.observation_space.shape
action_dim = env.action_space.shape[0]

encoder = ObsEncoder(obs_shape=obs_shape)
decoder_obs = ObsDecoder(obs_shape=obs_shape)
decoder_r = RewardDecoder()
rssm = GRURSSM(action_dim=action_dim)
actor = DreamerActor(action_dim=action_dim)
critic = DreamerCritic()

train_dreamer_agent(
    env=env,
    num_epochs=100,
    steps_per_epoch=500,
    batch_size=16,
    encoder=encoder,
    decoder_obs=decoder_obs,
    decoder_r=decoder_r,
    rssm=rssm,
    actor=actor,
    critic=critic
)
env.close()

[Epoch 0] ELBO: 902.70 | Actor: 0.25 | Critic: 1.64
[Epoch 1] ELBO: 639.60 | Actor: 0.39 | Critic: 2.77
[Epoch 2] ELBO: 516.12 | Actor: 0.40 | Critic: 5.46
[Epoch 3] ELBO: 425.52 | Actor: 0.34 | Critic: 8.47
[Epoch 4] ELBO: 589.68 | Actor: 0.35 | Critic: 10.92
[Epoch 5] ELBO: 457.91 | Actor: 0.36 | Critic: 20.20
[Epoch 6] ELBO: 366.66 | Actor: 0.31 | Critic: 39.98
[Epoch 7] ELBO: 267.97 | Actor: 0.42 | Critic: 51.31
[Epoch 8] ELBO: 368.64 | Actor: 0.43 | Critic: 56.83
[Epoch 9] ELBO: 255.15 | Actor: 0.38 | Critic: 46.16
[Epoch 10] ELBO: 239.39 | Actor: 0.31 | Critic: 42.86
[Epoch 11] ELBO: 286.59 | Actor: 0.38 | Critic: 50.22
[Epoch 12] ELBO: 215.89 | Actor: 0.36 | Critic: 57.63
[Epoch 13] ELBO: 207.47 | Actor: 0.38 | Critic: 57.44
[Epoch 14] ELBO: 185.93 | Actor: 0.46 | Critic: 81.13
[Epoch 15] ELBO: 179.80 | Actor: 0.39 | Critic: 115.11
[Epoch 16] ELBO: 155.46 | Actor: 0.53 | Critic: 186.59
[Epoch 17] ELBO: 120.53 | Actor: 0.45 | Critic: 309.65
[Epoch 18] ELBO: 112.69 | Actor: 0.43 |