In [40]:
import torch
import torch.nn.functional as F
from torch.nn import init
import math
import numpy as np
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot

from src.models.baselines.lstm import TorchLSTM
from src.models.baselines.gru import TorchGRU
from src.models.sequence.base import SequenceModule
# from src.models.sequence.modules.s4block import S4Block
# from src.dataloaders.audio import mu_law_decode, linear_decode

## Data

In [24]:
# Customized transform (transforms to tensor, here you can normalize, perform Data Augmentation etc.)
my_transform = transforms.Compose([transforms.ToTensor()])

# Download data
mnist_train = torchvision.datasets.MNIST('data', train = True, download=True, transform=my_transform)
mnist_test = torchvision.datasets.MNIST('data', train = False, download=True, transform=my_transform)

In [34]:
print(mnist_train)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )


In [25]:
#  Create a train_loader to select a batch from it
train_loader_example = torch.utils.data.DataLoader(mnist_train, batch_size=64)

# Taking a single batch of the images
images, labels = next(iter(train_loader_example))
print('1. original images shape:', images.shape)

# Remove channel from shape
images = images.reshape(-1, 28, 28)
print('2. reshaped images shape:', images.shape, '\n')

1. original images shape: torch.Size([64, 1, 28, 28])
2. reshaped images shape: torch.Size([64, 28, 28]) 



## Model architecture

In [19]:
class StackedRNN(SequenceModule):
    """
    StackedRNN with skip connections:
        Input (d_model) -> RNN_1 (d_hidden) -> Linear (d_hidden, d_hidden) -> Output
        [Input, RNN_1] (d_model + d_hidden) -> RNN_2 (d_hidden) -> Linear (d_hidden, d_hidden) -> += Output
        [Input, RNN_2] (d_model + d_hidden) -> RNN_3 (d_hidden) -> Linear (d_hidden, d_hidden) -> += Output
    ...
    """

    @property
    def d_output(self):
        return self.d_model if self.output_linear else self.d_hidden

    def __init__(
        self,
        d_model,
        d_hidden,
        n_layers,
        learn_h0=False,
        rnn_type='gru',
        skip_connections=False,
        weight_norm=False,
        dropout=0.0,
        output_linear=False,
    ):
        super().__init__()

        self.d_model = d_model
        self.d_hidden = d_hidden
        self.n_layers = n_layers
        self.learn_h0 = learn_h0
        self.skip_connections = skip_connections
        self.weight_norm = torch.nn.utils.weight_norm if weight_norm else lambda x: x

        self.output_linear = output_linear
        self.rnn_layers = torch.nn.ModuleList()
        self.lin_layers = torch.nn.ModuleList()
        self.dropout_layers = torch.nn.ModuleList()
        self.rnn_type = rnn_type

        if rnn_type == 'lstm':
            RNN = TorchLSTM
        elif rnn_type == 'gru':
            RNN = TorchGRU
        else:
            raise ValueError('rnn_type must be lstm or gru')

        for i in range(n_layers):
            ## create ModuleList by literations
            if i == 0:
                self.rnn_layers.append(
                    RNN(d_model=d_model, d_hidden=d_hidden, n_layers=1, learn_h0=learn_h0),
                )
            else:
                if skip_connections:
                    self.rnn_layers.append(
                        RNN(d_model=d_model + d_hidden, d_hidden=d_hidden, n_layers=1, learn_h0=learn_h0),
                    )
                else:
                    self.rnn_layers.append(
                        RNN(d_model=d_hidden, d_hidden=d_hidden, n_layers=1, learn_h0=learn_h0),
                    )

            if skip_connections:
                self.lin_layers.append(self.weight_norm(torch.nn.Linear(d_hidden, d_hidden)))
            else:
                self.lin_layers.append(torch.nn.Identity())

            # If dropout, only apply to the outputs of RNNs that are not the last one (like torch's LSTM)
            if dropout > 0.0 and i < n_layers - 1:
                self.dropout_layers.append(torch.nn.Dropout(dropout))
            else:
                self.dropout_layers.append(torch.nn.Identity())

        if output_linear:
            self.output_layer = self.weight_norm(torch.nn.Linear(d_hidden, d_model))
        else:
            self.output_layer = torch.nn.Identity()

        # Apply weight norm to all the RNN layers
        for rnn in self.rnn_layers:
            # Find all Linear layers in the RNN
            for name, module in rnn.named_modules():
                if isinstance(module, torch.nn.Linear):
                    setattr(rnn, name, self.weight_norm(module))

        # Use orthogonal initialization for W_hn if using GRU (weight_hh_l[0])
        if rnn_type == 'gru':
            for rnn in self.rnn_layers:
                torch.nn.init.orthogonal_(rnn.weight_hh_l0[2 * d_hidden:].data)

    """Create initial state for a batch of inputs."""
    def default_state(self, *batch_shape, device=None):
        return [
            rnn.default_state(*batch_shape, device=device)
            for rnn in self.rnn_layers
        ]

    def forward(self, inputs, *args, state=None, **kwargs):
        outputs = inputs
        prev_states = [None] * len(self.rnn_layers) if state is None else state
        next_states = []
        out = 0.
        for rnn, prev_state, lin, dropout in zip(self.rnn_layers, prev_states, self.lin_layers, self.dropout_layers):
            # Run RNN on inputs 
            outputs, state = rnn(outputs, prev_state)
            next_states.append(state)
            # <pre_state> e.g., prev_state is h_0 for for layer 0
            # <outputs> the whole last RNN layer
            # <state>: the final hidden states at the current layer)


            outputs = dropout(outputs)
            z = lin(outputs)
            if self.skip_connections:
                # If skip connections, add the outputs of all the RNNs to the outputs
                out += z
                # Feed in the outputs of the previous RNN, and the original inputs to the next RNN
                outputs = torch.cat([outputs, inputs], dim=-1)
            else:
                out = z
                outputs = z

        out = self.output_layer(out)

        return out, next_states

In [54]:
prev_states = [None] * 3
prev_states

[None, None, None]

In [47]:
# Create model instance
StackedRNN_example = StackedRNN(d_model = 28, d_hidden =100, n_layers = 6)
print(StackedRNN_example)


# Making log predictions:
out = StackedRNN_example(images, prints=True)

StackedRNN(
  (rnn_layers): ModuleList(
    (0): TorchGRU(28, 100, batch_first=True)
    (1-5): 5 x TorchGRU(100, 100, batch_first=True)
  )
  (lin_layers): ModuleList(
    (0-5): 6 x Identity()
  )
  (dropout_layers): ModuleList(
    (0-5): 6 x Identity()
  )
  (output_layer): Identity()
)


In [48]:
# for parameter in StackedRNN_example.parameters():
#     print(parameter.size())
pytorch_total_params = sum(p.numel() for p in StackedRNN_example.parameters() if p.requires_grad)
print(pytorch_total_params)

342000


## Training 