In [1]:
import os
import ray
import time
import math
import shutil
import numpy as np
import pandas as pd
from ray import tune
import gymnasium as gym
from ray import tune, air
import plotly.express as px
import matplotlib.pyplot as plt
from gymnasium.spaces import Box
import plotly.graph_objects as go
from ENNWrapper import ENNWrapper
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as ppo
from ray.train.torch import TorchTrainer
from ray.rllib.models import ModelCatalog
from torch.distributions.normal import Normal
from ray.tune.schedulers import ASHAScheduler
from ray.rllib.algorithms.ppo import PPOConfig
from torch.utils.tensorboard import SummaryWriter
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

2024-06-11 11:11:40,103	INFO worker.py:1724 -- Started a local Ray instance.


0,1
Python version:,3.10.9
Ray version:,2.9.2


In [2]:
class CustomTorchModelMOG(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(CustomTorchModelMOG, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)
        #get layers from config
        torch.autograd.set_detect_anomaly(True)
        enn_layer = 50
        #object instance variables
        self.gamma = 0.99
        self.step_number = 0
        self.action_space = action_space
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.z_dim = model_config['custom_model_config'].get('z_dim', 5)
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.critic_network = TorchFC(obs_space, action_space, 1, 
                                      model_config, name + "_critic")
        self.actor_network = TorchFC(obs_space, action_space, action_space.shape[0]*2, 
                                      model_config, name + "_actor")
        self.enn_wrapper = ENNWrapper(base_network = self.critic_network, z_dim = self.z_dim, 
                                      enn_layer = enn_layer, activation = self.activation_fn, 
                                      initializer = self.initializer)
        
    @OverrideToImplementCustomLogic
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs_flat'].float()
        batch_size = obs.shape[0]
        # actor forward pass
        raw_action_logits, _ = self.actor_network(input_dict, state, seq_lens)
        # use wrapper for critic output / output gradients are blocked - detached so only to update enn
        self.critic_output, _ = self.enn_wrapper(input_dict, state, seq_lens)
        self.step_number += 1
        
        return raw_action_logits, state

    @OverrideToImplementCustomLogic
    def value_function(self):
        return self.critic_output.squeeze(-1)

    @OverrideToImplementCustomLogic
    def custom_loss(self, policy_loss, sample_batch):
        cur_obs = {"obs": sample_batch[SampleBatch.CUR_OBS]}
        next_obs = {"obs": sample_batch[SampleBatch.NEXT_OBS]}
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]
        
        # calling the critic here is the same network that is passed to the ENNWrapper
        current_value, _ = self.critic_network(cur_obs)
        # build critic TD target
        with torch.no_grad():
            next_value, _ = self.critic_network(next_obs)
            next_value = next_value.squeeze(-1) if next_value.shape[-1] == 1 else next_value
        td_target = rewards + self.gamma * next_value * (1 - dones.float())
        # calculate MSE for critic
        critic_loss = torch.nn.functional.mse_loss(current_value.squeeze(-1), td_target)
        # calculate ENN loss
        enn_loss = self.enn_wrapper.enn_loss(next_obs=next_obs["obs"], rewards=rewards, 
                                             dones=dones, gamma=self.gamma)
        total_loss = [loss + (critic_loss + enn_loss) for loss in policy_loss]
        
        if self.step_number % 1_000 == 0:
            print(f"policy loss: {policy_loss} enn loss: {enn_loss} critic loss: {critic_loss}")
    
        return total_loss


ModelCatalog.register_custom_model("custom_torch_model_mog", CustomTorchModelMOG)

In [None]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 19_200, 
    sgd_minibatch_size = 4_096,
    grad_clip = 1.0,
    model = {'custom_model': 'custom_torch_model_mog', 'vf_share_layers': False, 
           'fcnet_hiddens': [2048,2048],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 20,
).resources(num_gpus = 1
)

algo = config.build()

num_iterations = 200
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-06-11 11:11:56,687	INFO trainable.py:164 -- Trainable.setup took 15.318 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
