In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import random

import torchvision
import torchvision.transforms as transforms

from src.utils.optim.schedulers import CosineWarmup
from src.models.sequence.rnns.rnn import RNN


import os
import argparse
from tqdm.auto import tqdm


import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myuqinzhou[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Use cuda if present
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device available now:', device)

Device available now: cpu


In [3]:
if device == 'cuda':
    cudnn.benchmark = True

## Defining functions

### Hyperparameters

In [4]:
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# Optimizer
parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=0.05, type=float, help='Weight decay')


# Scheduler
# parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler')
parser.add_argument('--epochs', default=10, type=float, help='Training epochs')


# Dataset
parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10'], type=str, help='Dataset')
parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10')


# Dataloader
parser.add_argument('--num_workers', default=0, type=int, help='Number of workers to use for dataloader')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')


# Model
parser.add_argument('--n_layers', default=2, type=int, help='Number of layers')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
parser.add_argument('--prenorm', action='store_true', help='Prenorm')
parser.add_argument('--norm', default= 'LN', choices=['LN', 'BN'], help='Norm types')
parser.add_argument('--cell', default= 'rnn', type=str, help='RNN\'s cell')


# General
parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

In [5]:
print(args)

Namespace(batch_size=64, cell='rnn', d_model=128, dataset='cifar10', dropout=0.1, epochs=10, grayscale=False, lr=0.001, n_layers=2, norm='LN', num_workers=0, prenorm=False, resume=False, weight_decay=0.05)


### Datasets

In [6]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [7]:
if args.dataset == 'cifar10':
    if args.grayscale:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
            transforms.Lambda(lambda x: x.view(1, 1024).t())
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Lambda(lambda x: x.view(3, 1024).t())
        ])

    # S4 is trained on sequences with no data augmentation!
    transform_train = transform_test = transform

    trainset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=False, download=True, transform=transform_test)

    d_input = 3 if not args.grayscale else 1
    d_output = 10

elif args.dataset == 'mnist':

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(1, 784).t())
    ])
    transform_train = transform_test = transform

    trainset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform_test)

    d_input = 1
    d_output = 10
else: raise NotImplementedError

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [8]:
# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)


In [9]:
# Taking a single batch of the images
images, labels = next(iter(trainloader))
print(images.size(), labels)

torch.Size([64, 1024, 3]) tensor([8, 2, 8, 8, 2, 5, 2, 4, 8, 9, 6, 7, 9, 5, 0, 3, 0, 8, 1, 3, 1, 1, 7, 8,
        2, 8, 6, 2, 2, 0, 5, 8, 7, 1, 4, 0, 4, 1, 0, 3, 8, 6, 9, 8, 5, 8, 4, 2,
        0, 3, 9, 2, 0, 2, 0, 6, 5, 4, 8, 4, 0, 6, 8, 9])


In [10]:
len(trainset)

45000

### Model architecture

In [11]:
# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d
print(dropout_fn)

<class 'torch.nn.modules.dropout.Dropout1d'>


In [12]:
class RNNbased(nn.Module):

    def __init__(
        self,
        d_input,
        d_output,
        cell='rnn',
        d_model=256,
        n_layers=2,
        dropout=0.2,
        prenorm=True,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB) (like embedding layer)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.FFNs = nn.ModuleList()

        for _ in range(n_layers):
            self.layers.append(
                RNN(d_input = d_model, d_model = d_model, cell = cell, return_output=True, transposed=False, dropout=0)
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))
            self.FFNs.append(nn.Sequential(nn.Linear(d_model, d_model*2), nn.GLU())                     
                                 )

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        for layer, norm, dropout, FFN in zip(self.layers, self.norms, self.dropouts, self.FFNs):
            # Each iteration of this loop will map (B, L, d_model) -> (B, L, d_model)
            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z)

            # Apply recurrence: we ignore the state input and output
            # z, _ = layer(z)

            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)
            
            # MLP +GLP
            z = FFN(z)

            z = dropout(z)

            # Residual connection
            x = z + x

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

In [13]:
example = RNNbased(d_input=d_input, d_output=d_output, cell='rnn', d_model=args.d_model, n_layers=args.n_layers, dropout=args.dropout, prenorm=args.prenorm)
print(example)

