# Latent-Space Reinforcement Learning

In traditional MBRL approaches like MBPO, the agent learns a predictive model of the environment in raw state space. This means the model directly learns to predict next states and rewards based on previous states and actions. However, this approach struggles in:
1. High-Dimensional Environments:
    - Raw observations contain redundant and unstructured information
    - Learning directly from raw pixels is computationally expensive and requires extensive data.
2. Long-horizon planning:
    - Model errors accumulate over long rollouts, leading to unreliable predictions.
To overcome these challenges, Latent Space MBRL learns a compressed, meaningful representation of the state rather than working directly in raw state space. This approach extracts essential features and enables more accurate predictions with fewer samples.

## Background

A latent space is a low-dimensional, compressed representation of the environment state that preserves only essential features. Instead of working with full observations, the model encodes states into this latent space and makes predictions there.
Common methods to learn latent representations:
- **Autoencoders**: Compress and reconstruct inputs through an encoder-decoder architecture.
- **Variational Autoencoders (VAEs)**: Learn a probabilistic distribution over the latent space.
- **Contrastive Learning**: Learn embeddings that maximize similarity between positive pairs of observations.

## Theory

Instead of modeling the environment in raw state space, we define:
1. **Encoder**: $E(s) \rightarrow$ Maps high.dimensional state $s_t$ to a latent representation $z_t$.
2. **Latent Transition Model** $f(z_t, a_t) \rightarrow$ predicts the next latent state $z_{t+1}$ 
3. **Reward Model** $R(z_t, a_t) \rightarrow$ Predicts the expected reward
4. **Decoder** $D(z) \rightarrow$ Optionally reconstructs $s_t$ from $z_t$

## Math

1. Encoding Observations into Latent Space
We first encode a raw state $s_t$ into a low-dimensional latent state $z_T$:
$$z_t = E_\phi(s_t)$$
where $E_\phi$ is the encoder network with parameters $\phi$.
2. Learning a Latent Transition Model
We train a dynamics model in the latent space:
$$\hat{z}_{t+1} = f_\theta(z_t, a_t)$$
where $f_\theta$ is the latent transition model.
3. Predicting Rewards in Latent Space
Instead of predicting rewards in raw observation space, we use:
$$\hat{r}_t = R_\psi(z_t, a_t)$$
where $R_\psi$ is the reward model.
4. Reconstructing States (Optional)
If needed, we reconstruct the original state form the latent representation:
$$\hat{s}_t = D_\omega(z_t)$$
where $D_\omega$ is the decoder network with parameters $\omega$.

### Loss functions
To train the latent space, we optimize the following:
- Latent Transition Loss:
$$\mathcal{L_{trans}} = \sum_t || \hat{z}_{t+1} - z_{t+1}||^2$$
- Reconstruction Loss (if using autoencoder):
$$\mathcal{L_{rec}} = \sum_t || D_\omega(z_t) - s_t||^2$$
- Reward Prediction Loss:
$$\mathcal{L_{rew}} = \sum_t || \hat{r}_t - r_t||^2$$
By optimizing these loss functions jointly, we ensure the latent space is both useful and predictive.

## Implementation

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
import gymnasium as gym
import matplotlib.pyplot as plt

In [2]:
class ReplayBuffer:
    def __init__(self, max_size=int(1e6)):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        state = np.array(state, dtype=np.float32).flatten()
        action = np.array(action, dtype=np.float32).flatten()
        next_state = np.array(next_state, dtype=np.float32).flatten()
        reward = np.float32(reward)
        done = np.float32(done)

        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (torch.tensor(np.stack(state), dtype=torch.float32),
                torch.tensor(np.stack(action), dtype=torch.float32),
                torch.tensor(reward, dtype=torch.float32).unsqueeze(1),
                torch.tensor(np.stack(next_state), dtype=torch.float32),
                torch.tensor(done, dtype=torch.float32).unsqueeze(1))
    
    def __len__(self):
        return len(self.buffer)

In [3]:
class Encoder(nn.Module):
    def __init__(self, state_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        std = torch.exp(0.5 * logvar)
        return mu+std*torch.randn_like(std)


In [4]:
class LatentTransitionModel(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(LatentTransitionModel, self).__init__()
        self.fc1 = nn.Linear(latent_dim+action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, latent_dim)
        
    def forward(self, z, a):
        x = torch.cat([z, a], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [5]:
class RewardModel(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(RewardModel, self).__init__()
        self.fc1 = nn.Linear(latent_dim+action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    def forward(self, z, a):
        x = torch.cat([z, a], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [6]:
class Actor(nn.Module):
    def __init__(self, latent_dim, action_dim, max_action, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action

    def forward(self, latent):
        x = F.relu(self.fc1(latent))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, -20, 2)
        return mean, log_std
    
    def sample(self, latent):
        mean, log_std = self.forward(latent)
        std = torch.exp(log_std)
        normal = torch.distributions.Normal(mean, std)
        x = normal.rsample()
        y = torch.tanh(x)
        action = y * self.max_action
        log_prob = normal.log_prob(x) - torch.log(1 - y.pow(2) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        return action, log_prob



In [11]:
class LatentMBPO:
    def __init__(self, state_dim, action_dim, latent_dim, max_action):
        self.encoder = Encoder(state_dim, latent_dim)
        self.dynamics_model = LatentTransitionModel(latent_dim, action_dim)
        self.reward_model = RewardModel(latent_dim, action_dim)
        self.policy = Actor(state_dim, action_dim, max_action)
        
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=3e-4)
        self.dynamics_model_optimizer = optim.Adam(self.dynamics_model.parameters(), lr=3e-4)
        self.reward_model_optimizer = optim.Adam(self.reward_model.parameters(), lr=3e-4)

    def select_action(self, state):
        z = self.encoder(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
        action, _ = self.policy.sample(z)
        return action.detach().cpu().numpy().squeeze(0)
    
    def train(self, real_buffer):
        state, action, reward, next_state, done = real_buffer.sample(256)
        
        z = self.encoder(state)
        z_next = self.encoder(next_state)

        z_pred = self.dynamics_model(z, action)
        r_pred = self.reward_model(z, action)

        transition_loss = F.mse_loss(z_pred, z_next)
        reward_loss = F.mse_loss(r_pred, reward)

        # Update dynamics model
        self.dynamics_model_optimizer.zero_grad()
        transition_loss.backward(retain_graph=True)
        self.dynamics_model_optimizer.step()

        # Update reward model
        self.reward_model_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_model_optimizer.step()

In [12]:
env = gym.make('HalfCheetah-v5')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

agent = LatentMBPO(state_dim, action_dim, 17, max_action)
real_buffer = ReplayBuffer()

num_episodes = 500
rewards = []

for episode in range(num_episodes):
    state, _ = env.reset()
    total_reward = 0.0

    for step in range(1000):
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward
        real_buffer.add(state, action, reward, next_state, done)
        state = next_state if not done else env.reset()

        if len(real_buffer) > 10000:
            agent.train(real_buffer)

    rewards.append(total_reward)
    print(f'Episode: {episode+1}, Reward: {total_reward}')

plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('HalfCheetah-v5')
plt.show()



