In [130]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [131]:
from IPython.display import display, HTML, Video
display(HTML("<style>.container { width:70% !important; }</style>"))

In [132]:
import matplotlib 
import matplotlib_inline
import matplotlib.pyplot as plt

import torch

import tqdm.autonotebook as tqdm

from model.s4 import S4
from model.s4_model import S4Model
from datasets import SequentialCIFAR10

In [133]:
%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

In [154]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
    
print(device)

cpu


In [172]:
batch_size, num_workers = 50, 0

In [173]:
ds_test = SequentialCIFAR10('/Users/nakhodnov/data/datasets/', train=False, download=False)
ds_train = SequentialCIFAR10('/Users/nakhodnov/data/datasets/', train=True, download=False)

dl_test = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=False)
dl_train = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)

In [174]:
images, labels = next(iter(dl_train))
images.shape

torch.Size([50, 1024, 3])

In [175]:
loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

model = S4Model(
    d_input=3,
    d_output=len(ds_train.data.classes),
    d_model=512,
    n_layers=6,
    dropout=0.1,
    prenorm=False,
    block_class=S4,
    block_kwargs={
        'bidirectional': True, 'postact': 'glu', 'tie_dropout': True,
#         'mode': 'diag', 'measure': 'diag-lin', 'disc': 'zoh', 'real_type': 'exp', 
        'n_ssm': 2
    },
    dropout_fn=torch.nn.Dropout1d
)

print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.05)

S4Model(
  (encoder): Linear(in_features=3, out_features=512, bias=True)
  (s4_layers): ModuleList(
    (0): S4(
      (kernel): SSKernel(
        (kernel): SSKernelNPLR()
      )
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (1): S4(
      (kernel): SSKernel(
        (kernel): SSKernelNPLR()
      )
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (2): S4(
      (kernel): SSKernel(
        (kernel): SSKernelNPLR()
      )
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (3): S4(
      (kernel): S

In [184]:
def train(model, optimizer, loss_fn, dl):
    model.train()
    
    n_objects, total_loss, accuracy = 0, 0.0, 0
    for images, labels in tqdm.tqdm(dl_train, total=len(dl), leave=False):
        optimizer.zero_grad()

        images = images.to(device=device)
        labels = labels.to(device=device, dtype=torch.long)
        y = model(images)
        predictions = torch.argmax(y, dim=1)

        loss = loss_fn(y, labels)
        loss.backward()
        optimizer.step()
        n_objects += predictions.shape[0]
        total_loss += loss.item() * predictions.shape[0]
        accuracy += torch.sum(torch.eq(predictions, labels)).item()
        
        break
        
    return total_loss / n_objects, accuracy / n_objects
    
def test(model, loss_fn, dl):
    model.eval()
    with torch.no_grad():
        n_objects, total_loss, accuracy = 0, 0.0, 0
        for images, labels in tqdm.tqdm(dl_train, total=len(dl), leave=False):
            images = images.to(device=device)
            labels = labels.to(device=device, dtype=torch.long)
            y = model(images)
            predictions = torch.argmax(y, dim=1)

            loss = loss_fn(y, labels)

            n_objects += predictions.shape[0]
            total_loss += loss.item() * predictions.shape[0]
            accuracy += torch.sum(torch.eq(predictions, labels)).item()

            break
            
    return total_loss / n_objects, accuracy / n_objects

In [185]:
max_epochs = 1

all_losses_test, all_accuracies_test = [], []
all_losses_train, all_accuracies_train = [], []
for epoch in tqdm.tqdm(range(max_epochs), total=max_epochs):
    loss_train, accuracy_train = train(model, optimizer, loss_fn, dl_train)
    
    loss_train, accuracy_train = test(model, loss_fn, dl_train)
    loss_test, accuracy_test = test(model, loss_fn, dl_test)
    
    all_losses_test.append(loss_test)
    all_losses_train.append(loss_train)
    all_accuracies_test.append(accuracy_test)
    all_accuracies_train.append(accuracy_train)
    
    print('Loss: {0:.3f}/{1:.3f}. Accuracy: {2:.3f}/{3:.3f}'.format(
        all_losses_train[-1], all_losses_test[-1], all_accuracies_train[-1], all_accuracies_test[-1]
    ))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 6.782/6.513. Accuracy: 0.060/0.100
