In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms.v2 as v2
import timm
from functools import partial
import gc
import falkon

gc.collect()
torch.cuda.empty_cache()
gc.collect()

to_tensor=v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True)
])

tf = v2.Compose([
    v2.RGB(),
    #v2.Resize(224),
    v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])

mnist_train = datasets.MNIST('/data', train=True, download=True,
                   transform=to_tensor)
mnist_test = datasets.MNIST('/data', train=False,
                   transform=to_tensor)

loader = partial(
    torch.utils.data.DataLoader,
    num_workers=6,
    pin_memory=False,
    shuffle=False,
    drop_last=False,
    batch_size=16,
)

resnet18 = timm.create_model('resnet18', pretrained=True).to('cuda').eval()
#train_imgs, train_labels = list(zip(*mnist_train))
#test_imgs, test_labels = list(zip(*mnist_test))


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.multiprocessing.set_sharing_strategy('file_system')
gc.collect()
torch.cuda.empty_cache()
gc.collect()
train_feats = []
train_imgs = []
train_labels = []
for i, (imgs, labs) in enumerate(loader(dataset=mnist_train)):
    train_imgs.append(imgs)
    train_labels.append(labs)
    x = tf(imgs).to('cuda')
    feats = resnet18.forward_intermediates(x, indices=1, intermediates_only=True)
    feats = feats[0].cpu().detach()
    train_feats.append(feats)
    if i % 500 == 0:
        print(i)
train_feats = torch.cat(train_feats, dim=0).squeeze()
train_x_imgs = torch.vstack(train_imgs)
train_x_imgs = train_x_imgs.reshape(train_x_imgs.shape[0], -1)
train_y = torch.cat(train_labels)
A = torch.eye(10, dtype=torch.float32)
train_y_onehot = A[train_y.to(torch.long), :]
print(train_x_imgs.shape, train_y_onehot.shape)

0
500
1000
1500
2000
2500
3000
3500
torch.Size([60000, 784]) torch.Size([60000, 10])


In [3]:
torch.multiprocessing.set_sharing_strategy('file_system')
gc.collect()
torch.cuda.empty_cache()
gc.collect()
test_feats = []
test_imgs = []
test_labels = []
for i, (imgs, labs) in enumerate(loader(dataset=mnist_test)):
    test_imgs.append(imgs)
    test_labels.append(labs)
    x = tf(imgs).to('cuda')
    feats = resnet18.forward_intermediates(x, indices=1, intermediates_only=True)
    feats = feats[0].cpu().detach()
    test_feats.append(feats)
    if i % 500 == 0:
        print(i)
test_feats = torch.cat(test_feats, dim=0).squeeze()
test_x_imgs = torch.vstack(test_imgs)
test_x_imgs = test_x_imgs.reshape(test_x_imgs.shape[0], -1)
test_y = torch.cat(test_labels)
A = torch.eye(10, dtype=torch.float32)
test_y_onehot = A[test_y.to(torch.long), :]
print(test_x_imgs.shape, test_y_onehot.shape)

0
500
torch.Size([10000, 784]) torch.Size([10000, 10])


In [4]:
def classif_error(y_true, y_pred):
    y_true = torch.argmax(y_true, dim=1)
    y_pred = torch.argmax(y_pred, dim=1)
    err = y_true.flatten() != y_pred.flatten()
    return torch.mean(err.to(torch.float32))

options = falkon.FalkonOptions(use_cpu=True)
kernel = falkon.kernels.GaussianKernel(sigma=15)
flk = falkon.Falkon(kernel=kernel,
                    penalty=1e-8,
                    M=1000,
                    maxiter=10,
                    options=options,
                    error_every=1,
                    error_fn=classif_error)

In [5]:
_ = flk.fit(train_feats, train_y_onehot)

train_pred_from_feats = flk.predict(train_feats)
print("Training error: %.2f%%" % (classif_error(train_pred_from_feats, train_y_onehot) * 100))

test_pred_from_feats = flk.predict(test_feats)
print("Test error: %.2f%%" % (classif_error(test_pred_from_feats, test_y_onehot) * 100))

Iteration   1 - Elapsed 0.84s - training error: tensor(0.2089)
Iteration   2 - Elapsed 1.18s - training error: tensor(0.1660)
Iteration   3 - Elapsed 1.66s - training error: tensor(0.1537)
Iteration   4 - Elapsed 2.02s - training error: tensor(0.1482)
Iteration   5 - Elapsed 2.37s - training error: tensor(0.1432)
Iteration   6 - Elapsed 2.71s - training error: tensor(0.1409)
Iteration   7 - Elapsed 3.07s - training error: tensor(0.1396)
Iteration   8 - Elapsed 3.40s - training error: tensor(0.1358)
Iteration   9 - Elapsed 3.81s - training error: tensor(0.1349)
Iteration  10 - Elapsed 4.60s - training error: tensor(0.1335)
Training error: 13.35%
Test error: 13.47%


In [6]:
_ = flk.fit(train_x_imgs, train_y_onehot)

train_pred_from_imgs = flk.predict(train_x_imgs)
print("Training error: %.2f%%" % (classif_error(train_pred_from_imgs, train_y_onehot) * 100))

test_pred_from_imgs = flk.predict(test_x_imgs)
print("Test error: %.2f%%" % (classif_error(test_pred_from_imgs, test_y_onehot) * 100))

Iteration   1 - Elapsed 0.69s - training error: tensor(0.1103)
Iteration   2 - Elapsed 1.12s - training error: tensor(0.0694)
Iteration   3 - Elapsed 1.56s - training error: tensor(0.0570)
Iteration   4 - Elapsed 2.00s - training error: tensor(0.0512)
Iteration   5 - Elapsed 2.43s - training error: tensor(0.0481)
Iteration   6 - Elapsed 2.87s - training error: tensor(0.0462)
Iteration   7 - Elapsed 3.30s - training error: tensor(0.0443)
Iteration   8 - Elapsed 3.76s - training error: tensor(0.0431)
Iteration   9 - Elapsed 4.20s - training error: tensor(0.0428)
Iteration  10 - Elapsed 5.14s - training error: tensor(0.0421)
Training error: 4.21%
Test error: 4.37%
