In [None]:
import os
import ray
import math
import torch
import heapq
import torch.nn as nn
from ray.rllib.models import ModelCatalog
from torch.nn import functional as F, init
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, normc_initializer

In [None]:
path = os.getcwd()
torch, nn = try_import_torch()
ray.init()

In [None]:
class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        ''''''
        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")
        input_space = obs_space.shape[0]
        self.linear_in = torch.nn.Linear(input_space, 128)
        self.linear_middle = nn.Linear(128, 128)
        self.linear_final = nn.Linear(128, 1)
        self.activation = torch.nn.LeakyReLU()
        ''''''

        ''''''
        self.step = 0
        self.threshold = 300 # every 300 SGD updates, we will reinitialize a neuron
                             # this currently means, every ~2 iterations with 10000, 2000, 30
        self.utility_limit = 0.1
        self.track_ids = {i: [] for i in range(self.linear_middle.in_features)}
        self.utility_middle_in = {i: 0 for i in range(self.linear_middle.in_features)}
        self.utility_middle_out = {i: 0 for i in range(self.linear_middle.out_features)}
        self.decay_rate = 0.99
        ''''''
        
    def check_dict(self, step, ids, neuron_id):
        '''
        Args:
            step: the current number of total SGD updates
            ids: dict that hast the neuron_ids and list of steps they have been below the limit
            neuron_id: id of the neuron interested in
        Notes:
            -Extract the steps that the neuron has been below the threshold
                using the .get() for a dict returns the value associated with the key i.e. the list of
                steps that it has been under the limit
            -If at any point there is a break where the neuron was above the threshold,
                break out of the loop for checking and return False
            -If the neuron has been on the list for the threshold number of SGD updates
                return True denoting it will be marked for reinitialization
        '''
        if step <= self.threshold:
            return False
        tracked_steps = ids.get(neuron_id, [])
        for num in range(step - self.threshold, step):
            '''###############################################################################
            the paper says to only update the utilities and age after calculating the loss
            do we have access after the gradients are updated? 
            I think having a custom callback overriding on_postprocess_trajectory will do this
            ###############################################################################'''
            if num not in tracked_steps:
                return False
        return True
    
    def check_neuron(self, neuron_utility, i):
        '''
        Args:
            neuron_utility: the current neuron (i) to check if it needs reinitialized
        Notes:
            - Checks if the neuron should be reinitialized
                i.e. if the neuron is below the threshold limit, add it to a tracked_ids list
                and when this crosses a certain X number of SGD updates consecutively, reinitialize it
            -Return true if the above is true
            -Makes a list within the track_ids dict that will track each step the neuron is below the limit
                this list is sliced to prevent it from becoming exponentially large
        '''
        reint_neuron = False
        if abs(neuron_utility) <= self.utility_limit:
            if i not in self.track_ids:
                self.track_ids[i] = []
            self.track_ids[i].append(self.step)
            if len(self.track_ids[i]) > self.threshold:
                self.track_ids[i] = self.track_ids[i][-self.threshold:]
        if self.check_dict(self.step, self.track_ids, i):
            reint_neuron = True         
        return reint_neuron
        
    def get_utility(self, layer_output):
        '''
        Args:
            layer_output: layer output of interest that will impace the next layer
                this output is after the activation function is applied
        Notes:
            -Loops through each neuron in layer l (k) and sums the weights going to layer l+1 (i)
                where .weight[i, k] is summing the current l layer neurons to the ith neuron of layer l+1
            -Calculates the utility metric of the ith neuron in l+1 by using the mean output of the batch
            -Add the utility metric for each neuron of the l+1 layer to a dict with the associate ith neuron key
            -Checks if the neuron should be reinitialized
                i.e. if the neuron is below the threshold limit, add it to a tracked_ids dict
                and when this crosses a certain X number of SGD updates consecutively, reinitialize it
            -Reinitializes the neuron if the above is true
        '''
        dict_checked = False
        for i in range(self.linear_middle.in_features):
            weight_summation = 0
            for k in range(self.linear_in.out_features):
                weight = self.linear_middle.weight[i, k].item()
                weight_summation += weight
            prev_utility = self.decay_rate * self.utility_middle_in[i]
            weighted_output = torch.mean(layer_output[:, i]) * weight_summation
            self.utility_middle_in[i] = prev_utility + (1 - self.decay_rate) * weighted_output.item()
            reint_neuron = self.check_neuron(self.utility_middle_in[i], i)
            if reint_neuron:
                self.reinitialize_neuron(self.linear_in, self.linear_middle, i)
        return
    
    def reinitialize_neuron(self, input_layer, next_layer, i):
        '''
        Args:
            i: neuron that needs initialization
            input_layer: the current layer that is of interest
            next_layer: the next layer which will be used to set the 
                weights going to it to zero
        Notes:
            -Cannot use torch's Xavier / Kaiming reint since it requires 2-D tenors where
                we have 1-D (hidden_layer_size, single neuron), so currently using a uniform dist.
                -Will look into Xavier for 1-D cases if possible
            -Reinitializes weights going to neuron i with the uniform dist.
            -Resets weights outgoing from neuron i to zero
        '''
        with torch.no_grad():
            init.uniform_(input_layer.weight[i, :], a=0, b=1)
            print(f"reinitialized input to neuron {i}")
            next_layer.weight[:, i].fill_(0)
            print(f"set output weights from neuron {i} to zero")
            

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        means_clamped = torch.clamp(means, -1, 1)
        log_stds_clamped = torch.clamp(log_stds, -10, 0)
        logits = torch.cat((means_clamped, log_stds_clamped), dim=-1)
        
        obs_in = input_dict['obs_flat']
        critic_in = self.linear_in(obs_in)
        critic_in_a = self.activation(critic_in)
        critic_middle = self.linear_middle(critic_in_a)
        critic_middle_a = self.activation(critic_middle)
        critic_final = self.linear_final(critic_middle_a)
        self.value = self.activation(critic_final)
        
        # this was key to implement -- we do not want gradients to flow this way 
        # also was giving straight nans
        with torch.no_grad():
            self.get_utility(critic_in_a)
        
        # this will check SGD updates
        
        # i.e. 10000 batch size / 2000 minibatch = 5 * 30 num_sgd_iter = 150 sgd updates
        # however, using self.step will return 153 for the 1st iter due to dummy batch initialization
        self.step += 1
        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(-1)
ModelCatalog.register_custom_model("SimpleCustomTorchModel", SimpleCustomTorchModel)

In [None]:
%%time

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    vf_loss_coeff = 1.0,
    vf_clip_param = 15.0,
    clip_param = 0.3,
    grad_clip_by ='norm', 
    train_batch_size = 10_000, 
    sgd_minibatch_size = 2_000,
    grad_clip = 1.0,
    model = {'custom_model': 'SimpleCustomTorchModel', 'vf_share_layers': False, 
           'fcnet_hiddens': [128,128],'fcnet_activation': 'LeakyReLU'},
).environment(env='HalfCheetah-v4'
).rollouts(
num_rollout_workers = 20,
).resources(num_gpus = 1
)

algo = config.build()

num_iterations = 100
results = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    results.append(result['episode_reward_mean'])
    
ray.shutdown()
 