In [None]:
import ray
import time
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray import tune, air
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
from gymnasium.spaces import Box, Discrete
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.rllib.algorithms.ppo as dqn
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import gymnasium as gym
import matplotlib.pyplot as plt
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
import numpy as np
import pandas as pd
from ray import tune
import math
from torch.distributions.normal import Normal
from ray.tune.schedulers import ASHAScheduler
from ray.tune.registry import register_env
from torch.utils.tensorboard import SummaryWriter
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os 
import random
import shutil
path = os.getcwd()

In [None]:
torch, nn = try_import_torch()
ray.init()

In [None]:
class SimpleContextualBandit(gym.Env):
    def __init__(self, config=None):
        self.action_space = Discrete(3)
        self.observation_space = Box(low=-1.0, high = 1.0, shape=(2,))
        self.cur_context = None

    def reset(self, *, seed = None, options = None):
        self.cur_context = random.choice([-1.0, 1.0])
        return np.array([self.cur_context, -self.cur_context]), {}

    def step(self, action):
        rewards_for_context = {
            -1.0: [-10, 0, 10],
            1.0: [10, 0, -10],
        }
        reward = rewards_for_context[self.cur_context][action]
        return (
            np.array([-self.cur_context, self.cur_context]),
            reward,
            True,
            False,
            {'regret': 10 - reward},
        )

register_env('SimpleContextualBandit', SimpleContextualBandit)

In [None]:
class BernoulliBandit(gym.Env):
    def __init__(self, config=None):
        self.action_space = Discrete(5)
        self.observation_space = Box(low=0.0, high = 1.0)
        self.num_actions = 5
        self.last_action_reward = None

    def reset(self, *, seed = None, options = None):
        self.last_action_reward = random.choice([0,1])
        return np.array([0]), {}

    def step(self, action):
        assert self.action_space.contains(action), "Invalid action selection"
        if action < self.num_actions - 1:
            reward = np.random.binomial(1,0.5)
        else:
            reward = self.last_action_reward

        print(f'regret: {1 - reward if action < self.num_actions - 1 else 1 - self.last_action_reward}')

        return (
            np.array([0]), 
            reward, 
            True, 
            False,
            {'regret': 1 - reward if action < self.num_actions - 1 else 1 - self.last_action_reward},
        )

register_env('BernoulliBandit', BernoulliBandit)

