# Muzero Parameters

* 800 simulations per move to pick an action
* During the generation of experience in the board game domains, the same exploration scheme as the one described in AlphaZero.

* Actions are encoded spatially in planes of the same resolution as the hidden state.

## Network Architecture

* prediction function uses the same architecture as alphazero, one or two convolutional layers that preserve the resolution but reduce the number of planes, followed by a fully connected layers to the size of the output.



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class Conv(nn.Module):
    def __init__(self, filters0, filters1, kernel_size, bn=False):
        super().__init__()
        blocks = [nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)]
        if bn:
            blocks.append(nn.BatchNorm2d(filters1))
        self.conv = nn.Sequential(*blocks)
        
    def forward(self, x):
        h = self.conv(x)
        return h

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()
        blocks = [Conv(filters, filters, 3, True),
                  nn.ReLU(),
                  Conv(filters, filters, 3, True)]
        self.conv = nn.Sequential(*blocks)

    def forward(self, x):
        return nn.ReLU()(x + (self.conv(x)))

In [3]:
res = ResidualBlock(12)

In [4]:
x = torch.randn(10,12,64,64)

In [5]:
y = res(x)

In [6]:
y.shape

torch.Size([10, 12, 64, 64])

In [7]:
np.min(x.detach().numpy() == y.detach().numpy())

False

In [4]:
class Representation(nn.Module):
    '''
    Convert observation into hidden abstract state
    '''
    def __init__(self, input_shape,num_filters,n_res_blocks):
        super().__init__()
        self.input_shape = input_shape
        self.board_size = self.input_shape[1] * self.input_shape[2]
        blocks = [Conv(self.input_shape[0], num_filters, 3, bn=True),
                 nn.ReLU()]
        for i in range(n_res_blocks):
            blocks.append(ResidualBlock(num_filters))
        self.layers = nn.Sequential(*blocks)
        
    def forward(self,x):
        return self.layers(x)
    
    def inference(self, x):
        self.eval()
        with torch.no_grad():
            rp = self(torch.from_numpy(x).unsqueeze(0))
        return rp.cpu().numpy()[0]

In [50]:
class Prediction(nn.Module):
    '''
    Dual Network
    Policy and value prediction from hidden abstract state
    '''
    def __init__(self, 
                 input_shape,
                 num_filters,
                 policy_size=10,
                 n_res_blocks=2):
        
        super().__init__()
        self.input_shape = input_shape
        self.board_size = self.input_shape[1] * self.input_shape[2]
        
        # Main convloutional block
        blocks = [Conv(self.input_shape[0], num_filters, 3, bn=True),nn.ReLU()]        
        for i in range(n_res_blocks):
            blocks.append(ResidualBlock(num_filters))
        blocks.append(nn.AvgPool2d(kernel_size=self.input_shape[1:]))
        self.conv_layers = nn.Sequential(*blocks)
        
        
        # Policy head
        self.policy_head = nn.Sequential( nn.Linear(num_filters,policy_size),
                                          nn.Softmax() )
                      
        # Value head
        self.value_head  = nn.Sequential( nn.Linear(num_filters,1),
                                          nn.Tanh() )
                              
    def forward(self,s):
        h = torch.squeeze(self.conv_layers(s))
        return self.policy_head(h), self.value_head(h) 
#         return h
    def inference(self, s):
        self.eval()
        with torch.no_grad():
            p, v = self(torch.from_numpy(s).unsqueeze(0))
        return p.cpu().numpy()[0], v.cpu().numpy()[0][0]


In [19]:
class Dynamics(nn.Module):
    '''
    Transition of hidden abstract state 
    '''
    def __init__(self, 
                 state_shape,
                 act_shape,
                 num_filters,
                 n_res_blocks=2):
        super().__init__()
        
        self.state_shape = state_shape
        self.act_shape = act_shape
        
        blocks = [Conv(state_shape[0]+act_shape[0], num_filters, 3, bn=True),
                 nn.ReLU()]
        for i in range(n_res_blocks):
            blocks.append(ResidualBlock(num_filters))
        self.layers = nn.Sequential(*blocks)
        
    def forward(self,s,a):
        h = torch.cat([s,a],dim=1)
        return self.layers(h)

    def inference(self, s, a):
        self.eval()
        with torch.no_grad():
            s = self(torch.from_numpy(s).unsqueeze(0), torch.from_numpy(a).unsqueeze(0))
        return s.cpu().numpy()[0]


In [55]:
x = torch.randn(33,12,64,64)

In [56]:
rep = Representation(x.shape[1:],20,2)

In [57]:
s = rep(x)

In [58]:
s.shape

torch.Size([33, 20, 64, 64])

In [59]:
dyn = Dynamics(s.shape[1:],s.shape[1:],40)

In [60]:
s_ = dyn(s,s)

In [61]:
s_.shape

torch.Size([33, 40, 64, 64])

In [62]:
pre = Prediction(s.shape[1:],20)

In [63]:
p,v = pre(s)

In [64]:
p.shape,v.shape

(torch.Size([33, 10]), torch.Size([33, 1]))