In [2]:
import torch
import torch.nn as nn
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from DreamerUtils import symexp, symlog

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
    print("GPU selected")
else:
    device = 'cpu'
    print("CPU selected for debugging")

GPU selected


In [9]:
class Critic(nn.Module):
    def __init__(self, latent_row_dim, latent_column_dim, hidden_state_dim, hidden_layer_num_nodes_1, hidden_layer_num_nodes_2, num_buckets=255, device='cpu'):
        super().__init__()
        self.latent_row_dim = latent_row_dim
        self.latent_column_dim = latent_column_dim
        self.num_buckets = num_buckets
        self.flatten = nn.Flatten(start_dim=2)
        self.value_net = nn.Sequential(
            nn.Linear(in_features=latent_column_dim * latent_row_dim + hidden_state_dim, out_features=hidden_layer_num_nodes_1, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_num_nodes_1, out_features=hidden_layer_num_nodes_2, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_num_nodes_2, out_features=num_buckets, device=device)
        )
        bucket_values = symexp(torch.linspace(-20, 20, num_buckets, device=device))
        self.register_buffer('buckets_crit', bucket_values)

    def forward(self, ht, zt):
        flattened_zt = self.flatten(zt)
        st = torch.cat([ht, flattened_zt], dim=-1)
        logits = self.value_net(st)
        return logits
    
    def value(self, ht, zt):
        logits = self.forward(ht, zt)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        value = torch.sum(probs * self.buckets_crit, dim=-1, keepdim=True)
        return value 

In [5]:
env_id = "CarRacing-v3"
env = gym.make(env_id, continuous=True)

In [10]:
critic = Critic(32, 32, 100, 128, 64, 255, device=device)

In [13]:
hiddens = torch.zeros(2, 2, 100, dtype=torch.float32, device=device)
latents = torch.zeros(2, 2, 32, 32, dtype=torch.float32, device=device)
print(latents)
print(latents.shape) 
print(hiddens.shape)

hiddens2 = torch.zeros(100, dtype=torch.float32, device=device)
latents2 = torch.zeros(32, 32, dtype=torch.float32, device=device)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
        

In [14]:
values = critic.value(hiddens, latents)
values2 = critic.value(hiddens2, latents2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [12]:
print(values.shape)
print(values)

torch.Size([2, 2, 1])
tensor([[[-74055.7500],
         [-74055.2500]],

        [[-74055.7500],
         [-74055.5000]]], device='cuda:0', grad_fn=<SumBackward1>)
