In [1]:
import torch
import torch.nn as nn
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

In [2]:
if torch.cuda.is_available():
    device = 'cuda'
    print("GPU selected")
else:
    device = 'cpu'
    print("CPU selected for debugging")

GPU selected


In [18]:
class SequenceModel(nn.Module):
    def __init__(self, latent_num_rows, latent_num_columns, hidden_dim, action_dim, *, num_layers=1, device='cpu'):
        super().__init__()
        self.latent_dim = latent_num_columns * latent_num_rows
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.device = device

        self.flatten = nn.Flatten(start_dim=2)
        self.GRU = nn.GRU(
            input_size=self.latent_dim + action_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            batch_first=True,
            device=device
        )

    def forward(self, last_latent_state: torch.tensor, last_hidden_state: torch.tensor, last_action: torch.tensor):
        last_latent_state = self.flatten(last_latent_state)
        input_tensor = torch.cat((last_latent_state, last_action), dim=-1) # (B, S, LR * LC + a)
        _, hidden = self.GRU(input_tensor, last_hidden_state)
        return hidden

In [4]:
env_id = "CarRacing-v3"
env = gym.make(env_id, continuous=True)

In [7]:
last_action = env.action_space.sample()
last_action = torch.tensor(last_action, dtype=torch.float32, device=device). unsqueeze(0).unsqueeze(0)
last_latent = torch.zeros(1, 1, 32, 32, dtype=torch.float32, device=device)
last_hidden = torch.zeros(1, 1, 100, dtype=torch.float32, device=device)

print(last_action.shape)
print(last_action)

torch.Size([1, 1, 3])
tensor([[[0.7011, 0.8323, 0.7808]]], device='cuda:0')


In [19]:
seq_model = SequenceModel(32, 32, 100, 3, num_layers=1, device=device)

In [20]:
hidden = seq_model.forward(last_latent, last_hidden, last_action)

In [21]:
print(hidden.shape)
print(hidden)

torch.Size([1, 1, 100])
tensor([[[-0.0148,  0.0346,  0.0425, -0.0593,  0.0131,  0.0412, -0.0003,
          -0.0260, -0.0390,  0.0393, -0.0533, -0.0210,  0.0254, -0.0484,
           0.0676, -0.0813,  0.0615,  0.0104, -0.0602,  0.0287,  0.0646,
          -0.0297,  0.0726, -0.0370, -0.0611, -0.0099, -0.0132, -0.0336,
           0.0432,  0.0183, -0.0070, -0.0020, -0.0910, -0.0823,  0.0788,
          -0.0059,  0.0723, -0.0125, -0.0208, -0.0840,  0.1058, -0.0017,
           0.0103, -0.0262,  0.0736, -0.1168, -0.0258, -0.0325, -0.0951,
           0.0036, -0.0589,  0.0432, -0.0712, -0.0021,  0.0692,  0.0194,
          -0.0708,  0.0711,  0.0533, -0.0279, -0.0116, -0.0601,  0.0055,
           0.0932, -0.0118, -0.0057, -0.0492,  0.0030,  0.0181, -0.0617,
          -0.0276,  0.0537, -0.0631,  0.0726,  0.0374,  0.0003, -0.0430,
           0.0530, -0.0532, -0.0038,  0.0385, -0.0308, -0.0023, -0.0657,
           0.0002,  0.0027,  0.0427, -0.0411, -0.1083,  0.0702, -0.0729,
          -0.0623,  0.0091,