RNNbased(
  (encoder): Linear(in_features=3, out_features=128, bias=True)
  (layers): ModuleList(
    (0): RNN(
      (cell): RNNCell(
        (W_hx): Linear(in_features=128, out_features=128, bias=False)
        (activate): Tanh()
        (W_hh): Linear(in_features=128, out_features=128, bias=False)
      )
    )
    (1): RNN(
      (cell): RNNCell(
        (W_hx): Linear(in_features=128, out_features=128, bias=False)
        (activate): Tanh()
        (W_hh): Linear(in_features=128, out_features=128, bias=False)
      )
    )
  )
  (norms): ModuleList(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (dropouts): ModuleList(
    (0): Dropout1d(p=0.1, inplace=False)
    (1): Dropout1d(p=0.1, inplace=False)
  )
  (FFNs): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): GLU(dim=-1)
    )
    (1): Sequential(
      (0): Linear(in_features=128, out_f

### Training

In [None]:
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [60]:
###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################

# Training
def train(model, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    print_every = 50
    
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        wandb.log({"Train/Batch loss": train_loss/(batch_idx+1)})

        if (batch_idx % print_every) == 0:
            print(f"Batch: {batch_idx}, Avg: {train_loss/(batch_idx+1)}")
        
        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )
    return train_loss/(batch_idx+1)




def eval(model, optimizer, criterion, epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc

Loss

## Wandb

In [61]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [62]:
# Launch 2 simulated experiments
total_runs = 1
for run in range(total_runs):
    wandb.init(
        project="Master thesis", 
        # Model + Cell 
        name=f"RNN_Vanilla_{run}", 
        # Track hyperparameters and run metadata
        config= args)
    

    # Model
    print('==> Building model...')
    model = RNNbased(d_input=d_input, d_output=d_output, cell='rnn', d_model=args.d_model, n_layers=args.n_layers, dropout=args.dropout, prenorm=args.prenorm)
    model = model.to(device)
    
    ### defining optimizer + scheduler
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay= args.weight_decay)
    scheduler = CosineWarmup(optimizer, T_max = args.epochs, eta_min= 1e-7, warmup_step= int(args.epochs * 0.1) + 1) 
    
    
    pbar = tqdm(range(start_epoch, args.epochs))
    for epoch in pbar:
        if epoch == 0:
            pbar.set_description('Epoch: %d' % (epoch))
        else:

            pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))

        print('==> Training...')
        epoch_loss = train(model = model, optimizer = optimizer, criterion = nn.CrossEntropyLoss())
        wandb.log({"Train/Epoch Loss": epoch_loss})

        print('==> Validating...')
        val_acc = eval(model = model, optimizer = optimizer, criterion = nn.CrossEntropyLoss(), epoch = epoch, dataloader = valloader, checkpoint=True)
        wandb.log({"Val/Val acc": val_acc})
        
        scheduler.step()
        wandb.log({"lr": scheduler.get_last_lr()})
        # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")

    
    print('==> Testing...')
    test_acc = eval(model = model, optimizer = optimizer, criterion = nn.CrossEntropyLoss(), epoch = epoch, dataloader = testloader, checkpoint=True)
    wandb.log({"Test/Test acc": test_acc})
    
    # Mark the run as finished
    wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Epoch Loss,▁
loss,█▆▆▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Epoch Loss,1.95263
loss,1.95263


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016671029850022022, max=1.0…

==> Building model...


Epoch: 0:   0%|          | 0/10 [00:00<?, ?it/s]

==> Training...




Batch: 0, Avg: 2.386662006378174




Batch: 50, Avg: 2.1338492281296673




Batch: 100, Avg: 2.077360971139209




Batch: 150, Avg: 2.047399572978746




Batch: 200, Avg: 2.0294170860034315




Batch: 250, Avg: 2.014967956865926




Batch: 300, Avg: 2.007842195786511




Batch: 350, Avg: 1.995730560389679




Batch: 400, Avg: 1.988370094810638




Batch: 450, Avg: 1.9810454914152227




Batch: 500, Avg: 1.973240368380518




Batch: 550, Avg: 1.967568049413539




Batch: 600, Avg: 1.961034630578687




Batch: 650, Avg: 1.95381942688961




Batch: 700, Avg: 1.9499512577533042


Batch Idx: (703/704) | Loss: 1.950 | Acc: 27.473% (12363/45000): : 704it [21:26,  1.83s/it]


==> Validating...


Batch Idx: (78/79) | Loss: 1.874 | Acc: 31.820% (1591/5000): : 79it [00:52,  1.52it/s]
Epoch: 1 | Val acc: 31.820:  10%|█         | 1/10 [22:18<3:20:47, 1338.59s/it]

Epoch 0 learning rate: [0.001]
==> Training...




Batch: 0, Avg: 1.8622061014175415




Batch: 50, Avg: 1.9005255792655198




Batch: 100, Avg: 1.907343590613639




Batch: 150, Avg: 1.9073582469232824




Batch: 200, Avg: 1.8999568342569455




Batch: 250, Avg: 1.892735735828658




Batch: 300, Avg: 1.8851417398135923




Batch: 350, Avg: 1.8784299331512886




Batch: 400, Avg: 1.8730035363290078




Batch: 450, Avg: 1.8644415124292648




Batch: 500, Avg: 1.860593170700911




Batch: 550, Avg: 1.8556452507115968




Batch: 600, Avg: 1.8482611792258137




Batch: 650, Avg: 1.84307869989568




Batch: 700, Avg: 1.8371489903386071


