In [1]:
import ray
import time
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune, air
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from gymnasium.spaces import Box
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as ppo
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import gymnasium as gym
import matplotlib.pyplot as plt
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
import numpy as np
import pandas as pd
from ray import tune
import math
from torch.distributions.normal import Normal
from ray.tune.schedulers import ASHAScheduler
from torch.utils.tensorboard import SummaryWriter
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os 
import shutil
path = os.getcwd()

KeyboardInterrupt: 

In [2]:
torch, nn = try_import_torch()
ray.init()

2024-04-29 18:33:39,260	INFO worker.py:1724 -- Started a local Ray instance.


0,1
Python version:,3.10.13
Ray version:,2.9.2




In [3]:
def reference_distribution(mean, std, z_dim, batch_size):
    dimensions = (batch_size, z_dim)
    random_draw = torch.normal(mean, std, dimensions)
    return random_draw


class EpistemicNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(EpistemicNNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)

        self.mean = 0
        self.std = 1
        self.z_dim = 8
        self.critic_output_dims = model_config['fcnet_hiddens'][-1]
        self.actor_output_dims = model_config['fcnet_hiddens'][-1]
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.enn_layer = 25
        self.action_outputs = action_space.shape[0]*2
        self.step_number = 0
        self.z_indices = None
        self.obs_space = obs_space
        self.action_space = action_space
        self.gamma = 0.95

        self.actor = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")
        # self.actor_out = SlimFC(self.actor_output_dims, action_space.shape[0]*2, initializer=self.initializer, activation_fn=self.activation_fn)

        # self.value_branch = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.critic_in = SlimFC(obs_space.shape[0], model_config['fcnet_hiddens'][0], initializer=self.initializer, activation_fn=self.activation_fn)
        self.critic_1 = SlimFC(model_config['fcnet_hiddens'][1], model_config['fcnet_hiddens'][1], initializer=self.initializer, activation_fn=self.activation_fn)
        self.critic_out = SlimFC(self.critic_output_dims, 1, initializer=self.initializer, activation_fn=self.activation_fn)
        
        self.enn_learnable_1 = SlimFC(self.critic_output_dims + self.z_dim, self.enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_out = SlimFC(self.enn_layer, self.z_dim*1, initializer=self.initializer, activation_fn=None)

        self.prior_1 = SlimFC(self.critic_output_dims + self.z_dim, self.enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_out = SlimFC(self.enn_layer, 1, initializer=self.initializer, activation_fn=None)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        action_logits, _ = self.actor(input_dict, state, seq_lens)
        # action_logits_out = self.actor_out(action_logits)
        batch_size = action_logits.shape[0]

        self.z_indices = reference_distribution(self.mean, self.std, self.z_dim, batch_size)

        obs = input_dict["obs_flat"].float()
        
        self.step_number += 1
        if self.step_number < 10:
            #this step counter will count with the dummy loss initialization as well
            critic_0 = self.critic_in(obs)
            critic_features = self.critic_1(critic_0)
            self.critic_out_value = self.critic_out(critic_features)
            z_concat = torch.cat((critic_features, self.z_indices), 1)

            #make passes through ENN and prior to allow params / gradients to be initialzed in the ENN & prior networks
            enn_layer_1_out = self.enn_learnable_1(z_concat)
            enn_layer = self.enn_learnable_out(enn_layer_1_out)
            enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
            z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
            enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
            enn_layer_out = enn_layer_bmm.squeeze(-1)
            
    
            prior_layer_1 = self.prior_1(z_concat)
            prior_layer_out = self.prior_out(prior_layer_1)
            
            
        else:
            #process critic output
            critic_0 = self.critic_in(obs)
            critic_features = self.critic_1(critic_0)
            value_out = self.critic_out(critic_features)
            z_concat = torch.cat((critic_features, self.z_indices), 1)

            #make passes through ENN and prior
            enn_layer_1_out = self.enn_learnable_1(z_concat)
            enn_layer = self.enn_learnable_out(enn_layer_1_out)
            enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
            z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
            enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
            enn_layer_out = enn_layer_bmm.squeeze(-1)

            with torch.no_grad():
                prior_layer_1 = self.prior_1(z_concat)
                prior_layer_out = self.prior_out(prior_layer_1)
    
            # #add prior and ENN to the value network out
            self.critic_out_value = value_out + enn_layer_out + prior_layer_out
            
        return action_logits, state
        
    @override(TorchModelV2)
    def value_function(self):
        #for the first 5 iterations use the base network as the summed value
        #then return the value = base_value + enn layer out  + prior layer out
        return self.critic_out_value.squeeze(-1)

    @override(TorchModelV2)
    def custom_loss(self, policy_loss, sample_batch):

        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_obs = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]


        critic_0 = self.critic_in(cur_obs)
        critic_features = self.critic_1(critic_0)
        z_concat = torch.cat((critic_features, self.z_indices), 1)
        enn_layer_1_out = self.enn_learnable_1(z_concat)
        enn_layer = self.enn_learnable_out(enn_layer_1_out)
        enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
        z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        enn_layer_out = enn_layer_bmm.squeeze(-1)
        
        next_critic_0 = self.critic_in(next_obs)
        next_critic_features = self.critic_1(next_critic_0)
        next_z_concat = torch.cat((next_critic_features, self.z_indices), 1)
        enn_layer_1_out = self.enn_learnable_1(next_z_concat)
        enn_layer = self.enn_learnable_out(enn_layer_1_out)
        enn_layer_unsqeeze_out = torch.unsqueeze(enn_layer, -1)
        z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_layer_bmm = torch.bmm(torch.transpose(enn_layer_unsqeeze_out, 1, 2), z_unsqueeze)
        enn_layer_out_next = enn_layer_bmm.squeeze(-1)
        
        td_target = rewards + self.gamma * enn_layer_out_next * (1 - dones.float())
        td_pred = enn_layer_out

        td_loss = torch.square(td_pred - td_target)
        td_loss = torch.mean(td_loss)

        total_loss = [loss + td_loss for loss in policy_loss]
        
        return total_loss
        


# Register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("ENNModel", EpistemicNNModel)

In [None]:
config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 1.0,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [5_000_000, 0.00025], [15_000_000, 0.00020], [30_000_000, 0.00015]],
    vf_loss_coeff = 1.0,
    #reintroduced clip param
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by = 'norm', 
    train_batch_size = 65_500, 
    sgd_minibatch_size = 4_096,
    grad_clip = 1.0,
    model={
        'custom_model': 'ENNModel',
        'fcnet_hiddens': [512, 512],
        'fcnet_activation': 'LeakyReLU'
    }
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 28,
# num_envs_per_worker = 4,
)
#.callbacks(MyCustomCallback
#)
algo = config.build()

num_iterations = 400
rewards = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    rewards.append(result['episode_reward_mean'])
    
ray.shutdown()


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
  self.actor = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")


Iteration: 0, Mean Reward: -357.3114158111543
Iteration: 1, Mean Reward: -345.5228033713146
Iteration: 2, Mean Reward: -328.80677669521975
Iteration: 3, Mean Reward: -322.05876979480513
Iteration: 4, Mean Reward: -301.8678314750078
Iteration: 5, Mean Reward: -285.8677336159631


In [5]:
policy_weights = algo.get_weights()

In [6]:
prior_weight_keys = [key for key in policy_weights['default_policy'] if 'prior' in key]

In [7]:
prior_weights = {key: policy_weights['default_policy'][key] for key in prior_weight_keys}

In [18]:
prior_weights['prior_1._model.0.weight'][1]

261

In [5]:
results_df_dim_change = pd.DataFrame(rewards)

In [8]:
results_df_dim_change.to_csv(path + '/dim_change.csv')