In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from time import time

from torchinfo import summary

In [124]:
class LM(nn.Module):
    def __init__(self, num):
        super(LM, self).__init__()
        
        self.fcs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(120, 120),
                    nn.ReLU(),
                    nn.Linear(120, 120)
                ) for _ in range(num)
            ]
        )
        
    def forward(self, x):
        
        x = [fc(x) for fc in self.fcs]
        
        output = torch.stack(x, dim=1)
        
        # output = torch.sum(output, dim=1)
        
        return output
    
class SM(nn.Module):
    def __init__(self, num):
        super(SM, self).__init__()
        
        
        self.num = num
        
        self.fc1 = nn.Linear(120, 120)
        self.fc2 = nn.Linear(120, 120)
        
        self.w1 = nn.Parameter(
            torch.block_diag(
                *[self.fc1.weight] * num
            )
        )    
        self.b1 = nn.Parameter(
            torch.cat([self.fc1.bias] * num)
        )
        
        
        self.w2 = nn.Parameter(
            torch.block_diag(
                *[self.fc2.weight] * num
            )
        )                   
        self.b2 = nn.Parameter(
            torch.cat([self.fc2.bias] * num)
        )
        
    def forward(self, x):
        
        x = x.repeat(1, self.num)
        
        x = torch.matmul(x, self.w1) + self.b1
        x = torch.relu(x)
        x = torch.matmul(x, self.w2) + self.b2
        output = x.view(-1, self.num, 120)
        
        return output
        
        

In [125]:
epochs = 1000
layers = 10
inputs = torch.randn(100, 120)

In [126]:
lm = LM(num=layers)

print(summary(lm, input_data=inputs))

st = time()
for _ in range(epochs):
    outputs = lm(inputs)

print(time() - st)

Layer (type:depth-idx)                   Output Shape              Param #
LM                                       [100, 10, 120]            --
├─ModuleList: 1-1                        --                        --
│    └─Sequential: 2-1                   [100, 120]                --
│    │    └─Linear: 3-1                  [100, 120]                14,520
│    │    └─ReLU: 3-2                    [100, 120]                --
│    │    └─Linear: 3-3                  [100, 120]                14,520
│    └─Sequential: 2-2                   [100, 120]                --
│    │    └─Linear: 3-4                  [100, 120]                14,520
│    │    └─ReLU: 3-5                    [100, 120]                --
│    │    └─Linear: 3-6                  [100, 120]                14,520
│    └─Sequential: 2-3                   [100, 120]                --
│    │    └─Linear: 3-7                  [100, 120]                14,520
│    │    └─ReLU: 3-8                    [100, 120]              

In [127]:
inputs = torch.randn(100, 120)
sm = SM(num=layers)

print(summary(sm, input_data=inputs))


st = time()
for _ in range(1000):
    outputs = sm(inputs)

print(time() - st)

Layer (type:depth-idx)                   Output Shape              Param #
SM                                       [100, 10, 120]            2,911,440
Total params: 2,911,440
Trainable params: 2,911,440
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
Input size (MB): 0.05
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.05
1.2347021102905273


In [128]:
w_b = [(x.weight, x.bias) for x in [nn.Linear(120, 120) for _ in range(2)]]
W = torch.block_diag(*[w for w, _ in w_b])

W = torch.cat([w for w, _ in w_b], dim=0)

print(W.shape)

b = nn.Parameter(torch.cat([b for _, b in w_b]))

print(b.shape)

torch.Size([240, 120])
torch.Size([240])
