In [55]:
import torch
from torch import nn
from torch.nn import functional as F

In [56]:
# generate X
X = torch.randn((2, 20), requires_grad=True)

In [57]:
# nn.Module for MLP
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, X):
        return self.out(F.relu(self.hidden(X)))

net = MLP()
net(X).detach().numpy()

array([[ 0.15443158,  0.27320445,  0.18545823,  0.41052157,  0.08764493,
        -0.08156813,  0.26313844,  0.24579261, -0.03862309,  0.36557156],
       [-0.08770379,  0.18351898,  0.26165506,  0.04121573, -0.27300996,
         0.2092031 ,  0.07376558,  0.3835181 , -0.15643279,  0.24441603]],
      dtype=float32)

In [58]:
# Sequential BLK

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for i, module in enumerate(args):
            self._modules[str(i)] = module

    def forward(self, X):
        for blk in self._modules.values():
            X = blk(X)
        return X

net = MySequential(MLP(), nn.Linear(10, 20))
net(X).detach().numpy()

array([[-0.21886319,  0.01656903, -0.10178545, -0.300775  , -0.02980975,
        -0.25826848,  0.00332622, -0.24216972, -0.35549825,  0.10793746,
         0.16377977, -0.22978947,  0.19712415,  0.06184904, -0.0749486 ,
        -0.25596613,  0.11900403,  0.38957718,  0.13625267,  0.24976714],
       [-0.20817   , -0.16096565, -0.09038559, -0.29355374, -0.2190541 ,
        -0.28728032,  0.08058178, -0.25969097, -0.11491592, -0.02254538,
         0.01490976, -0.21128692,  0.14241996,  0.06030833, -0.10960078,
        -0.23239025,  0.15991119,  0.3589933 ,  0.09074242,  0.23891999]],
      dtype=float32)

In [59]:
# Fixed Hidden MLP ------- so important !
# Fix hidden + parameter-shared linear layer

class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = nn.Parameter(torch.randn(10, 20))
        self.linear = nn.Linear(20, 10)

    def forward(self, X):
        X = self.linear(X)
        X = F.relu(torch.mm(X, self.rand_weight) + 1)
        X = self.linear(X)
        # control stream
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

net = FixedHiddenMLP()
net(X).detach().numpy()

array(0.35255978, dtype=float32)

In [60]:
# nestedMLP -------- also important !!

class NestedMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
                                 nn.Linear(64, 20), nn.ReLU(),)
        self.linear = nn.Linear(20, 10)

    def forward(self, X):
        X = self.net(X)
        X = F.relu(self.linear(X))
        return X

net = nn.Sequential(NestedMLP(), nn.Linear(10, 20), nn.ReLU(), FixedHiddenMLP())

net(X).detach().numpy()

array(0.00620824, dtype=float32)