# Neural Network Design

In [3]:
import torch.nn as nn
import torch.nn.functional as F

The primary component to build neural network is layer, each layer has two primary components: \
&emsp;1. A transformation (code) \
&emsp;2. A collection of weights (data) \
<code>nn.Module & forward()</code>

In [4]:
class MyNetwork(nn.Module):
    def __init__(self) -> None:
        super(MyNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        # (1) Input layer
        t = t
        # (2) Hidden conv layer 1
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        # (3) Hidden conv layer 2
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        # (4) Hidden linear layer 1
        t = t.reshape(-1, 12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        # (5) Hidden linear layer 2
        t = self.fc2(t)
        t = F.relu(t)
        # (6) Output linear layer 
        t = self.out(t)
        # t = F.softmax(y, dim=1)
        return t

In [11]:
network = MyNetwork()
network.conv1.weight

Parameter containing:
tensor([[[[ 0.1412, -0.1912,  0.1668, -0.0718, -0.1771],
          [-0.1692,  0.1525, -0.0434,  0.1038, -0.0565],
          [-0.0057,  0.0980,  0.1387,  0.1667, -0.0320],
          [ 0.1335,  0.0163,  0.1401,  0.0451,  0.1253],
          [ 0.0575,  0.1632,  0.0292, -0.0307, -0.0491]]],


        [[[ 0.0284,  0.1486,  0.1245, -0.0309,  0.1594],
          [-0.0302,  0.0900,  0.0618, -0.1294,  0.0770],
          [ 0.0203, -0.1393, -0.0509,  0.0674, -0.0833],
          [-0.1548, -0.1058,  0.1695, -0.0165, -0.0631],
          [-0.1371,  0.1641,  0.0935,  0.0427,  0.0682]]],


        [[[-0.0788,  0.1711,  0.1997, -0.0375,  0.1642],
          [-0.1564, -0.1217,  0.1886,  0.0548,  0.1874],
          [-0.0519,  0.0481,  0.1918, -0.1016, -0.0442],
          [-0.1188,  0.0250,  0.1310, -0.0717,  0.0584],
          [-0.1253,  0.0177, -0.0755,  0.1924, -0.0635]]],


        [[[ 0.0116,  0.0811, -0.0974, -0.1791, -0.1839],
          [-0.1851, -0.0678, -0.0182,  0.0743, -0.1235