# MuZero: Model-based RL (Part 2)

This is a series of notebooks to learn Muzero, which is a popular model-based reinforcement learning algorithm. 

In part 1, we learned the overview of Muzero as well as Monte Carlo Tree Search (MCTS) to collect training samples through self-play. In this notebook, we will learn the deep learning models used in MuZero.

### Review of three models

As we learned in part 1, Muzero uses three deep learning models to learn the dynamics of the environment as well as the optimal policy. They are:
- Representation model: $s^0 = h_\theta(o_t)$
    - Input: raw state of the current root
    - Output: latent state of the current root
- Dynamics model: $r^k, s^k = g_\theta(s^{k-1}, a^k)$
    - Input: latent state, action to take
    - Output: next latent state, expected immediate reward
- Prediction model: $p^k, v^k = f_\theta(s^k)$
    - Input: latent state
    - Output: policy at the input latent state, expected value at the input latent state
where $t$ is the index for the past and current steps and $k$ is the index for the future steps.

While dynamics model and prediction model used in original Muzero paper trained on multipel output values, we will devide these outputs into a different model to stablize the training process. More specifically, below code models each individual quantity using a separate network using five models.
- Representation model:
    - Input: raw state of the current root
    - Output: latent state of the current root
- Dynamic model:
    - Input: latent state, action to take
    - Output: next latent state
- Reward model:
    - Input: latent state, action to take
    - Output: expected immediate reward
- Value model:
    - Input: latent state
    - Output: expected value at the input latent state
- Policy model:
    - Input: latent state
    - Output: policy at the input latent state


The combination of the dynamics model and reward model behaves like the dynamics model of the original Muzero paper. The combination of the value model and policy model behaves like the prediction model of the original Muzero paper.

Muzero learns all of these models at the same time. The loss function is defined as the sum of three errors:
- Policy loss: the error between the actions predicted by the policy $p^k_t$ and by the search policy $\pi_{t+k}$. 
- Value loss: the error between the value function $v^k_t$ and the value target, $z_{t+k}$
- Reward loss: the error between the predicted immediate reward $r^k_t$ and the observed immediate reward $u_{t+k}$

With the sum of three loss values, MuZero runs optimizer and gradient descent as we do for typical deep learning model training. 
Let's review each model one by one.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np


As in part1, we assume using the CartPole-v0 environment in Gymnasium. The environment has two potential actions and each state is represented by a vector of four values (cart position, cart velocity, pole angle, and pole angular velocity). 

In [None]:
state_shape = 4
action_size = 2


### Representation network

We first define a representation network. It receives a raw state of the current root node and returns its latent state. Thus, the input shape is the state shape. In the architecture used in the MuZero paper, the input will be transformed into the shape of the hidden neuron size. The outputs from hidden neurons are then transformed into the shape of embedding size to get the output latent state. The hidden neuron size and embedding size are the hyperparameters.

In [None]:
class RepresentationNetwork(nn.Module):
    def __init__(self, input_size, hidden_neurons, embedding_size):
        super(RepresentationNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, embedding_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.layers(x)
    
hidden_neurons = 48
embedding_size = 4
rep_net = RepresentationNetwork(input_size=state_shape, hidden_neurons=hidden_neurons, embedding_size=embedding_size)


### Dynamics network

The dynamic network has a similar architecture to the representation network. But one difference is the input size. The dynamic function receives the latent state and action to take as input. In this tutorial, we use one-hot encoding to represent the action to take. For example, when the cart moves left, the action will be represented as [1,0]. On the other hand, when the cart moves right, the action will be represented as [0,1]. We combine these two-dimensional vectors with the embedded latent state. Thus, the input has the shape of embedding size + action size.
The output is the next latent state reached by taking the input action at the input latent state.

In [None]:
class DynamicNetwork(nn.Module):
    def __init__(self, input_size, hidden_neurons, embedding_size):
        super(DynamicNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, embedding_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.layers(x)
    
dyn_net = DynamicNetwork(input_size=embedding_size+action_size, hidden_neurons=hidden_neurons, embedding_size=embedding_size)


### Reward network

The reward network receives the latent state and action to take as the input and returns the predicted immediate reward as the output. In a Cartpole environment, a reward of +1 is granted to the agent at each step while the pole is kept upright. Thus, the predicted immediate reward (output) is a scalar.

In [None]:
class RewardNetwork(nn.Module):
    def __init__(self, input_size, hidden_neurons):
        super(RewardNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, 1)
        )

    def forward(self, x):
        return self.layers(x)
    
rew_net = RewardNetwork(input_size=embedding_size+action_size, hidden_neurons=hidden_neurons)


### Value network

The value network receives the latent state and returns the predicted expected value at the state. Instead of returning the output as a scalar value, MuZero uses an architecture to output multi-dimensional output and then applies an invertible transformation to get the predicted value (scalar). For more detail, please check "Appendix F Network architecture" of [the MuZero paper](https://arxiv.org/pdf/1911.08265#page=14.33) and "Appendix A: Proposition A.2" of [this paper](https://arxiv.org/pdf/1805.11593). 

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self, input_size, hidden_neurons, value_support_size):
        super(ValueNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, value_support_size)
        )

    def forward(self, x):
        return self.layers(x)
    
