In [18]:
import torch
import torch.nn as nn
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

In [19]:
if torch.cuda.is_available():
    device = 'cuda'
    print("GPU selected")
else:
    device = 'cpu'
    print("CPU selected for debugging")

GPU selected


In [57]:
class Encoder(nn.Module): 
    def __init__(self, hidden_state_dim, latent_num_rows, latent_num_columns, num_filters_1, num_filters_2, hidden_layer_nodes, device='cpu'):
        """
        Takes obseravtion (image in this class) and maps it to a latent state representation through a CNN.
        """ 
        super().__init__()
        self.latent_size = latent_num_rows * latent_num_columns
        self.latent_num_rows = latent_num_rows
        self.latent_num_columns = latent_num_columns
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_filters_1, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=num_filters_1, out_channels=num_filters_2, kernel_size=3, stride=1, padding=1, device=device),
            nn.SiLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AdaptiveAvgPool2d((2, 2))
        ) 
        flattened_feature_size = num_filters_2 * 2 * 2
        total_in_features = flattened_feature_size + hidden_state_dim
        self.flatten = nn.Flatten(start_dim=2)
        self.latent_mapper = nn.Sequential(
            nn.Linear(in_features=total_in_features, out_features=hidden_layer_nodes, device=device),
            nn.SiLU(),
            nn.Linear(in_features=hidden_layer_nodes, out_features=self.latent_size, device=device)
        )

    def forward(self, hidden, observation):
        B, S, C, H, W = observation.shape
        observation = observation.view(B * S, C, H, W)
        features = self.feature_extractor(observation)
        _, out_C, out_H, out_W = features.shape
        features = features.view(B, S, out_C, out_H, out_W)
        features = self.flatten(features)
        print(features.shape)
        print(hidden.shape)
        input = torch.cat((features, hidden), dim=-1)
        logits = self.latent_mapper(input)
        return logits
    
    def encode(self, hidden_state, observation):
        B, S, _ = hidden_state.shape
        logits = self.forward(hidden_state, observation)
        logits_reshaped = logits.view(B, S, self.latent_num_rows, self.latent_num_columns)
        dist = torch.distributions.Categorical(logits=logits)
        sampled_idx = dist.sample()
        latent_state_flat = torch.nn.functional.one_hot(sampled_idx, num_classes=self.latent_size)
        latent_state = latent_state_flat.view(B, S, self.latent_num_rows, self.latent_num_columns)
        return latent_state, logits_reshaped

In [21]:
env_id = "CarRacing-v3"
env = gym.make(env_id, continuous=True)

In [52]:
action = env.action_space.sample()
print(action.shape)
observation, _ = env.reset(seed=42)
observation_, reward, term, trun, _ = env.step(action)
observation = observation.transpose(2,0,1)
observation_ = observation_.transpose(2,0,1)
print(observation.shape)
print(observation_.shape)
print(action)
print(reward)
print(term)

observation = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
observation_ = torch.tensor(observation_, dtype=torch.float32, device=device).unsqueeze(0)

observations = torch.cat([observation, observation_], dim=0).unsqueeze(0)
observations = torch.cat([observations, observations], dim=0)
print(observations.shape)

(3,)
(3, 96, 96)
(3, 96, 96)
[0.7508009  0.20699303 0.16375159]
6.967137809187279
False
torch.Size([2, 2, 3, 96, 96])


In [58]:
encoder = Encoder(100, 32, 32, 32, 16, 128, device=device)

In [59]:
hiddens = torch.zeros(2, 2, 100, dtype=torch.float32, device=device)
latent, logits = encoder.encode(hiddens, observations)
print(latent)
print(latent.shape) 
print(logits.shape)

torch.Size([2, 2, 64])
torch.Size([2, 2, 100])
tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]]], device='cuda:0')
t

