In [2]:
from collections.abc import Iterable
from itertools import pairwise, chain
from typing import Callable

import matplotlib.pyplot as plt
import torch
import torchvision
import tqdm.auto as tqdm
from torch import nn

%env KERAS_BACKEND=torch

import keras
from keras import layers
import keras_tuner

torch.set_float32_matmul_precision('high')

env: KERAS_BACKEND=torch


In [3]:
from pathlib import Path
import os

if not Path("./src/kernels").is_dir():
    for parent_path in Path.cwd().parents:
        if (parent_path / "src/kernels").is_dir():
            os.chdir(parent_path)
            break
    else:
        raise FileNotFoundError("Can't find project root")

assert Path("./src/kernels").is_dir()

In [4]:
from src import kernels, convolutions
from src.models import simple_lenet
from src import load_data

In [40]:
import importlib

# from src.kernels import quadratic as quad_kernels
import src.models.simple_lenet
import src.load_data

importlib.reload(src.models.simple_lenet)
importlib.reload(src.load_data)
from src.models import simple_lenet
from src import load_data
# kernels = importlib.reload(kernels)
# convolutions = importlib.reload(convolutions)

In [24]:
k_mnist = load_data.k_mnist()
k_mnist

normalisation: tensor([[[48.8993]]]) tensor([[[88.8274]]])


Dataset(x_train=(60000, 1, 28, 28), x_test=(10000, 1, 28, 28), y_train=(60000,), y_test=(10000,))

In [41]:
test_model = simple_lenet.LeNet(img_channels=1, num_classes=10,
                                pool_fn=simple_lenet.LENET_POOLING_FUNCTIONS["aniso"]).to('cuda')
test_model

LeNet(
  (net): Sequential(
    (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): GenericConv2D(
      (kernel): QuadraticKernelCholesky2D(
        20, 20, kernel_size=3
        (covs): LearnedCholesky2D(20, 20)
      )
      (conv): SelectConvFixedLazy()
    )
    (3): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): GenericConv2D(
      (kernel): QuadraticKernelCholesky2D(
        50, 50, kernel_size=3
        (covs): LearnedCholesky2D(50, 50)
      )
      (conv): SelectConvFixedLazy()
    )
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): LazyLinear(in_features=0, out_features=500, bias=True)
    (8): ReLU()
    (9): Linear(in_features=500, out_features=10, bias=True)
  )
)

In [42]:
type(k_mnist.x_train[:256])

numpy.ndarray

In [43]:
kmodel = test_model.to_keras(k_mnist)
kmodel

<Sequential name=sequential_4, built=True>

In [44]:
kmodel.summary()

In [45]:
test_run_pinned = torch.as_tensor(k_mnist.x_train[:256]).pin_memory()
test_run_cuda = torch.as_tensor(k_mnist.x_train[:256]).cuda()

In [48]:
# %timeit kmodel(test_run_pinned.cuda())
%timeit kmodel(test_run_cuda)

789 μs ± 908 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


1.13 ms ± 24.9 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

806 μs ± 491 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [47]:
# %timeit test_model(test_run_pinned.cuda())
%timeit test_model(test_run_cuda)

663 μs ± 43.7 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


941 μs ± 132 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

671 μs ± 76.2 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [14]:
%timeit kmodel(k_mnist.x_train[:256])
%timeit kmodel(test_run_pinned)

1.14 ms ± 550 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.04 ms ± 179 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
%time kmodel.fit(k_mnist.x_train, k_mnist.y_train, batch_size=256, verbose=False)

CPU times: user 1.34 s, sys: 20.7 ms, total: 1.36 s
Wall time: 943 ms


<keras.src.callbacks.history.History at 0x766dfc2280e0>

In [49]:
xtrp = torch.as_tensor(k_mnist.x_train, device='cuda')
ytrp = torch.as_tensor(k_mnist.y_train, device='cuda')

In [60]:
%time hist = kmodel.fit(xtrp, ytrp, verbose=False, batch_size=256, epochs=5)

CPU times: user 5.54 s, sys: 109 ms, total: 5.65 s
Wall time: 4.19 s


In [59]:
def minitrain(epochs: int = 5):
    o = torch.optim.Adam(test_model.parameters())
    for _ in range(epochs):
        for xb, yb in k_mnist.train_loader(batch_size=256):
            o.zero_grad()
            res = test_model(xb.cuda())
            nn.functional.cross_entropy(res, yb.cuda()).backward()
            o.step()

%time minitrain()

CPU times: user 2min 15s, sys: 40.5 ms, total: 2min 15s
Wall time: 5.91 s


In [107]:
import src.kernels.utils
import src.kernels.quadratic
import src.kernels

importlib.reload(src.kernels.utils)
importlib.reload(src.kernels.quadratic)
importlib.reload(src.kernels)

<module 'src.kernels' from '/home/peter/Thesis/src/kernels/__init__.py'>

