In [None]:
import os
import ray
import time
import math
import shutil
import numpy as np
import pandas as pd
from ray import tune
import gymnasium as gym
from ray import tune, air
import plotly.express as px
import matplotlib.pyplot as plt
from gymnasium.spaces import Box
import plotly.graph_objects as go
from ENNWrapper import ENNWrapper
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as ppo
from ray.train.torch import TorchTrainer
from ray.rllib.models import ModelCatalog
from torch.distributions.normal import Normal
from ray.tune.schedulers import ASHAScheduler
from ray.rllib.algorithms.ppo import PPOConfig
from torch.utils.tensorboard import SummaryWriter
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

In [None]:
global adder
adder = 1.000001

class CustomTorchModelMOG(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(CustomTorchModelMOG, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)
        #get layers from config
        torch.autograd.set_detect_anomaly(True)
        hidden_layer0 = model_config['fcnet_hiddens'][0]
        hidden_layer1 = model_config['fcnet_hiddens'][1]
        enn_layer = 50
        #object instance variables
        self.std = 1.0
        self.mean = 0.0
        self.gamma = 0.99
        self.step_number = 0
        self.z_indices = None
        self.step_cut_off = 200
        self.adder = 1.000000001
        # random.seed(self.seed)
        self.elu = torch.nn.ELU() 
        self.action_space = action_space
        self.device = torch.device('cpu')
        self.num_actions = action_space.shape[0]
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.z_dim = model_config['custom_model_config'].get('z_dim', 5)
        self.num_gaussians = model_config['custom_model_config'].get('num_gaussians', 3)
        self.distribution = Normal(torch.full((self.z_dim,), self.mean), 
                                   torch.full((self.z_dim,), self.std))
        self.critic_network = TorchFC(obs_space, action_space, self.num_gaussians*3, 
                                      model_config, name + "_critic")
        self.actor_network = TorchFC(obs_space, action_space, action_space.shape[0]*2, 
                                     model_config, name + "_actor")
        self.enn_wrapper = ENNWrapper(base_network = self.critic_network, 
                                      z_dim = self.z_dim, activation_name = self.activation_fn, 
                                      enn_layer = enn_layer, hidden_layer = hidden_layer0)
        self.to(self.device)
        self.enn_wrapper.to(self.device)
        
    @OverrideToImplementCustomLogic
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs_flat'].float().to(self.device)
        batch_size = obs.shape[0]
        # actor forward pass
        raw_action_logits, _ = self.actor_network(input_dict, state, seq_lens)
        # use wrapper for critic output / output gradients are blocked - detached so only to update enn
        critic_output, self.enn_out = self.enn_wrapper(obs)
        # get MoG components
        means = critic_output[:, :self.num_gaussians]
        sigmas = torch.nn.functional.elu(critic_output[:, self.num_gaussians:self.num_gaussians*2]) + self.adder
        alphas = torch.nn.functional.softmax(critic_output[:, self.num_gaussians*2:], dim=-1)
        self._u, self._sigmas, self._alphas = means, sigmas, alphas
        self.step_number += 1
        
        return raw_action_logits, state

    @OverrideToImplementCustomLogic
    def value_function(self):
        # get value from critic (for distributional it is means * weights)
        multiply = self._u * self._alphas
        self.critic_value = torch.sum(multiply.to(self.device), dim = 1)
        # add in the ENN to value of critic
        self.final_critic_value = self.critic_value.to(self.device) + self.enn_out
        return self.final_critic_value.to(self.device)

    def predict_gmm_params(self, observation):
        output, _ = self.critic_network({"obs": observation}, [], None)
        means = output[:, :self.num_gaussians]
        sigmas_prev = output[:, self.num_gaussians:self.num_gaussians*2]
        sigmas = self.elu(sigmas_prev) + self.adder
        alphas = output[:, self.num_gaussians*2:]
        
        return means, sigmas, alphas
    
    def compute_log_likelihood(self, td_targets, mu_pred, sigma_pred, alphas_pred):
        
        td_targets_expanded = td_targets.unsqueeze(1)
        
        sigma_clamped = torch.clamp(sigma_pred, 1e-9, None)
        # alphas_clamped = torch.clamp(alpha_pred, 1e-30, 1e5)
        
        log_2_pi = torch.log(2*torch.tensor(math.pi))
        
        mus = td_targets_expanded - mu_pred
        
        logp = torch.clamp(-torch.log(sigma_clamped) - .5 * log_2_pi - torch.square(mus) / (2*torch.square(sigma_clamped)), -1e9, None)
        loga = torch.nn.functional.log_softmax(alphas_pred, dim=-1)

        summing_log = -torch.logsumexp(logp + loga, dim=-1)
        
        return summing_log


    @OverrideToImplementCustomLogic
    def custom_loss(self, policy_loss, sample_batch):
        gamma = 0.99
        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_obs = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]

        mu_pred, sigma_pred, w_pred = self.predict_gmm_params(cur_obs)
        mu_target, sigma_target, w_target = self.predict_gmm_params(next_obs)
        w_target = torch.nn.functional.softmax(w_target, dim = -1)

        
        next_state_value = torch.sum(mu_target * w_target, dim = 1).clone().detach()
        td_targets = rewards + gamma * next_state_value * (1 - dones.float())
        
        enn_loss = self.enn_wrapper.enn_loss(next_obs = next_obs, rewards = rewards, 
                                             dones = dones, current_critic_value = self.final_critic_value,
                                             next_critic_value = next_state_value, gamma = gamma)
        
        log_likelihood = self.compute_log_likelihood(td_targets, mu_pred, sigma_pred, w_pred)
        log_likelihood = torch.clamp(log_likelihood, -10, 80)
        nll_loss = torch.mean(log_likelihood)
        
        total_loss = [loss + (nll_loss + enn_loss) for loss in policy_loss]
        
        if self.step_number % 1_000 == 0:
            print(f"policy loss: {policy_loss} enn loss: {enn_loss} nll loss: {nll_loss}")
    
        return total_loss


ModelCatalog.register_custom_model("custom_torch_model_mog", CustomTorchModelMOG)

In [None]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 19_200, 
    sgd_minibatch_size = 4_096,
    grad_clip = 1.0,
    model = {'custom_model': 'custom_torch_model_mog', 'vf_share_layers': False, 
           'fcnet_hiddens': [2048,2048],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
)
#.callbacks(MyCustomCallback
#)

algo = config.build()

num_iterations = 1
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()
