In [None]:
import os
import ray
import math
import time
import random
import shutil

import numpy as np
import pandas as pd
import gymnasium as gym
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from gymnasium.spaces import Box, Discrete
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

from ray import tune, air
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.tune.registry import register_env
from ray.tune.schedulers import ASHAScheduler

import ray.rllib.algorithms.ppo as dqn
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.utils.annotations import override
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

path = os.getcwd()

In [None]:
torch, nn = try_import_torch()
ray.init()

In [None]:
class NLL_EpistemicNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(NLL_EpistemicNNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)

        torch.autograd.set_detect_anomaly(True)

        #get layers from config
        hidden_layer0 = model_config['fcnet_hiddens'][0]
        hidden_layer1 = model_config['fcnet_hiddens'][1]
        enn_layer = 50
        #object instance variables
        self.std = 1.0
        # self.seed = 15546
        self.mean = 0.0
        self.gamma = 0.99
        self.step_number = 0
        self.z_indices = None
        self.step_cut_off = 200
        self.adder = 1.000000001
        # random.seed(self.seed)
        self.elu = torch.nn.ELU() 
        # np.random.seed(self.seed)
        # torch.manual_seed(self.seed)
        self.num_actions = action_space.shape[0]
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.z_dim = model_config['custom_model_config'].get('z_dim', 5)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_gaussians = model_config['custom_model_config'].get('num_gaussians', 3)
        self.distribution = Normal(torch.full((self.z_dim,), self.mean), torch.full((self.z_dim,), self.std))
        
        self.actor_network = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")
        self.base_in = SlimFC(obs_space.shape[0], hidden_layer0, initializer=self.initializer, activation_fn=self.activation_fn)
        self.base_1 = SlimFC(hidden_layer0, hidden_layer1, initializer=self.initializer, activation_fn=self.activation_fn)
        self.base_out = SlimFC(hidden_layer1, self.num_gaussians*3, initializer=self.initializer, activation_fn=self.activation_fn)
        #enn learnable network
        self.enn_learnable_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_out = SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)
        #prior network (learnable for x steps and then static afterwards)
        self.prior_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_out = SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)        
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs_flat"].float()        
        action_logits, _ = self.actor_network(input_dict, state, seq_lens)     
        batch_size = obs.shape[0]
        
        base_in = self.base_in(obs)
        base_1 = self.base_1(base_in).to(self.device)
        base_1_detached_unsqueeze = torch.unsqueeze(base_1, 1).detach()
        base_out = self.base_out(base_1)
        base_out_detached = base_out.detach()

        self.z_indices = self.distribution.sample((batch_size,)).to(self.device)
        self.z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_input = torch.cat((self.z_unsqueeze, base_1_detached_unsqueeze.expand(-1, self.z_dim, -1)), dim=2)

        if self.step_number < self.step_cut_off:
            #prior value (learnable)
            prior_in = self.prior_in(enn_input)
            prior_1 = self.prior_1(prior_in)
            prior = self.prior_out(prior_1)
            # action_logits = torch.normal(mean = 0, std = 1, size=(1, self.num_actions*2)).to(self.device)
        else:
            with torch.no_grad():
                #prior value (static)
                prior_in = self.prior_in(enn_input)
                prior_1 = self.prior_1(prior_in)
                prior = self.prior_out(prior_1)
                
        prior_bmm = torch.bmm(torch.transpose(prior, 1, 2), self.z_unsqueeze)
        prior_out = prior_bmm.squeeze(-1)
        # learnable on features and z concat
        learnable_in = self.enn_learnable_in(enn_input)
        learnable_1 = self.enn_learnable_1(learnable_in)
        learnable = self.enn_learnable_out(learnable_1)
        learnable_bmm = torch.bmm(torch.transpose(learnable, 1, 2), self.z_unsqueeze)
        learnable_out = learnable_bmm.squeeze(-1)

        # add enn networks together to then add to the weighted average in the value function
        self.enn_out = torch.mean(learnable_out + prior_out, dim = -1)
        '''
        check the squeeze dimensions above as well -- am I losing information?

        
        this mean helps in the addition process to get a tensor of the batch size out
        --> (128, 1, 1) --squeeze--> (128,1) --mean of last dimension--> (128,)
        the value function returns the same after summing the mean * alphas --gives--> (128,)
        '''
        #means
        means = base_out_detached[:, :self.num_gaussians]
        self._u = means
        #sigmas
        sigmas = base_out_detached[:, self.num_gaussians:self.num_gaussians*2]
        sigmas = self.elu(sigmas) + self.adder
        self._sigmas = sigmas
        #weights
        alphas = base_out_detached[:, self.num_gaussians*2:]
        alphas = torch.nn.functional.softmax(alphas, dim=-1)
        self._alphas = alphas

        self.step_number += 1
            
            
        return action_logits, state
        
    @override(TorchModelV2)
    def value_function(self):
        multiply = self._u * self._alphas
        self.critic_value = torch.sum(multiply, dim = 1)
        self.final_critic_value = self.critic_value + self.enn_out
        return self.final_critic_value


    def predict_gmm_params(self, observation):
        #current and next observation through networks with same z_dim as above
        base_in = self.base_in(observation)
        base_1 = self.base_1(base_in).to(self.device)
        base_out = self.base_out(base_1)

        means = base_out[:, :self.num_gaussians]
        sigmas_prev = base_out[:, self.num_gaussians:self.num_gaussians*2]
        sigmas = self.elu(sigmas_prev) + self.adder
        alphas = base_out[:, self.num_gaussians*2:]
        
        return means, sigmas, alphas
    
    def compute_log_likelihood(self, td_targets, mu_pred, sigma_pred, alphas_pred):
        
        td_targets_expanded = td_targets.unsqueeze(1)
        
        sigma_clamped = torch.clamp(sigma_pred, 1e-7, None)
        denominator = (2*torch.square(sigma_clamped)).clamp_min(1e-7)
        # alphas_clamped = torch.clamp(alpha_pred, 1e-30, 1e5)
        
        log_2_pi = torch.log(2*torch.tensor(math.pi))
        
        mus = td_targets_expanded - mu_pred
        alphas = alphas_pred.clamp_min(1e-7)
        
        logp = (-torch.log(sigma_clamped) - .5 * log_2_pi - torch.square(mus)) / denominator
        loga = torch.nn.functional.log_softmax(alphas, dim=-1)

        summing_log = -torch.logsumexp(logp + loga, dim=-1)
        
        return summing_log 
        
    def enn_loss(self, next_obs, rewards, dones):
        #base target value        
        next_base_in = self.base_in(next_obs)
        next_base_1 = self.base_1(next_base_in)
        next_base_1_detached_unsqueeze = torch.unsqueeze(next_base_1, 1).detach()
        next_base_out = self.base_out(next_base_1).detach()
    
        #enn target value
        next_enn_input = torch.cat((self.z_unsqueeze, next_base_1_detached_unsqueeze.expand(-1, self.z_dim, -1)), dim=2)
        next_learnable_in = self.enn_learnable_in(next_enn_input)
        next_learnable_1 = self.enn_learnable_1(next_learnable_in)
        next_learnable = self.enn_learnable_out(next_learnable_1)
        next_learnable_bmm = torch.bmm(torch.transpose(next_learnable, 1, 2), self.z_unsqueeze)
        next_learnable_out = next_learnable_bmm.squeeze(-1)
    
        if self.step_number < self.step_cut_off:
            #prior target value (learnable)
            next_prior_in = self.prior_in(next_enn_input)
            next_prior_1 = self.prior_1(next_prior_in)
            next_prior = self.prior_out(next_prior_1)
        else:
            with torch.no_grad():
                #prior target value (static)
                next_prior_in = self.prior_in(next_enn_input)
                next_prior_1 = self.prior_1(next_prior_in)
                next_prior = self.prior_out(next_prior_1)
    
        next_prior_bmm = torch.bmm(torch.transpose(next_prior, 1, 2), self.z_unsqueeze)
        next_prior_out = next_prior_bmm.squeeze(-1)
    
        # add all networks together to get final action values
        means = next_base_out[:, :self.num_gaussians]
        sigmas_prev = next_base_out[:, self.num_gaussians:self.num_gaussians*2]
        sigmas = self.elu(sigmas_prev) + self.adder
        alphas = next_base_out[:, self.num_gaussians*2:]
        # get the weighted average plus the ENN output as in the forward pass
        self.enn_target = torch.mean(next_learnable_out + next_prior_out, dim = -1)
        # again, I think I have to take the mean above to get a tensor of batch size rather than a single scalar output over the batch (previously summed)
        next_values = torch.sum(means * alphas, dim = -1) + self.enn_target

        # make up the TD target for the ENN
        enn_target = rewards + self.gamma * next_values.clone().detach() * (1 - dones.float())
        enn_loss_square = torch.square(self.final_critic_value - enn_target)
        enn_loss = torch.mean(enn_loss_square)
        
        return enn_loss
        
    @override(TorchModelV2)
    def custom_loss(self, policy_loss, sample_batch):

        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_states = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]

        enn_loss = self.enn_loss(next_states, rewards, dones)

        mu_pred, sigma_pred, w_pred = self.predict_gmm_params(cur_obs)
        mu_target, sigma_target, w_target = self.predict_gmm_params(next_states)
        # only the target alphas go through softmax since predicted alphas go through log softmax later
        w_target = torch.nn.functional.softmax(w_target, dim = -1)

        next_state_value = torch.sum(mu_target * w_target, dim = 1)
        td_targets = rewards + self.gamma * next_state_value.detach() * (1 - dones.float())
        
        log_likelihood = self.compute_log_likelihood(td_targets, mu_pred, sigma_pred, w_pred)
        log_likelihood = torch.clamp(log_likelihood, -10, 80)
        nll_loss = torch.mean(log_likelihood)

        if self.step_number % 1_000 == 0:
            print(f"policy loss: {policy_loss} enn loss: {enn_loss} nll loss: {nll_loss}")
        
        total_loss = [loss + (nll_loss + enn_loss) for loss in policy_loss]
        
        return total_loss        


