In [3]:
# customize layer
import torch
from torch.nn import functional as f
from torch import nn

# simple example
class centeredLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, X):
        return X - X.mean()  # forward function are automatically called when class is initialized


layer = centeredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))  # call forward so we get [-2, -1, 0, 1, 2] results

net = nn.Sequential(nn.Linear(8, 128), centeredLayer())
y = net(torch.rand(16, 8))
print(y)
print(y.mean())  # bearly 0, which means centeredLayer works in sequential

tensor([-2., -1.,  0.,  1.,  2.])
tensor([[ 0.7157, -0.5380, -0.2954,  ..., -0.3217,  0.1765,  0.4156],
        [ 0.5206, -0.5326, -0.4781,  ..., -0.1824,  0.1855,  0.3475],
        [ 0.6036, -0.4129, -0.5439,  ..., -0.2046,  0.1718,  0.6398],
        ...,
        [ 0.5362, -0.7078, -0.5000,  ..., -0.3100,  0.2452,  0.4090],
        [ 0.4285, -0.3110, -0.5793,  ..., -0.2502,  0.3570,  0.6356],
        [ 0.7316, -0.4105, -0.3695,  ..., -0.2492,  0.3032,  0.7442]],
       grad_fn=<SubBackward0>)
tensor(4.6566e-10, grad_fn=<MeanBackward0>)


In [7]:
# customize parameters of layer
class myLinear(nn.Module):
    def __init__(self, in_units, units) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.rand(in_units, units))  # when we use nn.Parameter() function, 
                                                                 # pytorch automatically add them to module's parameters list
        self.bias = nn.Parameter(torch.rand(units))
    
    def forward(self, X):
        lin = torch.matmul(X, self.weight.data) + self.bias.data
        return f.relu(lin)
    
linear = myLinear(5, 3)
print(linear.weight)

# call forward() function indicatively
print(linear(torch.rand(2, 5)))

# validated in sequentail
net = nn.Sequential(myLinear(64, 8), myLinear(8, 1))
print(net(torch.rand(2, 64)))

Parameter containing:
tensor([[0.7911, 0.7050, 0.8046],
        [0.2476, 0.9851, 0.5565],
        [0.4053, 0.5551, 0.1865],
        [0.6024, 0.6224, 0.0948],
        [0.4930, 0.4969, 0.4747]], requires_grad=True)
tensor([[1.7427, 2.0683, 1.3792],
        [1.6893, 1.7784, 1.2554]])
tensor([[52.6111],
        [57.7240]])
