In [None]:
import torch 
from rllib.policy import NNPolicy, FelixPolicy
from rllib.value_function import NNValueFunction
from rllib.util import tensor_to_distribution
from rllib.util.neural_networks.utilities import random_tensor, zero_bias, init_head_weight

import matplotlib.pyplot as plt 
from matplotlib import rcParams

# If in your browser the figures are not nicely vizualized, change the following line. 
rcParams['font.size'] = 16

%load_ext autoreload
%autoreload 2

In [None]:
def plot_policy(policy, num_states=10, num_samples=1000):
    dim_state, dim_action = policy.dim_state, policy.dim_action
    fig, ax = plt.subplots(num_states, 2, figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k', sharex='col')
    for i in range(num_states):
        if i == 0:
            state = torch.zeros(dim_state)
        else:
            state = random_tensor(discrete=False, dim=dim_state, batch_size=None)
        out = policy(state)
        print(out[1])

        normal = tensor_to_distribution(out)
        tanh = tensor_to_distribution(out, tanh=True)

        ax[i, 0].hist(normal.sample((1000,)).squeeze().clamp_(-1, 1), density=True)
        ax[i, 1].hist(tanh.sample((1000,)).squeeze(), density=True)
        ax[i, 0].set_xlim([-1.1, 1.1])
        ax[i, 1].set_xlim([-1.1, 1.1])

    ax[0, 0].set_title('TruncNormal')
    ax[0, 1].set_title('Tanh')
    ax[-1, 0].set_xlabel('Action')
    ax[-1, 1].set_xlabel('Action')

    plt.show()

def plot_value_function(value_function, num_samples=1000):
    dim_state = value_function.dim_state
    fig, ax = plt.subplots(1, 1, figsize=(18, 10), dpi= 80, facecolor='w', edgecolor='k')
    state = random_tensor(discrete=False, dim=dim_state, batch_size=num_samples)
    value = value_function(state)
    ax.hist(value.squeeze().detach().numpy(), density=True)
    ax.set_xlabel('Value')

    plt.show()


# Policy Initialization

In [None]:
dim_state, dim_action = 4, 1
policy = FelixPolicy(dim_state, dim_action)
plot_policy(policy)

## NNPolicy with Default Initialization

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, biased_head=True)
plot_policy(policy)

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, biased_head=False)  # Unbias the head?
plot_policy(policy)

## NNPolicy with Zero Bias Initialization

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, biased_head=True)
zero_bias(policy)
plot_policy(policy)

## NNPolicy with Default Head Initialization

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, biased_head=True)
# zero_bias(policy)
init_head_weight(policy)
plot_policy(policy)

## NNPolicy with Zero Bias and Default Weight Initialization

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, biased_head=True)
zero_bias(policy)
init_head_weight(policy)
plot_policy(policy)

## Effect of  Initial Std Dev

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, initial_scale=0.1)
zero_bias(policy)
init_head_weight(policy)  # Increase scale weight
plot_policy(policy)

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, initial_scale=0.01)
zero_bias(policy)
init_head_weight(policy)  # Increase scale weight
plot_policy(policy)

In [None]:
dim_state, dim_action = 4, 1
policy = NNPolicy(dim_state, dim_action, initial_scale=0.5)
zero_bias(policy)
init_head_weight(policy)  # Increase scale weight
plot_policy(policy)

# Value Functions

In [None]:
dim_state, dim_action = 4, 1
value_function = NNValueFunction(dim_state)
zero_bias(value_function)
init_head_weight(value_function)  # Increase scale weight
torch.nn.init.uniform_(value_function.nn.head.bias, 2 + -0.1, 2 + 0.1)
plot_value_function(value_function)