In [None]:
import torch
import os
from utils import set_up_logging
from training import train, random_hparam_search
from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH

set_up_logging(LOGS_PATH)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
f"Using device {device}"

In [None]:
# hparams = {
#     "batch_size": 64,
#     "edit_count": 12,
#     "bin_count": 16,
#     "learning_rate": 0.001,
#     "scheduler_gamma": 0.9,
#     "num_epochs": 12,
#     "model_type": "SimpleCNN",
# }

# train(
#     hparams,
#     train_data_paths=TRAIN_DATA,
#     test_data_paths=TEST_DATA,
#     log_dir=RUNS_PATH,
#     max_duration=None,
#     use_tqdm=True,
#     device=device,
#     **hparams
# )

In [None]:
from scipy.stats import loguniform, uniform, randint
from models import MODELS, test_models


hyperparameters = [
    {
        "batch_size": [32, 64, 128],
        "edit_count": [12],
        "bin_count": [16],
        "learning_rate": loguniform(5e-4, 5e-3),
        "scheduler_gamma": uniform(loc=0.8, scale=0.15),
        "num_epochs": [12],
        "elu_alpha": uniform(0.5, 1.5),
        "leaky_relu_slope": uniform(0, 0.03),
        "dropout_prob": uniform(0, 0.1),
        "features": [
            [16, 32],
            [16, 32, 64],
            [16, 32, 64, 128],
            [32, 64],
            [32, 64, 128],
            [8, 16, 32],
            [8, 8, 8],
            [8, 8, 8, 8, 8],
            [8, 8, 8, 8, 8, 8, 8],
            [16, 16, 16, 16, 16],
            [16, 16, 16],
            [32, 32],    
            [32, 32, 32],
            [32, 32, 32, 32],
            [64, 64],
            [64, 64, 64]
        ],
        "use_residual": [True, False],
        "kernel_size": [3, 5],
        "model_type": ["HistogramNet"],
        "use_instance_norm": [True, False],
        "use_elu": [True, False],
        "leaky_relu_alpha": uniform(0, 0.05),
    }
]

test_models()

random_hparam_search(
    hyperparameters=hyperparameters,
    train_data_paths=TRAIN_DATA,
    test_data_paths=TEST_DATA,
    models_path=MODELS_PATH,
    tensorboard_path=RUNS_PATH,
    timeout_hours=4,
    device=device,
)