In [18]:
import torch
import torch.nn as nn
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

In [19]:
if torch.cuda.is_available():
    device = 'cuda'
    print("GPU selected")
else:
    device = 'cpu'
    print("CPU selected for debugging")

GPU selected


In [57]:
class Encoder(nn.Module): 
    def __init__(self, hidden_state_dim, latent_num_rows, latent_num_columns, num_filters_1, num_filters_2, hidden_layer_nodes, device='cpu'):
        """
        Takes obseravtion (image in this class) and maps it to a latent state representation through a CNN.
        """ 
        super().__init__()
        self.latent_size = latent_num_rows * latent_num_columns
        self.latent_num_rows = latent_num_rows
        self.latent_num_columns = latent_num_columns
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_filters_1, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=num_filters_1, out_channels=num_filters_2, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AdaptiveAvgPool2d((2, 2))
        ) 
        flattened_feature_size = num_filters_2 * 2 * 2
        total_in_features = flattened_feature_size + hidden_state_dim
        self.flatten = nn.Flatten(start_dim=2)
        self.latent_mapper = nn.Sequential(
            nn.Linear(in_features=total_in_features, out_features=hidden_layer_nodes, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_nodes, out_features=self.latent_size, device=device)
        )

    def forward(self, hidden, observation):
        B, S, C, H, W = observation.shape
        observation = observation.view(B * S, C, H, W)
        features = self.feature_extractor(observation)
        _, out_C, out_H, out_W = features.shape
        features = features.view(B, S, out_C, out_H, out_W)
        features = self.flatten(features)
        print(features.shape)
        print(hidden.shape)
        input = torch.cat((features, hidden), dim=-1)
        logits = self.latent_mapper(input)
        return logits
    
    def encode(self, hidden_state, observation):
        B, S, _ = hidden_state.shape
        logits = self.forward(hidden_state, observation)
        logits_reshaped = logits.view(B, S, self.latent_num_rows, self.latent_num_columns)
        dist = torch.distributions.Categorical(logits=logits)
        sampled_idx = dist.sample()
        latent_state_flat = torch.nn.functional.one_hot(sampled_idx, num_classes=self.latent_size)
        latent_state = latent_state_flat.view(B, S, self.latent_num_rows, self.latent_num_columns)
        return latent_state, logits_reshaped

In [21]:
env_id = "CarRacing-v3"
env = gym.make(env_id, continuous=True)

In [52]:
action = env.action_space.sample()
print(action.shape)
observation, _ = env.reset(seed=42)
observation_, reward, term, trun, _ = env.step(action)
observation = observation.transpose(2,0,1)
observation_ = observation_.transpose(2,0,1)
print(observation.shape)
print(observation_.shape)
print(action)
print(reward)
print(term)

observation = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
observation_ = torch.tensor(observation_, dtype=torch.float32, device=device).unsqueeze(0)

observations = torch.cat([observation, observation_], dim=0).unsqueeze(0)
observations = torch.cat([observations, observations], dim=0)
print(observations.shape)

(3,)
(3, 96, 96)
(3, 96, 96)
[0.7508009  0.20699303 0.16375159]
6.967137809187279
False
torch.Size([2, 2, 3, 96, 96])


In [58]:
encoder = Encoder(100, 32, 32, 32, 16, 128, device=device)

In [59]:
hiddens = torch.zeros(2, 2, 100, dtype=torch.float32, device=device)
latent, logits = encoder.encode(hiddens, observations)
print(latent)
print(latent.shape) 
print(logits.shape)

torch.Size([2, 2, 64])
torch.Size([2, 2, 100])
tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]]], device='cuda:0')
t