In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import DataLoader

## Multi-Layer Perceptron
We used a simple fully connected MLP to compare the results of with skip-connected networks. Each network consists of an input layer, an output layer and list of hidden layers. The layers have a uniform number of neurons per layer.

Each different network model has the following common input arguments:
- *in_dim*: Dimensions of input vector (28x28 for MNIST, 32x32 for CIFAR-10)
- *out_dim*: Dimensions of final output vector/number of classes (10 for MNIST, CIFAR-10)
- *n_layers*: Number of hidden layers, default=18. Total layers is 2 more
- *n_nodes*: Number of neurons in each layer, default=100

## Skip Connections
Skip connections are extra connections between layers that is used to feed the output of one layer to the input of next layers skipping layers in between. Skip connections were introduced to solve degradation problems and enhance feature reusability. There are three widely-used variants: ResNet, DenseNet and UNet. We have implemented ResNet and DenseNet on an MLP model below.

Further, for each variant, we construct 3 different architectures:
- OneSkip: We skip one intermediate layer for each layer in the network
- SourceSkip: We connect the input layer (or first hidden layer) to each layer in the network
- FullSkip: Each layer is fully connected to every other layer by skips.

### ResNets
Residual Nets (ResNet) follow the principle of having residual blocks where the the skipped input is vectorially added to the next layer. In our implementation, each layer depicts a residual block. The number of features in the hidden layers are kept uniform to make addition of vectors easier.

### DenseNets
In DenseNets, the output of a layer is concatenated with the input to the next layer. Thus the input dimensions increase along the network for DenseNets, whereas for ResNets they remain the same. Due to concatenations the features of previous layers are carried on to the next layers. Our implementation is similar to ResNet.

In [4]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(n_nodes, n_nodes) for i in range(n_layers)])
        self.out_layer = nn.Linear(n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        
        x = F.relu(self.in_layer(x))
        for i,apply_layer in enumerate(self.hidden_layers):
            x = F.relu(apply_layer(x))
        y_pred = self.out_layer(x)
        return y_pred,1

In [14]:
class OneSkipDN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.next_layer = nn.Linear(in_dim+n_nodes, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(2*n_nodes, n_nodes) for i in range(n_layers-1)])
        self.out_layer = nn.Linear(2*n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out_skip = F.relu(self.in_layer(x))
        out_dir = F.relu(self.next_layer(torch.cat((x,out_skip), 1)))

        for _,apply_layer in enumerate(self.hidden_layers):
            out_temp = out_skip
            out_skip = out_dir
            out_dir = F.relu(apply_layer(torch.cat((out_skip, out_temp), 1)))
        
        y_pred = self.out_layer(torch.cat((out_skip, out_dir), 1))
        return y_pred,1

In [13]:
class SourceSkipDN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(in_dim+n_nodes, n_nodes) for i in range(n_layers)])
        self.out_layer = nn.Linear(in_dim+n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out_dir = F.relu(self.in_layer(x))
        for _,apply_layer in enumerate(self.hidden_layers):
            out_dir = F.relu(apply_layer(torch.cat((out_dir, x), 1)))
        y_pred = self.out_layer(torch.cat((out_dir, x), 1))
        return y_pred,1

In [22]:
class FullSkipDN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(in_dim+(i*n_nodes), n_nodes) for i in range(1,n_layers+1)])
        self.out_layer = nn.Linear(in_dim+((n_layers+1)*n_nodes), out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out_skips = [x]
        out_skips.append(F.relu(self.in_layer(x)))

        for _,apply_layer in enumerate(self.hidden_layers):
            nxt_out = F.relu(apply_layer(torch.cat(out_skips, 1)))
            out_skips.append(nxt_out)
        y_pred = self.out_layer(torch.cat(out_skips, 1))
        return y_pred,1

In [5]:
class OneSkipRN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.next_layer = nn.Linear(n_nodes, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(n_nodes, n_nodes) for i in range(n_layers-1)])
        self.out_layer = nn.Linear(n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out_skip = F.relu(self.in_layer(x))
        out_dir = F.relu(self.next_layer(out_skip))

        for _,apply_layer in enumerate(self.hidden_layers):
            out_temp = out_skip
            out_skip = out_dir
            out_dir = F.relu(apply_layer(out_temp+out_skip))
        
        y_pred = self.out_layer(out_skip+out_dir)
        return y_pred,1

In [7]:
class SourceSkipRN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.next_layer = nn.Linear(n_nodes, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(n_nodes, n_nodes) for i in range(n_layers-1)])
        self.out_layer = nn.Linear(n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out_skip = F.relu(self.in_layer(x))
        out_dir = F.relu(self.next_layer(out_skip))

        for _,apply_layer in enumerate(self.hidden_layers):
            out_dir = F.relu(apply_layer(out_dir+out_skip))
        
        y_pred = self.out_layer(out_skip+out_dir)
        return y_pred,1

In [19]:
class FullSkipRN(nn.Module):
    def __init__(self, in_dim, out_dim, n_layers=3, n_nodes=100):
        super().__init__()
        
        self.in_layer = nn.Linear(in_dim, n_nodes)
        self.next_layer = nn.Linear(n_nodes, n_nodes)
        self.hidden_layers = nn.ModuleList([nn.Linear(n_nodes, n_nodes) for i in range(n_layers-1)])
        self.out_layer = nn.Linear(n_nodes, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        x = F.relu(self.in_layer(x))
        out_skips = [x]
        out_skips.append(F.relu(self.next_layer(x)))

        for _,apply_layer in enumerate(self.hidden_layers):
            nxt_out = F.relu(apply_layer(sum(out_skips)))
            out_skips.append(nxt_out)
        
        y_pred = self.out_layer(sum(out_skips))
        return y_pred,1

In [15]:
a = torch.randn(1,3)
b = torch.randn(1,3)
a,b

(tensor([[-0.5225,  0.1325, -1.3603]]), tensor([[0.3897, 0.4429, 1.6402]]))

In [18]:
arr = [a,b]
sum(arr)

tensor([[-0.1328,  0.5753,  0.2799]])

In [20]:
torch.cat(arr,1)

tensor([[-0.5225,  0.1325, -1.3603,  0.3897,  0.4429,  1.6402]])