### S4

In [None]:
import torch
from torch import nn
from torch import view_as_real
from torch.fft import ifft, irfft, rfft
from torch.nn import functional as F


def log_step_init(tensor, dt_min=0.001, dt_max=0.1):
    scale = torch.log(torch.tensor(dt_max)) - torch.log(torch.tensor(dt_min))
    return tensor * scale + torch.log(torch.tensor(dt_min))

def hippo(N):
    P = torch.sqrt(1 + 2 * torch.arange(1, N+1, dtype=torch.float))
    A = torch.outer(P, P)
    A = torch.tril(A) - torch.diag(torch.arange(1, N+1, dtype=torch.float))
    return A

def hippo_dplr(N):
    A = -1 * hippo(N)  # -ve sign here

    p = 0.5 * torch.sqrt(2 * torch.arange(1, N+1, dtype=torch.float32) + 1.0)
    p = p.to(torch.complex64)
    
    Ap = A.to(torch.complex64) + torch.outer(p, p)
    
    # eigen values, vectors
    lambda_, V = torch.linalg.eig(Ap)

    return lambda_, p, V


def p_lambda(n):
    lambda_, p, V = hippo_dplr(n)
    Vc = V.conj().T
    p = Vc @ p
    return [p, lambda_]


def cauchy_kernel(v, omega, lambda_):
    if v.ndim == 1:
        v = v.unsqueeze(0).unsqueeze(0)
    elif v.ndim == 2:
        v = v.unsqueeze(1)
    return (v/(omega-lambda_)).sum(dim=-1)


def causal_convolution(u, K):
    l_max = u.shape[1]  # u.shape = [batch, seq_length, d_model]
    
    # pad seq_length with l_max zeroes and compute fft
    ud = rfft(F.pad(u.float(), pad=(0, 0, 0, l_max, 0, 0)), dim=1)
    Kd = rfft(F.pad(K.float(), pad=(0, l_max)), dim=-1)
    
    # freq -> time domain
    return irfft(ud.transpose(-2, -1)*Kd)[..., :l_max].transpose(-2, -1).type_as(u)

# compute frequencies
def f_omega(l_max, dtype=torch.complex64):
    return torch.arange(l_max).type(dtype).mul(2j * torch.tensor(torch.pi) / l_max).exp()


class S4Layer(nn.Module):
    def __init__(self, d_model, n, l_max, time_delta=0.1):
        super().__init__()
        self.d_model = d_model
        self.n = n
        self.l_max = l_max
        self.time_delta = time_delta
        
        p, lambda_ = p_lambda(n)
        p = p.to(torch.complex64)
        lambda_ = lambda_.to(torch.complex64)
        self._p = nn.Parameter(view_as_real(p))
        self._lambda_ = nn.Parameter(view_as_real(lambda_).unsqueeze(0).unsqueeze(1))
        
        # make non trainable
        self.register_buffer(
            "omega",
            tensor=f_omega(self.l_max, dtype=torch.complex64),
        )
        self.register_buffer(
            "ifft_order",
            tensor=torch.tensor(
                [i if i == 0 else self.l_max-i for i in range(self.l_max)],
                dtype=torch.long,
            ),
        )
        
        B_init = torch.sqrt(2 * torch.arange(1, n+1, dtype=torch.float32) + 1.0)
        B_init = B_init.repeat(d_model, 1)
        self._B = nn.Parameter(
            view_as_real(B_init.to(torch.complex64))
        )
        self._Ct = nn.Parameter(
            view_as_real(
                nn.init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64))
            )
        )
        self.D = nn.Parameter(torch.ones(1, 1, d_model))
        self.log_step = nn.Parameter(log_step_init(torch.rand(d_model)))

    @property
    def p(self):
        return torch.view_as_complex(self._p)

    @property
    def lambda_(self):
        return torch.view_as_complex(self._lambda_)

    @property
    def B(self):
        return torch.view_as_complex(self._B)

    @property
    def Ct(self):
        return torch.view_as_complex(self._Ct)

    def roots(self):
        a0 = self.Ct.conj()
        a1 = self.p.conj().unsqueeze(0)
        b0 = self.B
        b1 = self.p.unsqueeze(0)
        step = self.log_step.exp()     

        # asynchronous discretization
        lambda_d = torch.exp(self.lambda_ * step.unsqueeze(1).unsqueeze(2) * self.time_delta)
        delta = (torch.exp(self.lambda_ * step.unsqueeze(1).unsqueeze(2)) - 1) / self.lambda_
        omega_z = self.omega.unsqueeze(0).unsqueeze(2)

        k00 = cauchy_kernel((a0 * b0).unsqueeze(1) * delta, omega_z, lambda_d)
        k01 = cauchy_kernel((a0 * b1).unsqueeze(1) * delta, omega_z, lambda_d)
        k10 = cauchy_kernel((a1 * b0).unsqueeze(1) * delta, omega_z, lambda_d)
        k11 = cauchy_kernel((a1 * b1).unsqueeze(1) * delta, omega_z, lambda_d)
        
        return k00 - k01 * (1.0 / (1.0 + k11)) * k10

    @property
    def K(self):
        at_roots = self.roots()
        out = ifft(at_roots, n=self.l_max, dim=-1)
        conv = torch.stack([i[self.ifft_order] for i in out]).real
        return conv.unsqueeze(0)

    def forward(self, u):
        return causal_convolution(u, K=self.K) + (self.D * u)


