In [None]:
import os
import ray
import time
import math
import shutil
import numpy as np
import pandas as pd
from ray import tune
from MoG_module import CriticMoG
from ray.train.torch import TorchTrainer
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ENNWrapper_mog_auto_loss import ENNWrapper
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC

path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

In [None]:
class CustomTorchModelMOG(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(CustomTorchModelMOG, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        torch.autograd.set_detect_anomaly(True)

        self.gamma = 0.99
        self.step_number = 0
        self.activation_fn = model_config['fcnet_activation']
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.base_critic_network = CriticMoG(obs_space = obs_space, num_gaussians = 2, 
                                        hidden_layer_dims = 2048, num_layers = 2, 
                                        activation = self.activation_fn)
        self.actor_network = TorchFC(obs_space, action_space, action_space.shape[0]*2, 
                                      model_config, name + "_actor")
        self.critic_network = ENNWrapper(base_network = self.base_critic_network, z_dim = 5, enn_layer = 50,
                                     activation = self.activation_fn)
        
    @OverrideToImplementCustomLogic
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs_flat'].float()
        batch_size = obs.shape[0]
        # actor forward pass
        raw_action_logits, _ = self.actor_network(input_dict, state, seq_lens)
        # critic forward pass for MoG network
        self.raw_critic_output, _ = self.critic_network(input_dict, state, seq_lens)
        self.step_number += 1
        
        return raw_action_logits, state

    @OverrideToImplementCustomLogic
    def value_function(self):
        return self.critic_network.value_function()

    @OverrideToImplementCustomLogic
    def custom_loss(self, policy_loss, sample_batch):
        critic_loss = self.critic_network.enn_loss(sample_batch = sample_batch, handle_loss = True, 
                                                gamma=self.gamma)
        total_loss = [loss + critic_loss for loss in policy_loss]
        return total_loss


ModelCatalog.register_custom_model("custom_torch_model_mog", CustomTorchModelMOG)

In [None]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 19_200, 
    sgd_minibatch_size = 4_096,
    grad_clip = 1.0,
    model = {'custom_model': 'custom_torch_model_mog', 'vf_share_layers': False, 
           'fcnet_hiddens': [2048,2048],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 20,
).resources(num_gpus = 1
)

algo = config.build()

num_iterations = 200
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()