# Register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("NLL_ENNModel", NLL_EpistemicNNModel)

In [None]:
class EpistemicNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(EpistemicNNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)

        torch.autograd.set_detect_anomaly(True)

        #get layers from config
        hidden_layer0 = model_config['fcnet_hiddens'][0]
        hidden_layer1 = model_config['fcnet_hiddens'][1]
        enn_layer = 50
        #object instance variables
        self.std = 1.0
        # self.seed = 15546
        self.mean = 0.0
        self.gamma = 0.99
        self.step_number = 0
        self.z_indices = None
        # random.seed(self.seed)
        # np.random.seed(self.seed)
        self.action_space_size = 4
        # torch.manual_seed(self.seed)
        self.num_actions = action_space.shape[0]
        self.initializer = torch.nn.init.xavier_normal_
        self.initializer_old = torch.nn.init.kaiming_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.z_dim = model_config['custom_model_config'].get('z_dim', 5)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.distribution = Normal(torch.full((self.z_dim,), self.mean), torch.full((self.z_dim,), self.std))
        #build main actor/critic networks
        # self.critic_network = TorchFC(obs_space, action_space, 1, model_config, name + "_q_testing")
        self.actor_network = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_q_testing")
        self.base_in = SlimFC(obs_space.shape[0], hidden_layer0, initializer=self.initializer, activation_fn=self.activation_fn)
        self.base_1 = SlimFC(hidden_layer0, hidden_layer1, initializer=self.initializer, activation_fn=self.activation_fn)
        self.base_out = SlimFC(hidden_layer1, 1, initializer=self.initializer, activation_fn=self.activation_fn)
        #enn learnable network
        self.enn_learnable_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_out = SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)
        #prior network (learnable for x steps and then static afterwards)
        self.prior_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_out = SlimFC(enn_layer, 1, initializer=self.initializer, activation_fn=self.activation_fn)        
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs_flat"].float()        
        action_logits, _ = self.actor_network(input_dict, state, seq_lens)
        batch_size = obs.shape[0]

        base_in = self.base_in(obs)
        base_1 = self.base_1(base_in).to(self.device)
        base_1_detach_enn = torch.unsqueeze(base_1, 1).clone().detach()
        base_out = self.base_out(base_1).detach()

        self.z_indices = self.distribution.sample((batch_size,)).to(self.device)
        self.z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_input = torch.cat((self.z_unsqueeze, base_1_detach_enn.expand(-1, self.z_dim, -1)), dim=2)

        if self.step_number < 200:
            #prior value (learnable)
            prior_in = self.prior_in(enn_input)
            prior_1 = self.prior_1(prior_in)
            prior = self.prior_out(prior_1)
        else:
            with torch.no_grad():
                #prior value (static)
                prior_in = self.prior_in(enn_input)
                prior_1 = self.prior_1(prior_in)
                prior = self.prior_out(prior_1)
                
        prior_bmm = torch.bmm(torch.transpose(prior, 1, 2), self.z_unsqueeze)
        prior_output = prior_bmm.squeeze(-1)
        # learnable on features and z concat
        learnable_in = self.enn_learnable_in(enn_input)
        learnable_1 = self.enn_learnable_1(learnable_in)
        learnable = self.enn_learnable_out(learnable_1)
        learnable_bmm = torch.bmm(torch.transpose(learnable, 1, 2), self.z_unsqueeze)
        learnable_out = learnable_bmm.squeeze(-1)

        # add all networks together to get final action values
        self.critic_value = base_out + learnable_out + prior_output

        self.step_number += 1
            
            
        return action_logits, state
        
    @override(TorchModelV2)
    def value_function(self):
        return self.critic_value.squeeze(-1)

    @override(TorchModelV2)
    def custom_loss(self, policy_loss, sample_batch):

        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_obs = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]

        #base target value
        next_base_in = self.base_in(next_obs)
        next_base_1 = self.base_1(next_base_in)
        next_base_1_detached_unsqueeze = torch.unsqueeze(next_base_1, 1).detach()
        next_base_out = self.base_out(next_base_1).detach()

        #enn target value
        next_enn_input = torch.cat((self.z_unsqueeze, next_base_1_detached_unsqueeze.expand(-1, self.z_dim, -1)), dim=2)
        next_learnable_in = self.enn_learnable_in(next_enn_input)
        next_learnable_1 = self.enn_learnable_1(next_learnable_in)
        next_learnable = self.enn_learnable_out(next_learnable_1)
        next_learnable_bmm = torch.bmm(torch.transpose(next_learnable, 1, 2), self.z_unsqueeze)
        next_learnable_out = next_learnable_bmm.squeeze(-1)

        if self.step_number < 200:
            #prior target value (learnable)
            next_prior_in = self.prior_in(next_enn_input)
            next_prior_1 = self.prior_1(next_prior_in)
            next_prior = self.prior_out(next_prior_1)
        else:
            with torch.no_grad():
                #prior target value (static)
                next_prior_in = self.prior_in(next_enn_input)
                next_prior_1 = self.prior_1(next_prior_in)
                next_prior = self.prior_out(next_prior_1)

        next_prior_bmm = torch.bmm(torch.transpose(next_prior, 1, 2), self.z_unsqueeze)
        next_prior_out = next_prior_bmm.squeeze(-1)

        # add all networks together to get final action values
        next_values = next_base_out + next_learnable_out + next_prior_out
        enn_target = rewards + self.gamma * next_values.detach() * (1 - dones.float())
        enn_loss_square = torch.square(self.critic_value - enn_target)
        enn_loss = torch.mean(enn_loss_square)

        total_loss = [loss + enn_loss for loss in policy_loss]

        if self.step_number % 1_000 == 0:
            print(f"policy loss: {policy_loss} enn loss: {enn_loss}")
        
        return total_loss        