In [None]:
class ENNDQNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(ENNDQNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        
        nn.Module.__init__(self)
        #gpu nonsense
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        hidden_layer0 = model_config['fcnet_hiddens'][0]
        hidden_layer1 = model_config['fcnet_hiddens'][1]
        enn_layer = 15

        #object instance variables
        self.z_dim = 5
        self.mean = 0.0
        self.std = 1.0
        self.step_number = 0
        self.z_indices = None
        self.gamma = 0.99
        self.initializer = torch.nn.init.xavier_normal_
        self.activation_fn = model_config['fcnet_activation']
        self.num_atoms = 1

        self.distribution = Normal(torch.full((self.z_dim,), self.mean), torch.full((self.z_dim,), self.std))

        self.action_space_size = 4
        self.map_size = 8
        
        self.action_space = action_space
        self.q_network = TorchFC(obs_space, action_space, self.action_space_size*self.num_atoms, model_config, name + "_q_testing")
        
        self.qnetwork_in = SlimFC(obs_space.shape[0], hidden_layer0, initializer=self.initializer, activation_fn=self.activation_fn)
        self.qnetwork_1 = SlimFC(hidden_layer0, hidden_layer1, initializer=self.initializer, activation_fn=self.activation_fn)
        self.qnetwork_out = SlimFC(hidden_layer1, self.action_space_size*self.num_atoms, initializer=self.initializer, activation_fn=self.activation_fn)
        
        self.enn_learnable_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.enn_learnable_out = SlimFC(enn_layer, self.action_space_size*self.num_atoms, initializer=self.initializer, activation_fn=None)

        self.prior_in = SlimFC(hidden_layer1 + 1, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_1 = SlimFC(enn_layer, enn_layer, initializer=self.initializer, activation_fn=self.activation_fn)
        self.prior_out = SlimFC(enn_layer, self.action_space_size*self.num_atoms, initializer=self.initializer, activation_fn=None)
        

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):

        obs = input_dict["obs_flat"].float()
        action_logits, _ = self.q_network(input_dict, state, seq_lens)

        batch_size = obs.shape[0]

        base_in = self.qnetwork_in(obs)
        base_1 = self.qnetwork_1(base_in).to(self.device)
        self.base_1_detached = torch.unsqueeze(base_1, 1).detach()
        base_out = self.qnetwork_out(base_1)
        self.base_value, self.base_indice = torch.max(base_out, dim=-1, keepdim=True)

        self.z_indices = self.distribution.sample((batch_size,)).to(self.device)
        self.z_unsqueeze = torch.unsqueeze(self.z_indices, -1)
        enn_input = torch.cat((self.z_unsqueeze, self.base_1_detached.expand(-1, self.z_dim, -1)), dim=2)
        # batch_size, hidden_layer + 1, z_dimensions

        if self.step_number < 200:
            #prior value (learnable)
            prior_in = self.prior_in(enn_input)
            prior_1 = self.prior_1(prior_in)
            prior = self.prior_out(prior_1)
        else:
            with torch.no_grad():
                #prior value (static)
                prior_in = self.prior_in(enn_input)
                prior_1 = self.prior_1(prior_in)
                prior = self.prior_out(prior_1)
        
        # learnable on features and z concat
        learnable_in = self.enn_learnable_in(enn_input)
        learnable_1 = self.enn_learnable_1(learnable_in)
        learnable = self.enn_learnable_out(learnable_1)
        learnable_bmm = torch.bmm(torch.transpose(learnable, 1, 2), self.z_unsqueeze)
        learnable_out = learnable_bmm.squeeze(-1)
        #the above returns (batch_size, num_actions)
        #so we select the 'max action' and return the value / index
        self.learnable_value = torch.gather(learnable_out, 1, self.base_indice)

        # prior on features and z concat
        prior_bmm = torch.bmm(torch.transpose(prior, 1, 2), self.z_unsqueeze)
        prior_output = prior_bmm.squeeze(-1)
        self.prior_value = torch.gather(prior_output, 1, self.base_indice)

        self.action_values = base_out + learnable_out + prior_output

        self.step_number += 1
            
        return self.action_values, state
        
    @override(TorchModelV2)
    def value_function(self):
        #this should be the argmax values added together
        return self.learnable_value + self.prior_value + self.base_value


    @override(TorchModelV2)
    def custom_loss(self, policy_loss, sample_batch):
        cur_obs = sample_batch[SampleBatch.CUR_OBS]
        next_obs = sample_batch[SampleBatch.NEXT_OBS]
        rewards = sample_batch[SampleBatch.REWARDS]
        dones = sample_batch[SampleBatch.DONES]


        #target critic value
        next_base_in = self.qnetwork_in(next_obs)
        next_base_1 = self.qnetwork_1(next_base_in)
        next_base_1_detached = torch.unsqueeze(next_base_1, 1).detach()
        next_base_out = self.qnetwork_out(next_base_1)
        next_base_value, next_base_indice = torch.max(next_base_out, dim=-1, keepdim=True)

        #enn target value
        next_enn_input = torch.cat((self.z_unsqueeze, next_base_1_detached.expand(-1, self.z_dim, -1)), dim=2)
        next_learnable_in = self.enn_learnable_in(next_enn_input)
        next_learnable_1 = self.enn_learnable_1(next_learnable_in)
        next_learnable = self.enn_learnable_out(next_learnable_1)
        next_learnable_bmm = torch.bmm(torch.transpose(next_learnable, 1, 2), self.z_unsqueeze)
        next_learnable_out = next_learnable_bmm.squeeze(-1)
        next_learnable_value = torch.gather(next_learnable_out, 1, self.base_indice)

        #with current setup - one iteration is 480 step numbers
        if self.step_number < 200:
            #prior target value (learnable)
            next_prior_in = self.prior_in(next_enn_input)
            next_prior_1 = self.prior_1(next_prior_in)
            next_prior = self.prior_out(next_prior_1)
        else:
            with torch.no_grad():
                #prior target value (static)
                next_prior_in = self.prior_in(next_enn_input)
                next_prior_1 = self.prior_1(next_prior_in)
                next_prior = self.prior_out(next_prior_1)

        next_prior_bmm = torch.bmm(torch.transpose(next_prior, 1, 2), self.z_unsqueeze)
        next_prior_out = next_prior_bmm.squeeze(-1)
        next_prior_value = torch.gather(next_prior_out, 1, self.base_indice)
        
        #enn total (enn = prior(critic_features, z) + enn_learnable(critic_features, z))
        enn_out = self.learnable_value + self.prior_value
        next_enn_out = next_learnable_value + next_prior_value
        #td loss for enn network minus the base network
        enn_target = rewards + self.gamma * next_enn_out.detach() * (1 - dones.float())
        enn_loss_square = torch.square(enn_out - enn_target)
        enn_loss = torch.mean(enn_loss_square)
        #detach target from being udpated / base network TD loss
        # base_target = rewards + self.gamma * next_base_value.detach() * (1 - dones.float())
        # base_square = torch.square(self.base_value - base_target)
        # base_loss = torch.mean(base_square)

        # l2_lambda = self.lambda_coeff / max(1, self.total_data_seen)

        # l2_reg = torch.tensor(0., device=self.device)
        # for name, param in self.named_parameters():
        #     if 'prior' or 'enn' in name:
        #         l2_reg += torch.norm(param)
        # l2_loss = l2_lambda * l2_reg

        total_loss = [loss + enn_loss for loss in policy_loss]

        return total_loss



ModelCatalog.register_custom_model("ENNDQNModel", ENNDQNModel)

In [None]:
gym.make('FrozenLake-v1', is_slippery=True, map_name='4x4')

In [None]:
%%time
config = DQNConfig().training(
    num_atoms = 1,
    v_min = -10.0,
    v_max = 10.0,
    noisy = False,
    # sigma0 = 0.5,
    #sets initial weights for noisy nets
    dueling = False,
    double_q = False,
    n_step = 1, 
    hiddens = (),
    target_network_update_freq = 10,
    num_steps_sampled_before_learning_starts = 20,
    replay_buffer_config = {
        'capacity': 1000
    },
    #IMPORTANT: need hiddens = [] and dueling = False for parametric action spaces
    before_learn_on_batch = None,
    training_intensity = None,
    td_error_loss_fn = 'huber',
    lr = 0.0003,
    #td error loss is ignored if num_atoms > 1
    categorical_distribution_temperature = 1.0,
    optimizer = {
        'weight_decay': 0.01
    },
    #temperature in the range of [0,1] which affects evaluation
    model={
        'custom_model': 'ENNDQNModel',
        'no_final_linear': False,
        'fcnet_hiddens': [64,64],
        'fcnet_activation': 'relu',
        'vf_share_layers': False
    }
).environment(
    env='FrozenLake-v1',
).rollouts(
    num_rollout_workers = 28,
).resources(
    num_gpus = 1
)

algo = config.build()

num_iterations = 500
rewards = []

for i in range(num_iterations):
    result = algo.train()
    print(f"Iteration: {i}, Mean Reward: {result['episode_reward_mean']}")
    rewards.append(result['episode_reward_mean'])
    
ray.shutdown()