In [2]:
import torch
import torch.nn as nn
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
    print("GPU selected")
else:
    device = 'cpu'
    print("CPU selected for debugging")

GPU selected


In [None]:
class Encoder(nn.Module): 
    def __init__(self, hidden_state_dim, latent_num_rows, latent_num_columns, num_filters_1, num_filters_2, hidden_layer_nodes, device='cpu'):
        """
        Takes obseravtion (image in this class) and maps it to a latent state representation through a CNN.
        """ 
        super().__init__()
        self.latent_size = latent_num_rows * latent_num_columns
        self.latent_num_rows = latent_num_rows
        self.latent_num_columns = latent_num_columns
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_filters_1, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=num_filters_1, out_channels=num_filters_2, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AdaptiveAvgPool2d((2, 2))
        ) 
        flattened_feature_size = num_filters_2 * 2 * 2
        total_in_features = flattened_feature_size + hidden_state_dim
        self.flatten = nn.Flatten(start_dim=2)
        self.latent_mapper = nn.Sequential(
            nn.Linear(in_features=total_in_features, out_features=hidden_layer_nodes, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_nodes, out_features=self.latent_size, device=device)
        )

    def forward(self, hidden, observation):
        B, S, C, H, W = observation.shape
        observation = observation.view(B * S, C, H, W)
        features = self.feature_extractor(observation)
        _, out_C, out_H, out_W = features.shape
        features = features.view(B, S, out_C, out_H, out_W)
        features = self.flatten(features)
        input = torch.cat((features, hidden), dim=-1)
        logits = self.latent_mapper(input)
        return logits
    
    def encode(self, hidden_state, observation):
        B, S, _ = hidden_state.shape
        logits = self.forward(hidden_state, observation)
        logits_reshaped = logits.view(B, S, self.latent_num_rows, self.latent_num_columns)
        dist = torch.distributions.Categorical(logits=logits)
        sampled_idx = dist.sample()
        latent_state_flat = torch.nn.functional.one_hot(sampled_idx, num_classes=self.latent_size)
        latent_state = latent_state_flat.view(B, S, self.latent_num_rows, self.latent_num_columns)
        return latent_state, logits_reshaped

In [5]:
env_id = "CarRacing-v3"
env = gym.make(env_id, continuous=True)

In [6]:
action = env.action_space.sample()
print(action.shape)
observation, _ = env.reset(seed=42)
observation_, reward, term, trun, _ = env.step(action)
observation = observation.transpose(2,0,1)
observation_ = observation_.transpose(2,0,1)
print(observation.shape)
print(observation_.shape)
print(action)
print(reward)
print(term)

observation = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
observation_ = torch.tensor(observation_, dtype=torch.float32, device=device).unsqueeze(0)

observations = torch.cat([observation, observation_], dim=0).unsqueeze(0)
observations = torch.cat([observations, observations], dim=0)
print(observations.shape)

(3,)
(3, 96, 96)
(3, 96, 96)
[-0.16066334  0.939331    0.6702185 ]
6.967137809187279
False
torch.Size([2, 2, 3, 96, 96])


In [7]:
encoder = Encoder(100, 32, 32, 32, 16, 128, device=device)

In [8]:
hiddens = torch.zeros(2, 2, 100, dtype=torch.float32, device=device)
latent, logits = encoder.encode(hiddens, observations)
print(latent)
print(latent.shape) 
print(logits.shape)

torch.Size([2, 2, 64])
torch.Size([2, 2, 100])
tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]]], device='cuda:0')
t