In [120]:
from src.kernels.utils import LearnedCholesky2D, LearnedSpectral2D
from src.kernels import QuadraticKernelCholesky2D, QuadraticKernelSpectral2D

k1 = LearnedCholesky2D(50, 60, init="normal")
k2 = LearnedSpectral2D(50, 60, init="normal")
print(k1, k2)
q1_3 = QuadraticKernelCholesky2D(50, 60, 3, init="normal")
q1_5 = QuadraticKernelCholesky2D(50, 60, 5, init="normal")
q1_10 = QuadraticKernelCholesky2D(50, 60, 10, init="normal")
q2_3 = QuadraticKernelSpectral2D(50, 60, 3, init="normal")
q2_5 = QuadraticKernelSpectral2D(50, 60, 5, init="normal")
q2_10 = QuadraticKernelSpectral2D(50, 60, 10, init="normal")
print(q1_3, q2_3, q1_5, q2_5)

LearnedCholesky2D(50, 60) LearnedSpectral2D(50, 60)
QuadraticKernelCholesky2D(
  50, 60, kernel_size=3
  (covs): LearnedCholesky2D(50, 60)
) QuadraticKernelSpectral2D(
  50, 60, kernel_size=3
  (covs): LearnedSpectral2D(50, 60)
) QuadraticKernelCholesky2D(
  50, 60, kernel_size=5
  (covs): LearnedCholesky2D(50, 60)
) QuadraticKernelSpectral2D(
  50, 60, kernel_size=5
  (covs): LearnedSpectral2D(50, 60)
)


In [109]:
print(k1.cholesky().shape)
print(k2.inverse_cov().shape)

torch.Size([60, 50, 2, 2])
torch.Size([60, 50, 2, 2])


In [110]:
print(q1_3().shape)
print(q2_3().shape)
print(q1_5().shape)
print(q2_5().shape)

torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 5, 5])
torch.Size([60, 50, 5, 5])


In [121]:
q1_3.compile(fullgraph=True)
q2_3.compile(fullgraph=True)
q1_5.compile(fullgraph=True)
q2_5.compile(fullgraph=True)
q1_10.compile(fullgraph=True)
q2_10.compile(fullgraph=True)
print(q1_3().shape)
print(q2_3().shape)
print(q1_5().shape)
print(q2_5().shape)
print(q1_10().shape)
print(q2_10().shape)

torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 5, 5])
torch.Size([60, 50, 5, 5])
torch.Size([60, 50, 10, 10])
torch.Size([60, 50, 10, 10])


  check(


In [122]:
print(q1_3().shape)
print(q2_3().shape)
print(q1_5().shape)
print(q2_5().shape)
print(q1_10().shape)
print(q2_10().shape)

torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 3, 3])
torch.Size([60, 50, 5, 5])
torch.Size([60, 50, 5, 5])
torch.Size([60, 50, 10, 10])
torch.Size([60, 50, 10, 10])


In [123]:
q1_3.cpu(), q2_3.cpu(), q1_5.cpu(), q2_5.cpu(), q1_10.cpu(), q2_10.cpu()
%timeit q1_3()
%timeit q2_3()
%timeit q1_5()
%timeit q2_5()
%timeit q1_10()
%timeit q2_10()

826 μs ± 2.56 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
117 μs ± 280 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
607 μs ± 350 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
175 μs ± 256 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
1.05 ms ± 315 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
168 μs ± 240 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [126]:
q1_3.cuda(), q2_3.cuda(), q1_5.cuda(), q2_5.cuda(), q1_10.cuda(), q2_10.cuda()
%timeit q1_3()
%timeit q2_3()
%timeit q1_5()
%timeit q2_5()
%timeit q1_10()
%timeit q2_10()

70.1 μs ± 64.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
74.8 μs ± 168 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
70.3 μs ± 71.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
74.9 μs ± 143 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
69.6 μs ± 75.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
75.5 μs ± 71.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [128]:
q2_5.state_dict()

OrderedDict([('pos_grid',
              tensor([[-2., -2.],
                      [-2., -1.],
                      [-2.,  0.],
                      [-2.,  1.],
                      [-2.,  2.],
                      [-1., -2.],
                      [-1., -1.],
                      [-1.,  0.],
                      [-1.,  1.],
                      [-1.,  2.],
                      [ 0., -2.],
                      [ 0., -1.],
                      [ 0.,  0.],
                      [ 0.,  1.],
                      [ 0.,  2.],
                      [ 1., -2.],
                      [ 1., -1.],
                      [ 1.,  0.],
                      [ 1.,  1.],
                      [ 1.,  2.],
                      [ 2., -2.],
                      [ 2., -1.],
                      [ 2.,  0.],
                      [ 2.,  1.],
                      [ 2.,  2.]], device='cuda:0')),
             ('covs.log_stds',
              tensor([[[ 0.5216,  1.1183],
                       [ 0.829