In [None]:
import ray
import time
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune, air
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from gymnasium.spaces import Box
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as ppo
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import gymnasium as gym
import matplotlib.pyplot as plt
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
import numpy as np
import pandas as pd
from ray import tune
import math
from torch.distributions.normal import Normal
from ray.tune.schedulers import ASHAScheduler
from torch.utils.tensorboard import SummaryWriter
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os 
import shutil
path = os.getcwd()

In [None]:
torch, nn = try_import_torch()
ray.init()

In [None]:
class EpistemicNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(EpistemicNNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)

        self.mean = 0.0
        self.std = 1.0
        self.z_dim = 10
        self.critic_output_dims = model_config['fcnet_hiddens'][-1]
        self.actor_output_dims = model_config['fcnet_hiddens'][-1]
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.enn_layer = 50
        self.action_outputs = action_space.shape[0]*2
        self.step_number = 0
        self.z_indices = None
        self.obs_space = obs_space
        self.action_space = action_space
        self.gamma = 0.99
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.distribution = Normal(torch.full((self.z_dim,), self.mean), torch.full((self.z_dim,), self.std))

        self.actor = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")

        self.critic_in = SlimFC(obs_space.shape[0], model_config['fcnet_hiddens'][0], initializer=self.initializer, activation_fn=self.activation_fn)
        self.critic_1 = SlimFC(model_config['fcnet_hiddens'][0], self.critic_output_dims, initializer=self.initializer, activation_fn=self.activation_fn)
        self.critic_out = SlimFC(self.critic_output_dims, 1, initializer=self.initializer, activation_fn=None)
        
        self.enn_learnable_1 = SlimFC(self.critic_output_dims + self.z_dim, self.enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_out = SlimFC(self.enn_layer, self.z_dim*1, initializer=self.initializer, activation_fn=None)

        self.prior_1 = SlimFC(self.critic_output_dims + self.z_dim, self.enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_out = SlimFC(self.enn_layer, self.z_dim*1, initializer=self.initializer, activation_fn=None)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        action_logits, _ = self.actor(input_dict, state, seq_lens)
        # action_logits_out = self.actor_out(action_logits)
        batch_size = action_logits.shape[0]

        self.z_indices = self.distribution.sample((batch_size,)).to(self.device)

        obs = input_dict["obs_flat"].float()
        
        self.step_number += 1
        if self.step_number < 50:
            #this step counter will count with the dummy loss initialization as well
            critic_0 = self.critic_in(obs)
            critic_features = self.critic_1(critic_0)
            self.critic_out_value = self.critic_out(critic_features)
            critic_features_detached = critic_features.detach()
            z_concat = torch.cat((critic_features_detached, self.z_indices), 1)

            #make passes through ENN and prior to allow params / gradients to be initialzed in the ENN & prior networks
            enn_layer_1_out = self.enn_learnable_1(z_concat)
            enn_layer = self.enn_learnable_out(enn_layer_1_out)
            enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
            z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
            enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
            enn_layer_out = enn_layer_bmm.squeeze(-1)
            
    
            prior_layer_1 = self.prior_1(z_concat)
            prior_layer = self.prior_out(prior_layer_1)
            prior_layer_unsqeeze_out = torch.unsqueeze(prior_layer, -1)
            prior_layer_bmm = torch.bmm(torch.transpose(prior_layer_unsqeeze_out, 1, 2), z_unsqueeze)
            prior_layer_out = prior_layer_bmm.squeeze(-1)
            
            
        else:
            #process critic output
            critic_0 = self.critic_in(obs)
            critic_features = self.critic_1(critic_0)
            value_out = self.critic_out(critic_features)
            critic_features_detached = critic_features.detach()
            z_concat = torch.cat((critic_features_detached, self.z_indices), 1)

            #make passes through ENN and prior
            enn_layer_1_out = self.enn_learnable_1(z_concat)
            enn_layer = self.enn_learnable_out(enn_layer_1_out)
            enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
            z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
            enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
            enn_layer_out = enn_layer_bmm.squeeze(-1)

            #after capturing initial uncertainty pass features / z through the prior without learning to maintain this initial uncertainty
            with torch.no_grad():
                prior_layer_1 = self.prior_1(z_concat)
                prior_layer = self.prior_out(prior_layer_1)
                prior_layer_unsqeeze_out = torch.unsqueeze(prior_layer, -1)
                prior_layer_bmm = torch.bmm(torch.transpose(prior_layer_unsqeeze_out, 1, 2), z_unsqueeze)
                prior_layer_out = prior_layer_bmm.squeeze(-1)
    
            # #add prior and ENN to the value network out
            self.critic_out_value = value_out + enn_layer_out + prior_layer_out
            
        return action_logits, state
        
    @override(TorchModelV2)
    def value_function(self):
        #for the first 10 steps use the base network as the summed value
        #then return the value = base_value + enn layer out  + prior layer out
        return self.critic_out_value.squeeze(-1)

    @override(TorchModelV2)
    def custom_loss(self, policy_loss, sample_batch):

        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_obs = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]


        critic_0 = self.critic_in(cur_obs)
        critic_features = self.critic_1(critic_0)
        value_out = self.critic_out(critic_features)
        
        critic_features_detached = critic_features.detach()
        z_concat = torch.cat((critic_features_detached, self.z_indices), 1)
        enn_layer_1_out = self.enn_learnable_1(z_concat)
        enn_layer = self.enn_learnable_out(enn_layer_1_out)
        enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
        z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        enn_layer_out = enn_layer_bmm.squeeze(-1)

        prior_layer_1 = self.prior_1(z_concat)
        prior_layer = self.prior_out(prior_layer_1)
        prior_layer_unsqeeze_out = torch.unsqueeze(prior_layer, -1)
        prior_layer_bmm = torch.bmm(torch.transpose(prior_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        prior_layer_out = prior_layer_bmm.squeeze(-1)
        
        next_critic_0 = self.critic_in(next_obs)
        next_critic_features = self.critic_1(next_critic_0)
        next_critic_features_detached = next_critic_features.detach()
        next_z_concat = torch.cat((next_critic_features_detached, self.z_indices), 1)
        enn_layer_1_out = self.enn_learnable_1(next_z_concat)
        enn_layer = self.enn_learnable_out(enn_layer_1_out)
        enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
        z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        enn_layer_out_next = enn_layer_bmm.squeeze(-1)

        next_critic_value_out = self.critic_out(next_critic_features)

        next_prior_layer_1 = self.prior_1(next_z_concat)
        next_prior_layer = self.prior_out(next_prior_layer_1)
        next_prior_layer_unsqeeze_out = torch.unsqueeze(next_prior_layer, -1)
        next_prior_layer_bmm = torch.bmm(torch.transpose(next_prior_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        next_prior_layer_out = next_prior_layer_bmm.squeeze(-1)
        #enn network TD loss
        enn_out = enn_layer_out + prior_layer_out
        next_enn_out = enn_layer_out_next + next_prior_layer_out

        #should the prior network be in this loss to updates params for the first xx steps??
        td_target = rewards + self.gamma * next_enn_out * (1 - dones.float())

        td_loss_square = torch.square(enn_out - td_target)
        td_loss = torch.mean(td_loss_square)
        #critic base network TD loss
        td_target_base = rewards + self.gamma * next_critic_value_out * (1 - dones.float())
        td_base_square = torch.square(value_out - td_target_base)
        td_base_loss = torch.mean(td_base_square)

        total_loss = sum(policy_loss) + td_loss + td_base_loss
        
        return total_loss


    # @override(TorchModelV2)
    # can't do this within a model class -- happens within the policy object
    # def optimizer(self):
    #     params_with_decay = []
    #     params_without_decay = []
    #     for name, param in self.named_parameters():
    #         if "enn" in name:
    #             params_with_decay.append(param)
    #         else:
    #             params_without_decay.append(param)
    #     return optim.Adam([
    #         {'params': params_with_decay, 'weight_decay': 0.01},
    #         #this will add L2 regularization only for the "enn" part of the gradient graph
    #         {'params': params_without_decay, 'weight_decay': 0},
    #     ])
        


# Register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("ENNModel", EpistemicNNModel)


''' FIXES
-Changed how the prior network functioned by implementing an if statement under 10 steps
--This ensures that the stop gradient is used after step 10 to encapsulate the "uncertainty" of pre-training
-Fixed the learnable ENN dimensions to output (Dz X C).T * Dz(vector)
-Add loss function from page 4 in "Approximate Thompson Sampling via Epistemic Neural Networks" (quadractic TD loss)
-Had to break down the critic branch to be multiple SlimFC's due to needing it in the custom loss where we do not have access to input_dict
-Updated the draw from the Pz distribution to not inidate a new distribution every forward pass
--i.e. create a Pz in the beginning and draw from the same one throughout training
-Added the prior function to the loss function to encapsulate all of the ENN
-Added L2 regularization (weight_decay for the optimizer step for the ENN network only)
-Fixed prior network ENN to output correct dimensions (Dz X C).T * Dz(vector) like the ENN learnable in forward and loss function (including "next")
-Critic output to have no activation (changed from self.activation_fn --> None)
-Detach the features from the epinet's gradient map
-Added base network TD loss to custom loss


OBSERVATIONS:
-The first trial seemed to work the best so far -- this had no loss function, incorrect dim outputs (only hiddens[-1] x 1)
-Training became exponentially slow while adding the loss function


THOUGHTS:
-Should I only do weight_decay for the ENN and not over all networks? answer: just the ENN learnable
-Are the dimensions logical for the step (Dz X C).T * Dz(vector)?     answer: yes -- this is correct where C is the mean value (or it could be the MoGs)
-What slowed down computation so much while adding the loss?          answer: (maybe) -- I had the ENN networks backpropagating into the base network
-Check over the loss function                                         answer: checked and correct per page 4 within ARXIV: 2302.09205 TD loss for RL
-"we optimize L through SGD -- at each gradient step we sample a mini-batch of data D and a batch of indices Z from Pz and we take a gradient step wrt the quadratic TD loss"
'''

In [None]:
config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 1.0,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [5_000_000, 0.00025], [15_000_000, 0.00020], [30_000_000, 0.00015]],
    vf_loss_coeff = 1.0,
    #reintroduced clip param
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by = 'norm', 
    train_batch_size = 65_500, 
    sgd_minibatch_size = 4_096,
    grad_clip = 1.0,
    model={
        'custom_model': 'ENNModel',
        'fcnet_hiddens': [512, 512],
        'fcnet_activation': 'LeakyReLU'
    }
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 28,
# num_envs_per_worker = 4,
)
#.callbacks(MyCustomCallback
#)
algo = config.build()

num_iterations = 400
rewards = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    rewards.append(result['episode_reward_mean'])
    
ray.shutdown()


In [None]:
policy_weights = algo.get_weights()

In [None]:
prior_weight_keys = [key for key in policy_weights['default_policy'] if 'prior' in key]

In [None]:
prior_weights = {key: policy_weights['default_policy'][key] for key in prior_weight_keys}

In [None]:
prior_weights['prior_1._model.0.weight'][1]

In [None]:
results_df_dim_change = pd.DataFrame(rewards)

In [None]:
results_df_dim_change.to_csv(path + '/dim_change.csv')