def value_transform(self, value_support):
    epsilon = 0.001
    value = torch.nn.functional.softmax(value_support)
    value = np.dot(value.detach().numpy(), range(len(value_support)))
    value = np.sign(value) * (
            ((np.sqrt(1 + 4 * epsilon
                * (np.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1
    )
    return value
    
max_value = 200
value_support_size = math.ceil(math.sqrt(max_value)) + 1
val_net = ValueNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, value_support_size=value_support_size)
network_output = val_net(torch.Tensor([1,1,1,1])) # output from network (multi-dimensional)
predicted_value = value_transform(output) # value after applying transformation

### Policy network

Lastly, the policy network receives the hidden state and returns the policy at the input state. This output value is not a probability. MuZero applies a softmax function to this output to get the probability of taking each action.

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, hidden_neurons, action_size):
        super(PolicyNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, action_size)
        )

    def forward(self, x):
        return self.layers(x)
    
pol_net = PolicyNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, action_size=action_size)
policy_logits = val_net(torch.Tensor([1,1,1,1])) # output from network (multi-dimensional)
softmax_policy = torch.nn.functional.softmax(torch.squeeze(policy_logits))


### Initial inference

In part1, we skipped the detail of two functions, initial_inference and recurrent_inference, which were used to run Monte Carlo Tree Search (MCTS). Now, we are ready to cover them. We use initial_inference function to expand the current root node. What this function does is:
- Use representation network to get the latent representation of the current root note
- Use value network to get the expected value at the current latent state
- Use policy network to get the policy at the current latent state

In the below implementation, the InitialModel class integrates these three steps. Thus, in initial_inference function, we create the InitialModel object and use this to return the transformed scalar value, immediate reward (always set as 0 for the root state), policy before applying sigmoid transformation, and latent representation of the root state.


In [None]:
class InitialModel(nn.Module):
    def __init__(self, representation_network, value_network, policy_network):
        super(InitialModel, self).__init__()
        self.representation_network = representation_network
        self.value_network = value_network
        self.policy_network = policy_network

    def forward(self, state):
        hidden_representation = self.representation_network(state)
        value = self.value_network(hidden_representation)
        policy_logits = self.policy_network(hidden_representation)
        return hidden_representation, value, policy_logits


def initial_inference(state):
    rep_net = RepresentationNetwork(input_size=state_shape, hidden_neurons=hidden_neurons, embedding_size=embedding_size)
    val_net = ValueNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, value_support_size=value_support_size)
    pol_net = PolicyNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, action_size=action_size)
    
    initial_model = InitialModel(rep_net, val_net, pol_net)
    hidden_representation, value, policy_logits = initial_model(state)
    return value_transform(value), 0, policy_logits, hidden_representation


### Recurrent inference

Another function we used in MCTS is recurrent_inference function. This function is used to run the mental simulation during MCTS. What this function does is:
- Use the dyamic network to get the next latent state when taking the input action at the input state
- Use the reward network to get the immediate reward when taking the input action at the input state
- Use the value network to get the expected value at the next latent state
- Use the policy network to get the policy at the next latent state

In the below implementation, the RecurrentModel class integrates these four steps. Thus, in recurrent_inference function, we create the RecurrentModel object and use this to return the transformed scalar value, immediate reward, policy before applying a sigmoid function, and latent representation of the next state.

In [None]:
class RecurrentModel(nn.Module):
    def __init__(self, dynamic_network, reward_network, value_network, policy_network):
        super(RecurrentModel, self).__init__()
        self.dynamic_network = dynamic_network
        self.reward_network = reward_network
        self.value_network = value_network
        self.policy_network = policy_network

    def forward(self, state_with_action):
        hidden_representation = self.dynamic_network(state_with_action)
        reward = self.reward_network(state_with_action)
        value = self.value_network(hidden_representation)
        policy_logits = self.policy_network(hidden_representation)
        return hidden_representation, reward, value, policy_logits

    
def hidden_state_with_action(self, hidden_state, action):
    """
    Merge hidden state and one hot encoded action
    """
    hidden_state_with_action = torch.concat(
        (hidden_state, torch.tensor(self._action_to_one_hot(action, self.action_size))[0]), axis=0)
    return hidden_state_with_action

def recurrent_inference(hidden_state, action):
    dyn_net = DynamicNetwork(input_size=embedding_size+action_size, hidden_neurons=hidden_neurons, embedding_size=embedding_size)
    rew_net = RewardNetwork(input_size=embedding_size+action_size, hidden_neurons=hidden_neurons)
    val_net = ValueNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, value_support_size=value_support_size)
    pol_net = PolicyNetwork(input_size=embedding_size, hidden_neurons=hidden_neurons, action_size=action_size)
    
    state_with_action = hidden_state_with_action(hidden_state, action)
    recurrent_model = RecurrentModel(dyn_net, rew_net, val_net, pol_net)
    hidden_representation, reward, value, policy_logits = recurrent_model(state_with_action)
    return value_transform(value), reward, policy_logits, hidden_representation


### Summary

In this notebook, we reviewed the deep neural networks used in MuZero. These networks are trained using the data collected with MCTC, which is the process we learned in part1. In the next notebook, we combine part1 and part2, and then add a few more elements to complete MuZero framework.
