In [2]:
import ray
import time
from ray.util.state import get_actor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune, air
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from gymnasium.spaces import Box
from ray.rllib.examples.models.centralized_critic_models import YetAnotherTorchCentralizedCriticModel
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as ppo
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import gymnasium as gym
import matplotlib.pyplot as plt
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
import numpy as np
from ray import tune
import math
from torch.distributions.normal import Normal
import pandas as pd
from functools import partial

In [3]:
torch, nn = try_import_torch()
ray.init()

2024-02-21 18:54:29,730	INFO worker.py:1724 -- Started a local Ray instance.


0,1
Python version:,3.10.9
Ray version:,2.9.2


In [4]:
global adder
adder = 1.000001
global num_gaussians
num_gaussians = 3
global parquet_path
parquet_path = "results/logs/parquet_logs.parquet"

class CustomTorchModelMOG(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(CustomTorchModelMOG, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)
        
        self.action_space = action_space
        
        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")
        
        self.critic_fcnet = TorchFC(obs_space, action_space, num_gaussians*3, model_config, name + "_critic")
        
        self.log_step = 0
        
        self.log_data = pd.DataFrame(columns = ['iteration', 'energy_distance','first_term_logged',
                                                'second_term_logged','erf'])
        '''
        Get the output from fcnet, which is tuple of the new state and the the output of
        the FullyConnectedLayer module under ray.rllib.models.torch.fcnet.py
        -->This output is from the inherited modelv2 from ray.rllib.models.modelv2.py
        ---which consists of outputs from calling self.foward as well as the state_out information
        ---which is the current (next) state
        

        -The input_dict contains complex observations (obs_flat, prev_action, prev_reward, is_training,
        -eps_id, agent_id, infos, and t) which will be unpacked by __call__ in modelv2 before being 
        -passed to the foward pass
        
        -State includes the list of state tensors with sizes matching those returned by get_initial_state
        + the batch dimension
        
        -seq_lens is a 1d tensor holding input sequence lengths
        
        -We can access the flattened observation tensor by input_dict['obs_flat']
        
        **Custom models should override forward and not __call__
        **The forward pass has to be performed before the value function call and calling the value
        function does not cause another call on the forward pass
        
        '''
        
    @OverrideToImplementCustomLogic
    def forward(self, input_dict, state, seq_lens):
        elu = torch.nn.ELU()
        
        #actor forward pass
        raw_action_logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        
        #process means and sigmas
        num_actions = self.action_space.shape[0]
        policy_means = raw_action_logits[:, :num_actions]
        policy_simgas_raw = raw_action_logits[:, num_actions:]
        elu = torch.nn.ELU()
        policy_sigmas = elu(policy_simgas_raw) + adder
        action_logits = torch.cat([policy_means, policy_sigmas], dim = -1)

        #critic forward pass
        value_output, _ = self.critic_fcnet(input_dict, state, seq_lens)        
        
        #process gaussian components
        i = num_gaussians
        means = value_output[:, :i]
        self._u = means
        
        sigmas_prev = value_output[:, i:i*2]
        sigmas = elu(sigmas_prev) + adder
        self._sigmas = sigmas
        
        alphas = value_output[:, i*2:]
        alphas = torch.nn.functional.softmax(alphas, dim=-1)
        self._alphas = alphas

        return action_logits, state

    @OverrideToImplementCustomLogic
    def value_function(self):
        #values of the forward pass is simply the gaussian means multiplied by their respective alpha
        multiply = self._u * self._alphas
        values = torch.sum(multiply, dim = 1)
        return values
    
    '''
    To-do: implement SR(lambda) method from GMAC https://arxiv.org/pdf/2105.11366.pdf
    '''

    #method for predicting gaussians based on the current obs and current critic parameters
    def predict_gmm_params(self, cur_obs):
        input_dict = {'obs': cur_obs}
        state = []  
        seq_lens = None
        value_output, _ = self.critic_fcnet(input_dict, state, seq_lens)        
        
        elu = torch.nn.ELU()
        
        i = num_gaussians
        means = value_output[:, :i]
        sigmas_prev = value_output[:, i:i*2]
        sigmas = elu(sigmas_prev) + adder
        alphas = value_output[:, i*2:]
        alphas = torch.nn.functional.softmax(alphas, dim=-1)
        
        return means, sigmas, alphas
    
    #method to generate target parameters (bellman distributional target)
    def generate_target_gmm_params(self, rewards, next_states, dones, gamma=0.99):

        mu_next, sigma_next, w_next = self.predict_gmm_params(next_states)
        mu_target = mu_next + (rewards.unsqueeze(-1) * gamma) * (1 - dones.unsqueeze(-1).float())
        sigma_target = sigma_next
        w_target = w_next
    
        return mu_target.clone().detach(), sigma_target.clone().detach(), w_target.clone().detach()
    
    
    #method for computing the distance between the predicted and the target distribution
    #this also takes into consideration the internal dispersion from the energy distance
    def compute_energy_distance_mog(self, mu_u, sigma_u, w_u, mu_v, sigma_v, w_v):
        """
        Compute the energy distance between three Gaussian Mixture Models (GMMs) analytically,
        including internal dispersion terms.
        """
        N = mu_u.size(0)
        N0 = mu_v.size(0)
        
        delta_U_V = self.calculate_delta(mu_u, sigma_u, w_u, mu_v, sigma_v, w_v) / (N*N0)
        delta_U_U0 = self.calculate_delta(mu_u, sigma_u, w_u, mu_u, sigma_u, w_u) / (N**2)
        delta_V_V0 = self.calculate_delta(mu_v, sigma_v, w_v, mu_v, sigma_v, w_v) / (N0**2)

        energy_distance = 2 * delta_U_V - delta_U_U0 - delta_V_V0
        
        self.log_to_dataframe(energy_distance = energy_distance, first_term_logged = None,
                             second_term_logged = None, erf = None)
    
        return energy_distance
    
    def log_to_dataframe(self, energy_distance, first_term_logged, second_term_logged , erf):
        new_log_entry = pd.DataFrame({
            'iteration': [self.log_step],
            'energy_distance': [energy_distance.mean().item() if energy_distance is not None else np.nan],
            'first_term_logged': [first_term_logged.mean().item() if first_term_logged is not None else np.nan],
            'second_term_logged': [second_term_logged.mean().item() if second_term_logged is not None else np.nan],
            'erf': erf.mean().item() if erf is not None else np.nan,
        })
        
        self.log_data = pd.concat([self.log_data, new_log_entry], ignore_index = True)
        
        if self.log_step % 1000 == 0:
            self.save_to_parquet()
        self.log_step += 1
        
    def save_to_parquet(self):
        self.log_data.to_parquet(parquet_path)
        
        
    def calculate_delta(self, mu_1, sigma_1, w_1, mu_2, sigma_2, w_2):
        """
        Compute delta for given GMM parameters.
        """
        mu_1_expanded = mu_1.unsqueeze(2)
        mu_2_expanded = mu_2.unsqueeze(1)

        sigma_1_expanded = sigma_1.unsqueeze(2)
        sigma_2_expanded = sigma_2.unsqueeze(1)

        w_1_expanded = w_1.unsqueeze(2)
        w_2_expanded = w_2.unsqueeze(1) 

        diff_mu = mu_1_expanded - mu_2_expanded

        sum_vars = sigma_1_expanded ** 2 + sigma_2_expanded ** 2
        exp_term = (-diff_mu**2) / (2*sum_vars)

        sqrt_2_over_pi = torch.sqrt(torch.tensor(2) / torch.tensor(math.pi))

        first_term = sqrt_2_over_pi * torch.sqrt(sum_vars) * torch.exp(exp_term)
        
        erf_term = -diff_mu / (sum_vars * torch.sqrt(torch.tensor(2)))
        erf = torch.special.erf(erf_term)
        second_term = diff_mu * (1 - 2 * erf)

        E_Zij = first_term + second_term

        delta = torch.sum(w_1_expanded * w_2_expanded * E_Zij, dim=(1,2))
        
        if self.log_step % 1 == 0:
            self.log_to_dataframe(energy_distance = None, first_term_logged = first_term, 
                                  second_term_logged = second_term, erf = erf)

        return delta


    @OverrideToImplementCustomLogic
    def custom_loss(self, policy_loss, sample_batch):
        
        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_states = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]

        mu_pred, sigma_pred, w_pred = self.predict_gmm_params(cur_obs)
    
        mu_target, sigma_target, w_target = self.generate_target_gmm_params(rewards, next_states, dones, gamma=0.99)

        energy_distance = self.compute_energy_distance_mog(mu_pred, sigma_pred, w_pred, 
                                                           mu_target, sigma_target, w_target)

        energy_distance_scalar = torch.mean(energy_distance)

        #no need to scale
        custom_loss_component = 1.0 * energy_distance_scalar
        modified_policy_loss = [loss + custom_loss_component for loss in policy_loss]

        return modified_policy_loss

