In [57]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from itertools import product, chain
import importlib

import trainer, data, models
from data import gen_lin_data
from trainer import ModelingDataset, Trainer, Hyperparameters
from models import LinearRegression

In [12]:
importlib.reload(trainer)
importlib.reload(data)
importlib.reload(models)

<module 'models' from '/Users/ralph/projects/d2dl/models.py'>

In [13]:
train_X, train_y = gen_lin_data(torch.randn(3), 1.0, 8000)
val_X, val_y = gen_lin_data(torch.randn(3), 1.0, 2000)

In [14]:
train_dataset = TensorDataset(train_X, train_y)
val_dataset = TensorDataset(val_X, val_y)

In [75]:
def grid_search_params(hyperparam_grids: Hyperparameters): 
    keys = []
    grid_vals = []
    for namespace, hyperparams in hyperparam_grids.__dict__.items():
        for k, k_grid_vals in hyperparams.items():
            keys.append((namespace, k))
            grid_vals.append(k_grid_vals)

    all_val_combos = product(*grid_vals)

    all_hyperparams = []
    for all_val_combo in all_val_combos:
        hyperparams = Hyperparameters()
        for ((namespace, k), val) in zip(keys, all_val_combo):
            getattr(hyperparams, namespace)[k] = val

        all_hyperparams.append(hyperparams)

    return all_hyperparams

In [149]:
def grid_search(hyperparam_grid, trainer_provider):
    results = []
    for i, hyperparams in enumerate(hyperparam_grid):
        print(f"Progress: {i}/{len(hyperparam_grid)}")
        trainer = trainer_provider(hyperparams)

        train_res = trainer.train(10)
        results.append(train_res)

    hyperparam_i, min_val_idx, min_val_loss = min((
        (i, result["val"].idxmin(), result["val"].min())
        for i, result in enumerate(results)
    ), key=lambda t: t[2])

    return hyperparam_i, min_val_loss

In [148]:
hyperparam_grid = grid_search_params(Hyperparameters(
    opt=dict(
        lr=[1e-3, 2e-3, 3e-3, 4e-3, 5e-5],
        weight_decay=[0, 0.5, 1]
    ),
    general=dict(
        batch_size=[10, 25, 50],
        num_epochs=list(range(10, 51, 10))
    )
))

In [150]:
provider = lambda hyperparams: Trainer(
    model=LinearRegression(),
    dataset=ModelingDataset(train_dataset, val_dataset),
    loss=nn.MSELoss,
    opt=torch.optim.SGD,
    hyperparameters=hyperparams
)

In [None]:
res = grid_search(hyperparam_grid, provider)