In [1]:
import sys
sys.path.append('/home/siqiouyang/work/projects/state-spaces/')

In [18]:
from hydra import compose, initialize
from omegaconf import OmegaConf

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

from src.models.sequence.model import SequenceModel
from src.tasks.encoders import PositionalEncoder

# Generate Dataset

In [3]:
class SequenceLabelDataset(Dataset):
    def __init__(self, inputs, labels):
        super(SequenceLabelDataset, self).__init__()
        assert len(inputs) == len(labels)

        self.inputs = inputs
        self.labels = labels

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

    def __len__(self):
        return len(self.inputs)

In [4]:
LENGTH = 10
DIM = 4
N_TRAIN = 200
N_VAL = 20
N_TEST = 20
BATCH_SIZE = 20

In [5]:
def get_dataset(n_sample):
    inputs = np.random.random((n_sample, LENGTH // 2, DIM))
    inputs = np.concatenate([inputs, inputs], axis=1).tolist()
    [np.random.shuffle(x) for x in inputs]
    inputs = torch.tensor(inputs)

    labels = []
    for x in inputs:
        label = []
        for i in range(len(x)):
            flag = True
            for j in range(i):
                if x[i].equal(x[j]):
                    flag = False
            label.append(0 if flag else 1)
        labels.append(label)
    labels = torch.tensor(labels)

    dataset = SequenceLabelDataset(inputs, labels)
    return dataset

In [6]:
train_ds = get_dataset(N_TRAIN)
val_ds = get_dataset(N_VAL)
test_ds = get_dataset(N_TEST)

In [7]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# Build Model

In [8]:
class Classifier(nn.Module):
    def __init__(self, model):
        super(Classifier, self).__init__()
        self.model = model
        self.proj = nn.Linear(DIM, 2)

    def forward(self, x):
        x = self.model(x)
        if type(x) is tuple:
            x = x[0]
        return self.proj(x)

In [9]:
initialize(version_base=None, config_path='../configs/model/')

hydra.initialize()

In [10]:
cfg = compose(config_name='s4_dev')
OmegaConf.set_struct(cfg, False)
cfg.pop('_name_'), cfg.pop('encoder'), cfg.pop('decoder')
s4 = SequenceModel(**cfg)

CUDA extension for Cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%


In [11]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, d_hidden, n_head, n_layer, batch_first=True, norm_first=True):
        super(TransformerEncoder, self).__init__()

        self.position_emb = PositionalEncoder(d_model=d_model, dropout=0., max_len=LENGTH)
        layer = nn.TransformerEncoderLayer(d_model, n_head, d_hidden, dropout=0., batch_first=batch_first, norm_first=norm_first)
        layer_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(layer, n_layer, norm=layer_norm)

    def forward(self, x):
        x = self.position_emb(x)
        return self.encoder(x)

In [12]:
transformer = TransformerEncoder(DIM, DIM * 4, 1, 1)

In [13]:
clas_s4 = Classifier(s4)
clas_xfm = Classifier(transformer)

In [14]:
model = clas_s4
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [16]:
running_loss = 0.
running_acc = 0.
n_epoch = 10
for _ in range(n_epoch):
    iterator = tqdm(train_dataloader)
    for inputs, labels in iterator:
        model.zero_grad()
        preds = model(inputs)
        loss = loss_fn(preds.transpose(1, 2), labels)
        loss.backward()

        print('\n', '=' * 10)
        for name, param in model.named_parameters():
            print(name, param.grad)
        print('\n', '=' * 10)

        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss = 0.1 * loss.item() + 0.9 * running_loss

        acc = (preds.argmax(dim=-1) == labels).float().mean()
        running_acc = 0.1 * acc.item() + 0.9 * running_acc

        # iterator.set_description('running loss {:.2f} running acc {:.2f}'.format(running_loss, running_acc))

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[  1.9214,   7.6651],
          [ -5.8614,   0.0721]],

         [[-10.6703,  -8.0080],
          [  4.7503,  -4.3997]],

         [[  3.1856,  -4.0136],
          [  0.8359,   4.6546]],

         [[ -0.4345,   0.8220],
          [  1.2137,  -0.2218]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([-9.5746, 17.7018,  2.0300,  0.4894])
model.layers.0.layer.kernel.kernel.B tensor([[[[ 8.6189,  4.0891],
          [ 0.5126, -4.1888]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[-15.1971,  -0.8256],
          [ 11.1049,   5.5160]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[2.2841, 5.1299]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[-2.3956, -3.5007]])
model.layers.0.layer.output_linear.0.weight tensor([[-0.0037,  0.0259,  0.0108,  0.0061],
        [ 0.0155, -0.0651, -0.0258, -0.0702],
        [-0.0142, -0.0128, -0.0141,  0.0070],
        [ 0.0038,  0.0352,  0.0203,  0.0400],
        [ 0.0027, -0.0055,  0.0006

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

  0%|          | 0/10 [00:00<?, ?it/s]


model.layers.0.layer.kernel.kernel.C tensor([[[[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]],

         [[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.log_dt tensor([nan, nan, nan, nan])
model.layers.0.layer.kernel.kernel.B tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.P tensor([[[[nan, nan],
          [nan, nan]]]])
model.layers.0.layer.kernel.kernel.inv_w_real tensor([[nan, nan]])
model.layers.0.layer.kernel.kernel.w_imag tensor([[nan, nan]])
model.layers.0.layer.output_linear.0.weight tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])
model.layers.0.layer.output_linear.0.bias tensor([nan, nan, nan, nan, nan, nan, nan, nan])
model.layers.0.norm.norm.weight 

In [74]:
preds = model(inputs)
loss = loss_fn(preds.transpose(1, 2), labels)

In [75]:
print(loss)

tensor(0.7652, grad_fn=<NllLoss2DBackward0>)
