# Supervised classifier parameter sweep on synthetic data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

In [None]:
from autoplan.trainer import ClassifierTrainer, option_combinations
from autoplan.dataset import PrelabeledDataset, build_synthetic_dataset
from autoplan.generator import ProgramGenerator
from autoplan.vis import plot_accuracy, plot_cm, plot_loss
from autoplan.token import TokenType, PyretTokenizer, OCamlTokenizer
from scripts.rainfall_ingest import ingest_dataset

from grammars.rainfall.ocaml import Program
from grammars.rainfall.labels import GeneralRainfallLabels

from tqdm import tqdm_notebook as tqdm
import pandas as pd
import torch
import os
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
# import seaborn as sns

device = torch.device('cpu')
# device = torch.device('cuda:0')
REPO_DIR = os.path.expanduser('~/autoplan')

# Parameter Tuning

In [None]:
dataset_name = 'T1'
student_dataset = ingest_dataset(dataset_name)

In [None]:
sample_size = [
    ('N', [5, 10, 15, 20, 25, 30])
#     ('N', [100, 300, 500, 1000, 5000, 10000])
]

datasets = [
    (opts, build_synthetic_dataset(
                GeneralRainfallLabels,
                tokenizer=OCamlTokenizer(),
                generator=ProgramGenerator(grammar=Program(), adaptive=True),
                vocab_index=student_dataset.vocab_index,
                **opts))
    for opts in option_combinations(sample_size)
]

In [None]:
model_options = [
    ('dataset', datasets),
    ('model', [nn.RNN, nn.GRU, nn.LSTM]),
    ('hidden_size', [32, 128, 512, 2048]),
    ('embedding_size', [32, 128, 512, 2048]),
]

all_evals = [
    (opts, ClassifierTrainer.crossval(opts['dataset'][1], k=10, epochs=100, val_frac=0.33, device=device,
                                      model_opts={k: v for k, v in opts.items() if k != 'dataset'}))
    for opts in tqdm(option_combinations(model_options))
]

In [None]:
sorted([{
    'params': k, 
    'accuracy': np.mean(v['accuracy'])
} for k, v in all_evals], key=lambda t: -t['accuracy'])[:10]

In [None]:
# Add visualization tool 

In [None]:
# Add conclusion