## Import the necessary packages

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from dataclasses import dataclass

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

from tqdm.auto import tqdm

from s4 import S4Block as S4

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.


## Set the hyperparameters

In [3]:
@dataclass
class HyperparametersConfig:
    seed: int = 0
    ## Data
    batch_size: int = 128
    num_workers: int = 4
    ## Model
    d_model: int = 128
    n_layers: int = 4
    dropout: float = 0.1
    prenorm: bool = True
    ## Optimization
    lr: float = 0.01
    weight_decay: float = 0.01
    epochs: int = 100

config = HyperparametersConfig()

## Prepare the dataloaders

In [4]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(1, 784).t())
])
transform_train = transform_test = transform

trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_train)
trainset, _ = split_train_val(trainset, val_split=0.1)

valset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_test)
_, valset = split_train_val(valset, val_split=0.1)

testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform_test)

d_input = 1
d_output = 10

trainloader = torch.utils.data.DataLoader(
trainset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
valloader = torch.utils.data.DataLoader(
valset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
testloader = torch.utils.data.DataLoader(
testset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)


## Prepare the models

In [6]:
rnn = nn.RNN(input_size=d_input, hidden_size=128, num_layers=1, batch_first=True)
lstm = nn.LSTM(input_size=d_input, hidden_size=128, num_layers=1, batch_first=True)

class S4Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4(d_model, dropout=dropout, transposed=True, lr=min(0.001, config.lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(nn.Dropout1d(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x


s4 = S4Model(
    d_input=d_input,
    d_output=d_output,
    d_model=config.d_model,
    n_layers=config.n_layers,
    dropout=config.dropout,
    prenorm=config.prenorm,
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
rnn = rnn.to(device)
lstm = lstm.to(device)
s4 = s4.to(device)

rnn_optimizer = optim.AdamW(rnn.parameters(), lr=config.lr, weight_decay=config.weight_decay)
lstm_optimizer = optim.AdamW(lstm.parameters(), lr=config.lr, weight_decay=config.weight_decay)
s4_optimizer = optim.AdamW(s4.parameters(), lr=config.lr, weight_decay=config.weight_decay)

## Training

In [7]:
def train(model, optimizer, criterion, trainloader):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )
    
    return train_loss/(batch_idx+1), 100.*correct/total

criterion = nn.CrossEntropyLoss()
loss, acc = train(s4, s4_optimizer, criterion, trainloader)

Batch Idx: (304/422) | Loss: 0.494 | Acc: 83.120% (32450/39040): : 305it [00:20, 15.66it/s]

## Compare RNN, LSTM, and S4 training curves. 