'''
Fixes:
(1)Target sigmas were not strictly positive: added the elu activation along with a positive constant
(2)Fixed the critic network's sigmas to be elu with an added constant isntead of squares
(3)Added the cdf of the normal distribution (second term) to the loss function as per GMAC's paper
(3->)this fixed the negative distance values that were happening since this should not have happened
(3->)due to the way the expectation is taken between dists will always be positive
(4)Added internal dispersion back into the energy distance once the delta method was fixed with (3) above
(4->)This has an appropriate magnitude for the loss compared to the policy loss so no scaling is needed
(5)The result of the four main points above gives results of 1_000+ in the Cheetah-v4 env. after 100 iterations
(6)Updated file savings from csv logs to parquet logs for better effeciency and control
(7)Clamped the exp term -- after logging this term alone was 800_000_000_000 --
(8)Fixed policy action logits to be the correct dimensions and removed (7)
'''
ModelCatalog.register_custom_model("custom_torch_model_mog", CustomTorchModelMOG)


In [None]:
config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
#     kl_coeff = 0.5,
    num_sgd_iter = 30,
    lr = 0.00025, 
    vf_loss_coeff = 1.0,
#     vf_clip_param = 1.0,
    clip_param = 0.5,
    grad_clip_by ='norm', 
    train_batch_size = 19_200, 
    sgd_minibatch_size = 4096,
    grad_clip = 1.0,
    model = {'custom_model': 'custom_torch_model_mog', 
    'vf_share_layers': False, 'fcnet_hiddens': [1024,1024],'fcnet_activation': 'silu'}
).environment(env = 'HalfCheetah-v4'
).rollouts(
num_rollout_workers = 2,
num_envs_per_worker = 4,
)

num_iterations = 500

algo = config.build()
rewards = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    rewards.append(result['episode_reward_mean'])
    
ray.shutdown()