class S4Block(nn.Module):
    def __init__(
        self,
        d_model,
        n,
        l_max,
        dropout=0.0,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.s4 = S4Layer(d_model, n=n, l_max=l_max)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(d_model, d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.s4(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear(x)
        x = x + residual

        return x


class S4Model(nn.Module):
    def __init__(
        self,
        d_input,
        d_model,
        d_output,
        n_blocks,
        n,
        l_max,
        dropout = 0.0,
    ):
        super().__init__()
        self.d_input = d_input
        self.d_model = d_model
        self.d_output = d_output
        self.n_blocks = n_blocks
        
        self.encoder = nn.Linear(d_input, d_model)
        self.blocks = nn.ModuleList([
            S4Block(
                d_model=d_model,
                n=n,
                l_max=l_max,
                dropout=dropout,
            ) for _ in range(n_blocks)
        ])
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, u):
        x = self.encoder(u)
        for block in self.blocks:
            x = block(x)
        x = x.mean(dim=1)
        x = self.decoder(x)
        return x
    

### Main

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import wandb
from datetime import datetime


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wandb.login(key="2f3ffd7baf545af396e18e48bfa20b33d2609dcc")

def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0 - val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.Lambda(lambda x: x.reshape(3, -1).t())
])

trainset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=True, download=True, transform=transform)
trainset, _ = split_train_val(trainset, val_split=0.1)

valset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=True, download=True, transform=transform)
_, valset = split_train_val(valset, val_split=0.1)

testset = torchvision.datasets.CIFAR10(
    root='./data/cifar/', train=False, download=True, transform=transform)

d_input = 3
d_output = 10

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

model = S4Model(
    d_input=d_input,
    d_model=512,
    d_output=d_output,
    n_blocks=6,
    n=64,
    l_max=1024,
    dropout=0.2,
)

model.to(device)

num_epochs = 8
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


wandb.init(
    project="saidl-s4",
    name = f"s4_run_{datetime.now().strftime('%d%m_%H%M')}",
    config={
        "learning_rate": 0.001,
        "weight_decay": 0.01,
        "epochs": num_epochs,
        "batch_size": 64,
        "model_config": {
            "d_model": 128,
            "n_blocks": 6,
            "dropout": 0.2
        },
        "discretization": "async"
    }
)

def train_epoch(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(trainloader, desc='training')
    for i, (inputs, labels) in enumerate(pbar):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss = (running_loss * i + loss.item()) / (i + 1)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        current_acc = correct / total

        wandb.log({
            "train/batch_loss": loss.item(),
            "train/running_loss": running_loss,
            "train/running_acc": current_acc
        })

        pbar.set_postfix({
            'loss': f'{running_loss:.4f}',
            'acc': f'{correct/total:.2f}'
        })

    return running_loss, correct / total

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='evaluating')
        for i, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss = (running_loss * i + loss.item()) / (i + 1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': f'{running_loss:.4f}',
                'acc': f'{correct/total:.2f}'
            })

    return running_loss, correct / total

best_val_acc = 0.0
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, valloader, criterion, device)
    scheduler.step()

    wandb.log({
        "train/epoch_loss": train_loss,
        "train/epoch_acc": train_acc,
        "val/loss": val_loss,
        "val/acc": val_acc,
        "epoch": epoch
    })

    print(f'epoch: {epoch+1}/{num_epochs}')
    print(f'train loss: {train_loss:.4f} | train acc: {train_acc:.2f}')
    print(f'val loss: {val_loss:.4f} | val acc: {val_acc:.2f}')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')

    print('-' * 50)

model.load_state_dict(torch.load('best_model.pth'))
test_loss, test_acc = evaluate(model, testloader, criterion, device)
wandb.log({
    "test/loss": test_loss,
    "test/acc": test_acc
})
print(f'test loss: {test_loss:.4f} | test acc: {test_acc:.2f}')

wandb.finish()