# Register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("ENNModel", EpistemicNNModel)

In [None]:
##### %%time
config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [25_000_000, 0.0002], [30_000_000, 0.0001]],
    # lr = 0.001,
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by = 'norm', 
    train_batch_size = 65_500, 
    sgd_minibatch_size = 12_500,
    grad_clip = 1.0,
    # optimizer = {
    #     'weight_decay': 0.01
    # },
    model={
        'custom_model': 'NLL_ENNModel', #ENNModel NLL_ENNModel
        'fcnet_hiddens': [512, 512],
        'fcnet_activation': 'LeakyReLU',
        'vf_share_layers': False,
        'custom_model_config': {
            'z_dim': 5,
            'num_gaussians': 3,
        },
    }
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 28,
# num_envs_per_worker = 4,
).resources(num_gpus = 1)
#.callbacks(MyCustomCallback
#)
algo = config.build()

num_iterations = 200
rewards = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    rewards.append(result['episode_reward_mean'])
    
ray.shutdown()

In [None]:
policy_weights = algo.get_weights()

In [None]:
prior_weight_keys = [key for key in policy_weights['default_policy'] if 'prior' in key]

In [None]:
prior_weights = {key: policy_weights['default_policy'][key] for key in prior_weight_keys}

In [None]:
prior_weights['prior_1._model.0.weight'][1]