Batch Idx: (703/704) | Loss: 1.836 | Acc: 32.498% (14624/45000): : 704it [21:21,  1.82s/it]


==> Validating...


Batch Idx: (78/79) | Loss: 1.779 | Acc: 34.600% (1730/5000): : 79it [00:51,  1.53it/s]
Epoch: 2 | Val acc: 34.600:  20%|██        | 2/10 [44:31<2:58:03, 1335.47s/it]

Epoch 1 learning rate: [0.001]
==> Training...




Batch: 0, Avg: 1.8959286212921143




Batch: 50, Avg: 1.758659138399012




Batch: 100, Avg: 1.7471088843770546




Batch: 150, Avg: 1.7409561564590756




Batch: 200, Avg: 1.7457035803676244




Batch: 250, Avg: 1.7365603351972967




Batch: 300, Avg: 1.7300894830710072




Batch: 350, Avg: 1.7309612582551788




Batch: 400, Avg: 1.7280685797593838




Batch: 450, Avg: 1.7276611832980306




Batch: 500, Avg: 1.7246052664435076




Batch: 550, Avg: 1.7206792100155202




Batch: 600, Avg: 1.717090265128061




Batch: 650, Avg: 1.7115850159283241




Batch: 700, Avg: 1.7086538649489638


Batch Idx: (703/704) | Loss: 1.708 | Acc: 38.742% (17434/45000): : 704it [21:20,  1.82s/it]


==> Validating...


Batch Idx: (78/79) | Loss: 1.651 | Acc: 39.840% (1992/5000): : 79it [00:51,  1.53it/s]
Epoch: 3 | Val acc: 39.840:  30%|███       | 3/10 [1:06:44<2:35:37, 1333.98s/it]

Epoch 2 learning rate: [0.0009619435722790177]
==> Training...




Batch: 0, Avg: 1.691342830657959




Batch: 50, Avg: 1.6467516913133509




Batch: 100, Avg: 1.6393487654109993




Batch: 150, Avg: 1.632578011380126




Batch: 200, Avg: 1.6248396077559362




Batch: 250, Avg: 1.6205695392601043




Batch: 300, Avg: 1.6201894291215562




Batch: 350, Avg: 1.6196492832270784




Batch: 400, Avg: 1.6191505770433574




Batch: 450, Avg: 1.619033786251381




Batch: 500, Avg: 1.6157041728615522




Batch: 550, Avg: 1.6097775362364393




Batch: 600, Avg: 1.6070427249949704




Batch: 650, Avg: 1.6015327491335423




Batch: 700, Avg: 1.601384059338699


Batch Idx: (703/704) | Loss: 1.601 | Acc: 43.420% (19539/45000): : 704it [21:30,  1.83s/it]


==> Validating...


Batch Idx: (78/79) | Loss: 1.571 | Acc: 44.020% (2201/5000): : 79it [00:51,  1.52it/s]
Epoch: 4 | Val acc: 44.020:  40%|████      | 4/10 [1:29:06<2:13:44, 1337.40s/it]

Epoch 3 learning rate: [0.0008535680352542143]
==> Training...




Batch: 0, Avg: 1.4398150444030762




Batch: 50, Avg: 1.5498565038045247




Batch: 100, Avg: 1.5427079684663527




Batch: 150, Avg: 1.5454180611679886




Batch: 200, Avg: 1.5409212675853747




Batch: 250, Avg: 1.5416596410758943




Batch: 300, Avg: 1.5340597261226059




Batch: 350, Avg: 1.5340141275329808




Batch: 400, Avg: 1.5355402409584444




Batch: 450, Avg: 1.5387418592584108




Batch: 500, Avg: 1.5373080902232856




Batch: 550, Avg: 1.5376275334297638




Batch: 600, Avg: 1.5371720293000612




Batch: 650, Avg: 1.5362718314069757




Batch: 700, Avg: 1.5362770837996043


Batch Idx: (703/704) | Loss: 1.535 | Acc: 46.033% (20715/45000): : 704it [21:35,  1.84s/it]


==> Validating...


Batch Idx: (78/79) | Loss: 1.515 | Acc: 46.300% (2315/5000): : 79it [00:53,  1.48it/s]
Epoch: 5 | Val acc: 46.300:  50%|█████     | 5/10 [1:51:35<1:51:48, 1341.62s/it]

Epoch 4 learning rate: [0.0006913725820109266]
==> Training...




Batch: 0, Avg: 1.5475780963897705




Batch: 50, Avg: 1.5068494782728308




Batch: 100, Avg: 1.5085607483835504


Batch Idx: (107/704) | Loss: 1.509 | Acc: 46.774% (3233/6912): : 108it [03:22,  1.87s/it]
Epoch: 5 | Val acc: 46.300:  50%|█████     | 5/10 [1:54:57<1:54:57, 1379.57s/it]


KeyboardInterrupt: 

In [63]:
scheduler.get_last_lr()


[0.0006913725820109266]

## Hyperparameters