In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
obs = torch.rand(4, 4, 4)
obs = obs.view(-1, 4, 4, 4)
print(obs, obs.shape)

tensor([[[[0.0089, 0.2716, 0.2336, 0.8858],
          [0.0911, 0.1410, 0.0812, 0.3638],
          [0.8331, 0.9065, 0.8579, 0.2625],
          [0.6742, 0.1075, 0.6947, 0.7362]],

         [[0.8590, 0.0500, 0.4145, 0.4060],
          [0.1767, 0.9980, 0.4415, 0.0838],
          [0.0332, 0.5095, 0.4402, 0.8077],
          [0.5429, 0.1697, 0.5759, 0.5666]],

         [[0.5888, 0.8027, 0.2435, 0.5854],
          [0.5964, 0.9139, 0.3052, 0.4030],
          [0.6395, 0.3709, 0.5470, 0.4967],
          [0.7824, 0.5559, 0.5852, 0.2932]],

         [[0.2872, 0.6416, 0.3181, 0.3236],
          [0.8485, 0.8219, 0.3015, 0.4674],
          [0.4075, 0.0832, 0.0299, 0.6567],
          [0.7302, 0.9547, 0.3857, 0.4873]]]]) torch.Size([1, 4, 4, 4])


In [3]:
class Othello_QNet(nn.Module):
    def __init__(self, board_size, in_channels=4, hidden_channels=8):
        super(Othello_QNet, self).__init__()
        self.board_size = board_size
        self.hidden_channels = hidden_channels
        self.in_channels = in_channels

        # Input shape: (3, board_size, board_size)
        self.f = nn.Sequential(
            nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.hidden_channels),
            nn.Tanh(),

            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.hidden_channels),
            nn.Tanh(),

            nn.Conv2d(self.hidden_channels, self.hidden_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.hidden_channels),
            nn.Tanh(),
        )

        self.Conv_Policy = nn.Conv2d(self.hidden_channels, 1, kernel_size=3, stride=1, padding=1)
        self.Softmax = nn.Softmax(dim=0)

        self.FC_Value = nn.Linear(self.hidden_channels * self.board_size**2, 1, bias=False)
    
    def forward(self, x):
        x = self.f(x)
        ac_probs = self.Conv_Policy(x)
        ac_probs = self.Softmax(torch.flatten(ac_probs)).reshape(-1, 1, self.board_size, self.board_size)
        value = self.FC_Value(torch.flatten(x))
        return value, ac_probs

In [7]:
QNet = Othello_QNet(board_size=4)
value, ac_probs = QNet(obs)

print(value)
print(ac_probs, ac_probs.shape, torch.sum(ac_probs).item())

tensor([0.0258], grad_fn=<SqueezeBackward3>)
tensor([[[[0.0617, 0.0763, 0.0550, 0.0479],
          [0.0544, 0.0526, 0.1033, 0.0980],
          [0.0648, 0.0446, 0.0480, 0.0639],
          [0.0456, 0.0493, 0.0744, 0.0602]]]], grad_fn=<ReshapeAliasBackward0>) torch.Size([1, 1, 4, 4]) 0.9999998211860657
