In [62]:
from __future__ import annotations

from collections.abc import Iterable
from itertools import pairwise, chain
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision
import tqdm.auto as tqdm
from torch import nn
import math
from torch.distributions.multivariate_normal import MultivariateNormal

%env KERAS_BACKEND=torch

import keras
from keras import layers
import keras_tuner

env: KERAS_BACKEND=torch


In [2]:
# Move to project root
from pathlib import Path
import os

if not Path("./src/kernels").is_dir():
    for parent_path in Path.cwd().parents:
        if (parent_path / "src/kernels").is_dir():
            os.chdir(parent_path)
            break
    else:
        raise FileNotFoundError("Can't find project root")

assert Path("./src/kernels").is_dir()

In [73]:
from src import kernels, convolutions
from src.models import lenet_like
from src import load_data

In [5]:
tuner = keras_tuner.GridSearch(
    hypermodel=lenet_like.lenet_like(img_channels=1, num_classes=10),
    objective="val_accuracy",
    max_trials=0,
    overwrite=False,
    directory="checkpoints",
    project_name="basic_kmnist",
)
tuner

Reloading Tuner from checkpoints/basic_kmnist/tuner0.json


<keras_tuner.src.tuners.gridsearch.GridSearch at 0x78f2e09d5730>

In [63]:
from typing import NamedTuple


class TrialWrapper(NamedTuple):
    trial: keras_tuner.engine.trial.Trial
    trial_id: str
    hyperparameters: keras_tuner.HyperParameters
    metrics: dict[str, float]
    score: float
    best_step: int
    loader: ModelLoader = None

    @classmethod
    def from_trial(cls, trial: keras_tuner.engine.trial.Trial, model_loader: ModelLoader = None):
        metrics = {name: hist.get_best_value() for name, hist in trial.metrics.metrics.items()}
        return cls(trial, trial.trial_id, trial.hyperparameters, metrics, trial.score, trial.best_step, model_loader)

    def load_model(self) -> keras.Model:
        assert self.loader and self.trial
        return self.loader.tuner.load_model(self.trial)

In [67]:
class ModelLoader:
    def __init__(self, project_name: str, hypermodel: Callable[[keras_tuner.HyperParameters], keras.Model],
                 tuner_kind='grid'):
        if tuner_kind == 'grid':
            tuner_cls = keras_tuner.GridSearch
        else:
            raise ValueError(f"Unknown {tuner_kind=}")
        self.tuner: keras_tuner.Tuner = tuner_cls(
            hypermodel=hypermodel,
            objective="val_accuracy",
            max_trials=0,
            overwrite=False,
            directory="checkpoints",
            project_name=project_name,
        )
        self.trials = self.tuner.oracle.get_best_trials(-1)
        self.all_params = set()

        for trial in self.trials:
            self.all_params.update(trial.hyperparameters.values)

    @staticmethod
    def nonconflicting(params_a: dict, params_b: dict):
        for key in set(params_a).intersection(params_b):
            if params_a[key] != params_b[key]:
                return False
        return True

    def find(self, params: dict) -> TrialWrapper:
        unknown_keys = set(params).difference(self.all_params)
        assert not unknown_keys, f"{unknown_keys=}"
        for trial in self.trials:
            if self.nonconflicting(params, trial.hyperparameters.values):
                return TrialWrapper.from_trial(trial, self)

        raise ValueError(f"Params {params=} are conflicting or were not tested")



In [68]:
loader = ModelLoader("basic_kmnist", lenet_like.lenet_like(1, 10))

Reloading Tuner from checkpoints/basic_kmnist/tuner0.json


In [78]:
qt = loader.find({'quadratic-pool-kernel': 'quadratic-multi'})
print(qt.metrics)
qm = qt.load_model()

{'accuracy': np.float64(0.996492067972819), 'loss': np.float64(0.014133953334142765), 'top3': np.float64(0.9997539718945821), 'val_accuracy': np.float64(0.9600555499394735), 'val_loss': np.float64(0.49338407317797345), 'val_top3': np.float64(0.9908333420753479)}


In [72]:
qm.summary()

In [74]:
kmn = load_data.k_mnist()
kmn

Dataset(x_train=(60000, 28, 28), x_test=(10000, 28, 28), y_train=(60000,), y_test=(10000,))

In [76]:
qm.evaluate(kmn.x_test, kmn.y_test, batch_size=2 ** 10, return_dict=True)

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 63ms/step - accuracy: 0.9007 - loss: 1.3936 - top3: 0.9703


{'accuracy': 0.9010000228881836,
 'loss': 1.4052506685256958,
 'top3': 0.9699000120162964}

In [79]:
qtiso = loader.find({'quadratic-pool-kernel': 'quadratic-iso'})
print(qtiso.metrics)
qmiso = qtiso.load_model()

{'accuracy': np.float64(0.9967222213745117), 'loss': np.float64(0.011933539683620134), 'top3': np.float64(0.9997936487197876), 'val_accuracy': np.float64(0.9691296418507894), 'val_loss': np.float64(0.1483689248561859), 'val_top3': np.float64(0.9940184950828552)}


  saveable.load_own_variables(weights_store.get(inner_path))


In [80]:
qmiso.evaluate(kmn.x_test, kmn.y_test, batch_size=2 ** 10, return_dict=True)

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 70ms/step - accuracy: 0.9267 - loss: 0.3576 - top3: 0.9827


{'accuracy': 0.9254999756813049,
 'loss': 0.3601137697696686,
 'top3': 0.9805999994277954}