In [8]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    """
    Multilayer Perceptron (MLP).
    """
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(5,200, dtype=torch.float64)   # fully connected layer 1
        self.fc2 =  nn.Linear(200,300, dtype=torch.float64)  # fully connected layer 2 (output layer)
        self.fc3 = nn.Linear(300,300, dtype=torch.float64)
        self.fc4 = nn.Linear(300,200, dtype=torch.float64)
        self.fc5 = nn.Linear(200,5*5, dtype=torch.float64)
        
        self.init_weights()

    def init_weights(self):
        for fc in [self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]:
            f_in = fc.weight.size(1)
            nn.init.normal_(fc.weight, 0.0, 1 / math.sqrt(f_in/2))
            nn.init.constant_(fc.bias, 0.0)

    def forward(self, x):

        z = F.relu(self.fc1(x))
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc4(z))
        z = self.fc5(z)
        z = z.reshape(-1, 5, 5)  # reshape output to [batch_size, 5, 5]
        
        return z


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [9]:
if __name__ == '__main__':
    from dataset import SwitchDataset
    net = MLP()
    print(net)
    print('Number of CNN parameters: {}'.format(count_parameters(net)))
    dataset = SwitchDataset ()
    VOQs, Matchings = next(iter(dataset.train_loader))
    print('Size of model output:', net(VOQs.double()).size())

MLP(
  (fc1): Linear(in_features=5, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=300, bias=True)
  (fc3): Linear(in_features=300, out_features=300, bias=True)
  (fc4): Linear(in_features=300, out_features=200, bias=True)
  (fc5): Linear(in_features=200, out_features=25, bias=True)
)
Number of CNN parameters: 217025
Size of model output: torch.Size([20, 5, 5])


In [6]:
import torch
xx = torch.rand([4,5,25])

In [12]:
print(xx[0])

tensor([[0.6957, 0.0742, 0.4300, 0.4726, 0.2031, 0.7170, 0.5558, 0.7992, 0.8887,
         0.3692, 0.6846, 0.0411, 0.2311, 0.1864, 0.3102, 0.4288, 0.0867, 0.6728,
         0.9611, 0.1147, 0.7084, 0.6208, 0.4916, 0.9043, 0.8982],
        [0.5116, 0.6817, 0.6919, 0.6359, 0.2754, 0.2347, 0.4219, 0.1444, 0.6425,
         0.4778, 0.5828, 0.9268, 0.6429, 0.2447, 0.9827, 0.9873, 0.0449, 0.3861,
         0.2960, 0.2812, 0.2121, 0.4212, 0.8944, 0.5519, 0.7283],
        [0.0334, 0.7106, 0.3271, 0.4712, 0.2638, 0.0622, 0.5018, 0.3616, 0.8383,
         0.8122, 0.1504, 0.3883, 0.0217, 0.9826, 0.2547, 0.9379, 0.4903, 0.1151,
         0.6900, 0.2795, 0.9832, 0.9959, 0.7060, 0.6790, 0.9401],
        [0.0016, 0.8578, 0.2292, 0.2896, 0.0951, 0.9400, 0.2277, 0.9241, 0.1898,
         0.8321, 0.2379, 0.9686, 0.5273, 0.0713, 0.6181, 0.7810, 0.7830, 0.2592,
         0.5633, 0.7711, 0.1324, 0.8601, 0.8911, 0.8625, 0.2240],
        [0.0851, 0.0762, 0.3791, 0.3165, 0.2616, 0.3770, 0.8844, 0.1088, 0.5117,
       