# Soft Actor-Critic with Autoencoders (SAC-AE)

## Autoencoders

Neural Network that compresses high-dimensional data into a lower-dimensional latent representation and reconstructs the original data from this compressed form.
Components:
-Encoder: maps input $x$ to a latent code $z=E_{\phi}(x)$
-Decoder: Reconstruct input $\hat{x}=D_{\theta}(z)$

### Mathematical formulation
Training objective minimizes reconstruction error, typically Mean Squared Error (MSE):
$$min_{\phi,\theta} E_{x\sim D}[||x-D_{\theta}(E_{\phi}(x))||^2]$$
Regularization terms can be added:
$$L_{AE}(\phi,\theta)=||x-\hat{x}||^2+\lambda_{z}||z||^1+\lambda_{\theta}||\theta||^2$$

1.**Latent Feature Extraction**
Autoencoders learn meaningful, compressed representations, capturing essential data features. High-dimensional inputs (images) er encoded into lower-dimensional latent vectors, preserving essential information.
2.**Why Autoencoders help in RL?**
RL from raw pixels is sample-inefficient, reward signals alone are sparse. Autoencoders provide rich, dense signals via reconstructions, guiding the network toward stable and meaningful features.

## Theoretcial Background of SAC-AE
Standard SAC directly on pixels is inefficient, requiring massive data. Sparse rewards poorly shape CNN feature extraction, causing slow and unstable training.
SAC-AE integrates an autoencoder into SAC, Images ($s_t$) are encoded into latent states ($z_t = E_{\phi}(s_t)$), which actor/critic networks use. The autoencoder provides additional supervision via reconstruction loss, stabilizing and accelerating training.
SAC-AE combines SAC's objective with AE reconstruction loss:
$$J_{SAC-AE} = J_{SAC}(\phi,\psi,\theta)+\beta L_{AE}(\phi,\theta_{dec})$$
-Critic Loss: minimizes Bellman error using encoded states.
-Actor loss: maximizes Q-value of actions and entropy.
-Entropy temperature loss: tunes exploration via entropy.
-Autoencoder Loss: reconstruction loss ensures latent representation preserves image info.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, latent_dim=50):
        super(Encoder, self).__init__()

        self.conv1 = nn.Conv2d(3,32,3,2)
        self.conv2 = nn.Conv2d(32,32,3,1)
        self.conv3 = nn.Conv2d(32,32,3,1)
        self.conv4 = nn.Conv2d(32,32,3,1)

        self.fc = nn.Linear(32*36*36, latent_dim)
        self.ln = nn.LayerNorm(latent_dim)

    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.ln(x)
        x = torch.tanh(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, latent_dim=50):
        super(Decoder, self).__init__()

        self.fc = nn.Linear(latent_dim, 32*36*36)
        
        self.deconv1 = nn.ConvTranspose2d(32,32,3,1)
        self.deconv2 = nn.ConvTranspose2d(32,32,3,1)
        self.deconv3 = nn.ConvTranspose2d(32,32,3,1)
        self.deconv4 = nn.ConvTranspose2d(32,3,3,2,output_padding=1)

    def forward(self,z):
        z = F.relu(self.fc(z))
        z = z.view(z.size(0), 32, 36, 36)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        z = F.relu(self.deconv3(z))
        z = self.deconv4(z)

        recon= torch.tanh(z)
        return recon

In [2]:
class PolicyNetwork(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 2*action_dim)
        )

    def forward(self, z):
        out = self.net(z)
        mu, log_std = out.chunk(2, dim=-1) #split the output into two halves
        log_std = torch.clamp(log_std, -10, 2)
        std = torch.exp(log_std)
        return mu, std
    
class QNetwork(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(QNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim+action_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 1)
        )

    def forward(self, z, a):
        z_a = torch.cat([z,a], dim=-1)
        q = self.net(z_a)
        return q

In [3]:
import random
import numpy as np

class ReplayBuffer:
    def __init__(self, capacity, image_shape, action_dim):
        self.capacity = capacity
        self.ptr = 0
        self.size =0

        self.obs_buf = np.zeros((capacity, *image_shape), dtype=np.float32)
        self.next_obs_buf = np.zeros((capacity, *image_shape), dtype=np.float32)
        self.acts_buf = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rews_buf = np.zeros(capacity, dtype=np.float32)
        self.done_buf = np.zeros(capacity, dtype=np.float32)

    def add(self, obs, action, reward, next_obs, done):
        idx = self.ptr
        self.obs_buf[idx] = obs
        self.next_obs_buf[idx] = next_obs
        self.acts_buf[idx] = action
        self.rews_buf[idx] = reward
        self.done_buf[idx] = done

        self.ptr = (self.ptr+1) % self.capacity
        self.size = min(self.size+1, self.capacity)

    def sample(self, batch_size=32):
        idxs = np.random.choice(self.size, batch_size, replace=False)
        return dict(obs=self.obs_buf[idxs],
                    next_obs=self.next_obs_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])
    

In [4]:
import gymnasium as gym
import torch.optim as optim

In [9]:
env = gym.make("CarRacing-v3")
ae = Encoder(50)
decoder = Decoder(50)
actor = PolicyNetwork(50, env.action_space.shape[0])
critic = QNetwork(50, env.action_space.shape[0])
buffer = ReplayBuffer(100000, (96,96,3), env.action_space.shape[0])

ae_optim = optim.Adam(ae.parameters(), lr=1e-3)
actor_optim = optim.Adam(actor.parameters(), lr=3e-4)
critic_optim = optim.Adam(critic.parameters(), lr=3e-4)

for ep in range(500):
    state, _ = env.reset()
    state = torch.FloatTensor(state).permute(2,0,1).unsqueeze(0) / 255.
    total_reward = 0

    for t in range(1000):
        _, z = ae(state)
        action, _ = actor.sample(z)
        next_state, reward, done, _, _ = env.step(action.detach().numpy())
        buffer.add(state, action, reward, next_state, done)
        state = next_state

        # Train autoencoder
        s_batch, _, _, _, _ = buffer.sample(32)
        recon, _ = ae(s_batch)
        loss_ae = nn.MSELoss()(recon, s_batch)
        ae_optim.zero_grad()
        loss_ae.backward()
        ae_optim.step()

        # Train SAC
        # (same as standard SAC but using latent states)
    
    print(f"Episode {ep+1}: {total_reward}")
env.close()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x53792 and 41472x50)