In [1]:
import os
import ray
import torch
import heapq
import torch.nn as nn
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC

In [2]:
path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

2024-08-30 18:29:13,838	INFO worker.py:1752 -- Started a local Ray instance.


0,1
Python version:,3.8.10
Ray version:,2.10.0


[36m(RolloutWorker pid=70985)[0m name and param: _value_branch_separate.0._model.0.weight and torch.Size([16, 17])
[36m(RolloutWorker pid=70985)[0m weight summation is: 0.8425783952698112
[36m(RolloutWorker pid=70985)[0m weight summation is: -1.6201322712004185
[36m(RolloutWorker pid=70985)[0m weight summation is: -0.6534757427871227
[36m(RolloutWorker pid=70985)[0m weight summation is: -0.5407587364315987
[36m(RolloutWorker pid=70985)[0m weight summation is: 0.022920441813766956
[36m(RolloutWorker pid=70985)[0m weight summation is: -0.5331066437065601
[36m(RolloutWorker pid=70985)[0m weight summation is: -0.5533240190707147
[36m(RolloutWorker pid=70985)[0m weight summation is: -0.9374191043898463
[36m(RolloutWorker pid=70985)[0m weight summation is: 0.3696219357661903
[36m(RolloutWorker pid=70985)[0m weight summation is: -1.3668912937864661
[36m(RolloutWorker pid=70985)[0m weight summation is: 1.0899480991065502
[36m(RolloutWorker pid=70985)[0m weight summati

In [3]:
class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        means_clamped = torch.clamp(means, -1, 1)
        log_stds_clamped = torch.clamp(log_stds, -10, 0)
        logits = torch.cat((means_clamped, log_stds_clamped), dim=-1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)

        # iterate through the weights for the value network (for now)
        # and gather the weight summations -- using .item() creates a flatten list we can now sum with
        all_hidden_weights = []
        for name, param in self.critic_fcnet.named_parameters():
            if 'value' in name and len(param.shape) == 2:
                print(f"name and param: {name} and {param.shape}")
                for i in range(param.shape[0]):
                    weight_summation = 0
                    for j in range(param.shape[1]):
                        weight = param[i, j].item()
                        # all_hidden_weights.append(abs(weight))
                        weight_summation += weight
                    # this weight length currently gives back 16 (l+1) neurons worth of lengths
                    # or -- it gives 17 weight connections for each of the (l+1) layer'd neurons
                    print(f"weight summation is: {weight_summation}")

        
        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(-1)

# register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("SimpleCustomTorchModel", SimpleCustomTorchModel)

In [4]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 1000, 
    sgd_minibatch_size = 200,
    grad_clip = 1.0,
    model = {'custom_model': 'SimpleCustomTorchModel', 'vf_share_layers': False, 
           'fcnet_hiddens': [16,16],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 20,
).resources(num_gpus = 1
)

algo = config.build()

num_iterations = 1
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()
 

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))

KeyboardInterrupt

