In [None]:
import os
import ray
import torch
import heapq
import torch.nn as nn
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, normc_initializer

In [None]:
path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

In [None]:
class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        ''''''
        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")
        input_space = obs_space.shape[0]
        self.linear_in = torch.nn.Linear(input_space, 128)
        self.linear_middle = nn.Linear(128, 128)
        self.linear_final = nn.Linear(128, 1)
        self.activation = torch.nn.LeakyReLU()
        ''''''

        ''''''
        self.step = 0
        self.threshold = 300 # every 300 SGD updates, we will reinitialize a neuron
                             # this currently means, every ~2 iterations with 10000, 2000, 30
        self.utility_limit = 0.001
        self.track_ids = {i: [] for i in range(self.linear_middle.in_features)}
        self.utility_middle_in = {i: 0 for i in range(self.linear_middle.in_features)}
        self.utility_middle_out = {i: 0 for i in range(self.linear_middle.out_features)}
        self.decay_rate = 0.99
        ''''''
        
    def check_dict(self, step, ids, neuron_id):
        if step <= self.threshold:
            return False
        # using the .get of a dict gives the associated value for the key
        # so now we have the list of sgd steps that this was marked a "dead neuron"
        tracked_steps = ids.get(neuron_id, [])
        # check the last number of threshold steps to see if any numbers are missing
        # i.e. see if it has been consecutively flagged for the threshold
        for num in range(step - self.threshold, step):
            # if any numbers are missing mark it false
            '''###############################################################################
            the paper says to only update the utilities and age after calculating the loss
            do we have access after the gradients are updated? 
            I think having a custom callback overriding on_postprocess_trajectory will do this
            ###############################################################################'''
            if num not in tracked_steps:
                return False
        return True
    
    def check_neuron(self, neuron_utility, i):
        reint_neuron = False
        if abs(neuron_utility) <= self.utility_limit:
            if i not in self.track_ids:
                self.track_ids[i] = []
            self.track_ids[i].append(self.step)
            # slice the list to only track the last threshold number of sgd updates
            if len(self.track_ids[i]) > self.threshold:
                self.track_ids[i] = self.track_ids[i][-self.threshold:]
            # check the dictionary to see if the tracked neuron (i) has been on the track_ids dict
            # for threshold SGD updates, if so flag it for reinitialization
        if self.check_dict(self.step, self.track_ids, i):
            reint_neuron = True         
        return reint_neuron
        
    def get_utility(self, layer_output):
        dict_checked = False
        # loop through each neuron in the first layer
        for i in range(self.linear_middle.in_features):
            weight_summation = 0
            for k in range(self.linear_in.out_features):
                # get the summed weights of the next layer
                weight = self.linear_middle.weight[i, k].item()
                weight_summation += weight
            prev_utility = self.decay_rate * self.utility_middle_in[i]
            # calculate the weighted input for each neuron of the next layer
            weighted_output = torch.mean(layer_output[:, i]) * weight_summation
            # update utility for each neuron
            '''
            Do we want to sum over the minibatch or average? -- average makes more sense
            to get the average of the contribution of the neuron
            '''
            self.utility_middle_in[i] = prev_utility + (1 - self.decay_rate) * weighted_output.item()
            # if the utility is below the limit then track its id by appending the step to the list
            reint_neuron = self.check_neuron(self.utility_middle_in[i], i)
            if reint_neuron:
                print(f"neuron {i} with utility: {self.utility_middle_in[i]} needs reinitialized")
        # just curious printing -- what is the value closest to zero?
        if self.step % 3_000 == 0:
            min_key = min(self.utility_middle_in, key=lambda k: abs(self.utility_middle_in[k]))
            min_value = self.utility_middle_in[min_key]
            print(f"Currently, neuron {min_key} has the utility closest to zero: {min_value}")
        return None

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        means_clamped = torch.clamp(means, -1, 1)
        log_stds_clamped = torch.clamp(log_stds, -10, 0)
        logits = torch.cat((means_clamped, log_stds_clamped), dim=-1)
        
        obs_in = input_dict['obs_flat']
        critic_in = self.linear_in(obs_in)
        critic_in_a = self.activation(critic_in)
        critic_middle = self.linear_middle(critic_in_a)
        critic_middle_a = self.activation(critic_middle)
        critic_final = self.linear_final(critic_middle_a)
        self.value = self.activation(critic_final)
        
        # this was key to implement -- we do not want gradients to flow this way 
        # also was giving straight nans
        with torch.no_grad():
            self.get_utility(critic_in_a)
        
        # this will check SGD updates
        
        # i.e. 10000 batch size / 2000 minibatch = 5 * 30 num_sgd_iter = 150 sgd updates
        # however, using self.step will return 153 for the 1st iter due to dummy batch initialization
        self.step += 1
        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(-1)
ModelCatalog.register_custom_model("SimpleCustomTorchModel", SimpleCustomTorchModel)

In [None]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 10_000, 
    sgd_minibatch_size = 2_000,
    grad_clip = 1.0,
    model = {'custom_model': 'SimpleCustomTorchModel', 'vf_share_layers': False, 
           'fcnet_hiddens': [128,128],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 20,
).resources(num_gpus = 1
)

algo = config.build()

num_iterations = 100
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()
 