## Grid Search

In [None]:
from metal.contrib.slicing.synthetics.geometric_synthetics import generate_dataset
from metal.utils import split_data
from metal.contrib.backends.snorkel_gm_wrapper import SnorkelLabelModel

NUM_TRIALS = 20
NUM_SLICES = 5
K = 2
M = 20
N = 5000
seed = None

Z_kwargs = {'num_slices': NUM_SLICES}
L, X, Y, Z, targeting_lfs_idx = generate_dataset(K, M, N, 
                                                 Z_kwargs=Z_kwargs,
                                                 return_targeting_lfs=True,
                                                 seed=seed)

Ls, Xs, Ys, Zs = split_data(L, X, Y, Z, splits=[0.5, 0.25, 0.25], shuffle=True)

label_model = SnorkelLabelModel()
label_model.train_model(Ls[0])
Y_train = label_model.predict_proba(Ls[0])
Ys[0] = Y_train

In [None]:
layer_out_dims = [2, 10, 10, 2]
search_space = {
    'n_epochs': [10, 20, 40],
    'lr': {'range': [0.001, 1], 'scale': 'log'} ,
    'l2': 0, #{'range': [0.0001, 10], 'scale': 'log'},
    'slice_weight': [0.01, 0.05, 0.1, 0.2, 0.5],
    # 'batch_size': None <-- This is handled in DataLoader
}

In [None]:
import torch
from torch.utils.data import DataLoader

from metal.tuners import RandomSearchTuner
from metal.contrib.slicing.online_dp import SliceHatModel
from metal.utils import SlicingDataset
from metal.end_model import EndModel

batch_size = 32
L_train = torch.Tensor(Ls[0].todense())
dataset = SlicingDataset(Xs[0], L_train, Ys[0])
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
end_model = EndModel(layer_out_dims, verbose=False)

tuner = RandomSearchTuner(SliceHatModel, log_dir='checkpoints')
_ = tuner.search(
    search_space,
    dev_data=(Xs[1], Ys[1]),
    max_search=10,
    init_args=[end_model, M],
    init_kwargs={'verbose': False},
    train_args=[train_loader],
    train_kwargs={'verbose': True, 'disable_prog_bar': True, 'print_every': 5},
    verbose=True,
)