## Import the necessary packages

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from dataclasses import dataclass

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

# from tqdm.auto import tqdm
from tqdm.notebook import tqdm

from s4 import S4Block as S4

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.


## Set the hyperparameters

In [3]:
@dataclass
class HyperparametersConfig:
    seed: int = 0
    ## Data
    batch_size: int = 128
    num_workers: int = 4
    ## Model
    d_model: int = 128
    n_layers: int = 4
    dropout: float = 0.1
    prenorm: bool = True
    ## Optimization
    lr: float = 0.01
    weight_decay: float = 0.01
    epochs: int = 10

config = HyperparametersConfig()

## Prepare the dataloaders

In [4]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(1, 784).t())
])
transform_train = transform_test = transform

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
trainset, _ = split_train_val(trainset, val_split=0.1)

valset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_test)
_, valset = split_train_val(valset, val_split=0.1)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

d_input = 1
d_output = 10

trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
valloader = torch.utils.data.DataLoader(valset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

## Prepare the models

In [6]:
class RNNModel(nn.Module):
    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=128,
        n_layers=1,
        dropout=0.1,
    ):
        super().__init__()

        self.rnn = nn.RNN(d_input, d_model, n_layers, batch_first=True, dropout=dropout)
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x, _ = self.rnn(x)  # (B, L, d_input) -> (B, L, d_model)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

class LSTMModel(nn.Module):
    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=128,
        n_layers=1,
        dropout=0.1,
    ):
        super().__init__()

        self.lstm = nn.LSTM(d_input, d_model, n_layers, batch_first=True, dropout=dropout)
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x, _ = self.lstm(x)  # (B, L, d_input) -> (B, L, d_model)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x


class S4Model(nn.Module):
    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=128,
        n_layers=1,
        dropout=0.1,
    ):
        super().__init__()

        self.s4_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(S4(d_model, dropout=dropout, transposed=False, lr=min(0.001, config.lr)))
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        for layer in self.s4_layers:
            x, _ = self.s4(x)  # (B, L, d_input) -> (B, L, d_model)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x



In [7]:

rnn = RNNModel(d_input=d_input,d_output=d_output,d_model=config.d_model,n_layers=config.n_layers,dropout=config.dropout,)
lstm = LSTMModel(d_input=d_input,d_output=d_output,d_model=config.d_model,n_layers=config.n_layers,dropout=config.dropout,)
s4 = S4Model(d_input=d_input,d_output=d_output,d_model=config.d_model,n_layers=config.n_layers,dropout=config.dropout)

# print the number of parameters for each model
print("RNN  model has", sum(p.numel() for p in rnn.parameters()), "parameters")
print("LSTM model has", sum(p.numel() for p in lstm.parameters()), "parameters")
print("S4   model has", sum(p.numel() for p in s4.parameters()), "parameters")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
rnn = rnn.to(device)
lstm = lstm.to(device)
s4 = s4.to(device)

rnn_optimizer = optim.AdamW(rnn.parameters(), lr=config.lr, weight_decay=config.weight_decay)
lstm_optimizer = optim.AdamW(lstm.parameters(), lr=config.lr, weight_decay=config.weight_decay)
s4_optimizer = optim.AdamW(s4.parameters(), lr=config.lr, weight_decay=config.weight_decay)

RNN  model has 117130 parameters
LSTM model has 464650 parameters
S4   model has 265482 parameters


## Training

In [8]:
def train(model, optimizer, criterion, trainloader):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )
    
    return model, train_loss/(batch_idx+1), 100.*correct/total

def eval(model, dataloader):
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    return eval_loss/(batch_idx+1), 100.*correct/total


criterion = nn.CrossEntropyLoss()
rnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(rnn_optimizer, config.epochs)
pbar = tqdm(range(0, config.epochs))
rnn_train_acc = []
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    rnn, train_loss, train_acc = train(rnn, rnn_optimizer, criterion, trainloader)
    val_loss, val_acc = eval(rnn, valloader)
    test_loss, test_acc = eval(rnn, testloader)
    rnn_scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")
    rnn_train_acc.append(train_acc)

Batch Idx: (421/422) | Loss: 2.303 | Acc: 11.200% (6048/54000): : 422it [00:09, 46.01it/s]
Batch Idx: (46/47) | Loss: 2.301 | Acc: 11.100% (666/6000): : 47it [00:00, 94.01it/s] 
Batch Idx: (78/79) | Loss: 2.302 | Acc: 11.350% (1135/10000): : 79it [00:00, 101.59it/s]
Batch Idx: (421/422) | Loss: 2.302 | Acc: 11.063% (5974/54000): : 422it [00:09, 42.90it/s]
Batch Idx: (46/47) | Loss: 2.301 | Acc: 11.100% (666/6000): : 47it [00:00, 84.86it/s] 
Batch Idx: (78/79) | Loss: 2.302 | Acc: 11.350% (1135/10000): : 79it [00:00, 106.50it/s]
Batch Idx: (421/422) | Loss: 2.302 | Acc: 11.183% (6039/54000): : 422it [00:09, 44.08it/s]
Batch Idx: (46/47) | Loss: 2.302 | Acc: 11.100% (666/6000): : 47it [00:00, 98.79it/s] 
Batch Idx: (78/79) | Loss: 2.301 | Acc: 11.350% (1135/10000): : 79it [00:00, 101.68it/s]
Batch Idx: (421/422) | Loss: 2.302 | Acc: 11.207% (6052/54000): : 422it [00:09, 45.21it/s]
Batch Idx: (46/47) | Loss: 2.302 | Acc: 11.100% (666/6000): : 47it [00:00, 96.69it/s] 
Batch Idx: (78/79) | 

In [9]:
criterion = nn.CrossEntropyLoss()
lstm_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lstm_optimizer, config.epochs)
pbar = tqdm(range(0, config.epochs))
lstm_train_acc = []
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    lstm, train_loss, train_acc = train(lstm, lstm_optimizer, criterion, trainloader)
    val_loss, val_acc = eval(lstm, valloader)
    test_loss, test_acc = eval(lstm, testloader)
    lstm_scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")
    lstm_train_acc.append(train_acc)

NameError: name 'lstm_optimizer' is not defined

In [None]:

criterion = nn.CrossEntropyLoss()
s4_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(s4_optimizer, config.epochs)
pbar = tqdm(range(0, config.epochs))
s4_train_acc = []
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    s4, train_loss, train_acc = train(s4, s4_optimizer, criterion, trainloader)
    val_loss, val_acc = eval(s4, valloader)
    test_loss, test_acc = eval(s4, testloader)
    s4_scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")
    s4_train_acc.append(train_acc)

## Compare RNN, LSTM, and S4 training curves. 

In [None]:

import matplotlib.pyplot as plt

plt.plot(rnn_train_acc, label='RNN')
plt.plot(lstm_train_acc, label='LSTM')
plt.plot(s4_train_acc, label='S4')

plt.xlabel('Epoch')
plt.ylabel('Train Accuracy')