In [None]:
def unroll_model(
            self,
            observation_sequence_batch: torch.tensor, 
            action_sequence_batch: torch.tensor, 
            reward_sequence_batch: torch.tensor, 
            continue_sequence_batch: torch.tensor
        ):
        def single_sequence_unroll(observation_sequence, action_sequence, reward_sequence, continue_sequence):
            posterior_logits = []
            prior_logits = []
            obs_likelyhood_seq = []
            rew_likelyhood_seq = [] 
            cont_likelyhood_seq = []

            observation_sequence = observation_sequence.unsqueeze(0)
            action_sequence = action_sequence.unsqueeze(0)
            reward_sequence = reward_sequence.unsqueeze(0)
            continue_sequence = continue_sequence.unsqueeze(0)

            hidden_state = torch.zeros(1, 1, self.hidden_dims, device=self.device)
            posterior_latent = torch.zeros(1, 1, self.latent_num_rows, self.latent_num_columns, device=self.device)

            for t in range(self.horizon):
                prev_action = action_sequence[:, t-1:t] if t > 0 else torch.zeros(1, 1, self.action_dims, device=self.device)
                posterior_latent, hidden_state, posterior_logits_t = self.observe_step(
                    posterior_latent,
                    hidden_state,
                    prev_action,
                    observation_sequence[:, t:t+1] 
                )
                reward_th = to_twohot(reward_sequence[:, t:t+1], self.reward_predictor.buckets_rew)

                prior_latent_logits = self.dynamics_predictor(hidden_state)
                dec_mu, dec_sig = self.decoder(hidden_state, posterior_latent)
                rew_logits = self.reward_predictor(hidden_state, posterior_latent) 
                _, cont_logits = self.continue_predictor(hidden_state, posterior_latent)

                observation_log_likelyhood = gaussian_log_probability(observation_sequence[:, t:t+1], dec_mu, dec_sig)
                continue_log_likelyhood = torch.nn.functional.binary_cross_entropy_with_logits(
                    cont_logits,
                    continue_sequence[:, t:t+1],
                    reduction='none'
                )
                reward_log_probs = torch.nn.functional.log_softmax(rew_logits, dim=-1)
                reward_log_likelyhood = reward_th * reward_log_probs

                posterior_logits.append(posterior_logits_t)
                prior_logits.append(prior_latent_logits)
                obs_likelyhood_seq.append(observation_log_likelyhood)
                rew_likelyhood_seq.append(reward_log_likelyhood)
                cont_likelyhood_seq.append(continue_log_likelyhood)
            
            prior_logits = torch.stack(prior_logits, dim=1)
            posterior_logits = torch.stack(posterior_logits, dim=1)
            obs_likelyhood_seq = torch.stack(obs_likelyhood_seq, dim=1)
            rew_likelyhood_seq = torch.stack(rew_likelyhood_seq, dim=1)
            cont_likelyhood_seq = torch.stack(cont_likelyhood_seq, dim=1)
            
            return (
                prior_logits.squeeze(0), 
                posterior_logits.squeeze(0), 
                obs_likelyhood_seq.squeeze(0), 
                rew_likelyhood_seq.squeeze(0), 
                cont_likelyhood_seq.squeeze(0)
            )

        batched_sequence_unroll = torch.vmap(single_sequence_unroll, in_dims=(0,0,0,0), randomness='different')
        prior_logits, posterior_logits, obs_log_lh, rew_log_lh, cont_log_lh = batched_sequence_unroll(
            observation_sequence_batch,
            action_sequence_batch,
            reward_sequence_batch,
            continue_sequence_batch
        )
        return (
            prior_logits, 
            posterior_logits, 
            obs_log_lh, 
            rew_log_lh, 
            cont_log_lh
        )

In [None]:
def unroll_model(
        self,
        observation_sequence_batch: torch.tensor, 
        action_sequence_batch: torch.tensor, 
        reward_sequence_batch: torch.tensor, 
        continue_sequence_batch: torch.tensor
    ):
    prior_logits = []
    posterior_logits = []
    obs_log_lh = []
    rew_log_lh = []
    cont_log_lh = []
    B = continue_sequence_batch.shape[0]
    hidden_state_batch = torch.zeros(B, 1, self.hidden_dims, dtype=torch.float32, device=self.device)
    posterior_latent_batch = torch.zeros(B, 1, self.latent_num_rows, self.latent_num_columns, dtype=torch.float32, device=self.device)
    for t in range(self.horizon):
        prev_action_batch = action_sequence_batch[:, t-1:t] if t > 0 else torch.zeros(B, 1, self.action_dims, device=self.device)
        posterior_latent_batch, hidden_state_batch, posterior_logits_batch = self.observe_step(
            posterior_latent_batch,
            hidden_state_batch,
            prev_action_batch,                    
            observation_sequence_batch[:, t:t+1] 
        )
        reward_th_batch = to_twohot(reward_sequence_batch[:, t:t+1], self.reward_predictor.buckets_rew)

        prior_latent_logits_batch = self.dynamics_predictor(hidden_state_batch)
        dec_mu_batch, dec_sig_batch = self.decoder(hidden_state_batch, posterior_latent_batch)
        reward_logits_batch = self.reward_predictor(hidden_state_batch, posterior_latent_batch) 
        _, cont_logits_batch = self.continue_predictor(hidden_state_batch, posterior_latent_batch)

        observation_log_likelyhood_batch = gaussian_log_probability(observation_sequence_batch[:, t:t+1], dec_mu_batch, dec_sig_batch)
        continue_log_likelyhood_batch = torch.nn.functional.binary_cross_entropy_with_logits(
            cont_logits_batch,
            continue_sequence_batch[:, t:t+1],
            reduction='none'
        )
        reward_log_probs_batch = torch.nn.functional.log_softmax(reward_logits_batch, dim=-1)
        reward_log_likelyhood_batch = torch.sum(reward_th_batch * reward_log_probs_batch, dim=-1, keepdim=True)

        posterior_logits.append(posterior_logits_batch)
        prior_logits.append(prior_latent_logits_batch)
        obs_log_lh.append(observation_log_likelyhood_batch)
        rew_log_lh.append(reward_log_likelyhood_batch)
        cont_log_lh.append(continue_log_likelyhood_batch)

    posterior_logits = torch.stack(posterior_logits, dim=1).squeeze(dim=2)
    prior_logits = torch.stack(prior_logits, dim=1).squeeze(dim=2)
    obs_log_lh = torch.stack(obs_log_lh, dim=1).squeeze(dim=2)
    rew_log_lh = torch.stack(rew_log_lh, dim=1).squeeze(dim=2)
    cont_log_lh = torch.stack(cont_log_lh, dim=1).squeeze(dim=2)
    return (
            prior_logits, 
            posterior_logits, 
            obs_log_lh, 
            rew_log_lh, 
            cont_log_lh
        )