In [9]:
class Decoder(nn.Module):
    """
    Takes a latent state and maps it to the image it was created by.
    """
    def __init__(self, latent_num_rows, latent_num_columns, observation_dim, hidden_state_dim, num_filters_1, num_filters_2, hidden_layer_nodes, device='cpu'):
        super().__init__()
        self.upscale_starting_dim = observation_dim[0] // 4
        self.num_filters_2 = num_filters_2

        self.hidden_dim = hidden_state_dim
        self.latent_row_dim = latent_num_rows
        self.latent_col_dim = latent_num_columns

        self.flatten = nn.Flatten(start_dim=1)
        self.upscaler = nn.Sequential(
            nn.Linear(in_features=latent_num_rows * latent_num_columns + hidden_state_dim, out_features=hidden_layer_nodes, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_nodes, out_features=num_filters_2 * self.upscale_starting_dim * self.upscale_starting_dim, device=device),
            nn.SiLU()
        )
        self.image_builder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_filters_2, out_channels=num_filters_1, kernel_size=4, stride=2, padding=1, device=device),
            nn.SiLU(),
            nn.ConvTranspose2d(in_channels=num_filters_1, out_channels=6, kernel_size=4, stride=2, padding=1, device=device)
        )
        self.softplus = nn.Softplus()

    def forward(self, hidden: torch.tensor, latent: torch.tensor):
        B, S, _ = hidden.shape
        hidden = hidden.view(B * S, self.hidden_dim)
        latent = latent.view(B * S, self.latent_row_dim, self.latent_col_dim)

        latent = self.flatten(latent)
        input = torch.cat((hidden, latent), dim=-1)

        x = self.upscaler(input)
        x = x.view(-1, self.num_filters_2, self.upscale_starting_dim, self.upscale_starting_dim)

        obs_params = self.image_builder(x)
        mu, sigma_logits = torch.chunk(obs_params, chunks=2, dim=1)
        sigma = self.softplus(sigma_logits) + 1e-4

        _, C, H, W = mu.shape
        mu = mu.view(B, S, C, H, W)
        sigma = sigma.view(B, S, C, H, W)
        return mu, sigma
    
    def decode(self, hidden_state: torch.tensor, latent_state: torch.tensor):
        mu, sigma = self.forward(hidden_state, latent_state)
        dist = torch.distributions.Independent(torch.distributions.Normal(loc=mu, scale=sigma), 3)
        observation = dist.rsample()
        return observation, mu, sigma

In [10]:
decoder = Decoder(32, 32, (96, 96), 100, 32, 16, 128, device=device)

In [None]:
hiddens = torch.zeros(2, 2, 100, dtype=torch.float32, device=device)
observation_decode, mu, sigma = decoder.decode(hiddens, latent)

In [13]:
print(observation_decode.shape)
print(mu.shape)
print(sigma.shape)
print(observation_decode)

torch.Size([2, 2, 3, 96, 96])
torch.Size([2, 2, 3, 96, 96])
torch.Size([2, 2, 3, 96, 96])
tensor([[[[[ 6.7391e-01, -4.7964e-02, -5.7319e-01,  ..., -4.0895e-01,
            -5.4148e-01, -3.7130e-01],
           [-1.4428e+00, -4.7047e-01, -5.6548e-01,  ...,  5.4288e-01,
             1.3301e+00,  1.4066e+00],
           [ 6.3714e-01, -7.5179e-01, -9.0108e-01,  ..., -7.5696e-01,
             6.3301e-01, -5.8308e-02],
           ...,
           [-1.1645e-01,  1.7642e-01, -5.4705e-01,  ...,  2.0863e-01,
             1.2980e-01, -6.3522e-01],
           [ 6.8784e-01,  2.6905e-01,  5.9810e-01,  ..., -2.4836e-01,
             3.8437e-01, -1.5331e-01],
           [ 5.2590e-01,  5.5076e-02,  4.4435e-01,  ..., -1.1557e+00,
             5.0378e-02, -6.5019e-01]],

          [[-4.2576e-01,  5.0399e-01,  1.7178e-01,  ...,  6.1553e-01,
             1.0620e-01, -4.8056e-01],
           [ 3.6043e-01,  6.3996e-01, -3.9948e-01,  ...,  9.1081e-01,
             8.0584e-03,  1.5238e+00],
           [-5.4729e