In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import os
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from subtle.dnn.generators import GeneratorUNet2D
from subtle.data_loaders import SliceLoader
import subtle.subtle_loss as suloss
import matplotlib.pyplot as plt

import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.keras import TuneReporterCallback

ray.init()
plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (12, 10)

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def show_img(img, title='', axis=False, vmin=None, vmax=None):
    imshow_args = {}
    
    if vmin:
        imshow_args['vmin'] = vmin
    if vmax:
        imshow_args['vmax'] = vmax
    
    im_axis = 'on' if axis else 'off'
    plt.axis(im_axis)
    plt.imshow(img, **imshow_args)
    plt.title(title, fontsize=15)

In [None]:
fpaths_h5 = [
    '/home/srivathsa/projects/studies/gad/tiantan/preprocess/data/NO26.h5',
    '/home/srivathsa/projects/studies/gad/tiantan/preprocess/data/NO27.h5'
]

def train(params, reporter):
    l1_w = params['l1_lambda']
    ssim_w = 1 - l1_w
    loss_function = suloss.mixed_loss(l1_lambda=l1_w, ssim_lambda=ssim_w)
    metrics_monitor = [suloss.l1_loss, suloss.ssim_loss, suloss.mse_loss]
    
    data_loader = SliceLoader(
        data_list=fpaths_h5, batch_size=params['batch_size'], shuffle=False, verbose=0,
        slices_per_input=7, resize=240, slice_axis=[0]
    )
    
    model = GeneratorUNet2D(
        num_channel_output=1,
        loss_function=loss_function, metrics_monitor=metrics_monitor,
        verbose=0, lr_init=params['lr_init'],
        img_rows=240, img_cols=240, num_channel_input=14, compile_model=True
    )
    model.load_weights()

    train_X, train_Y = data_loader.__getitem__(7)
    val_X, val_Y = data_loader.__getitem__(14)

    model.model.fit(train_X, train_Y, validation_data=(val_X, val_Y), verbose=0, 
                    callbacks=[TuneReporterCallback(reporter)])
#     print(hist)
#     return {'loss': hist.history['val_l1_loss'][0], 'status': STATUS_OK}

In [None]:
# trials = Trials()
# space = {
#     'l1_lambda': hp.uniform('l1_lambda', 0, 1),
#     'ssim_lambda': hp.uniform('ssim_lambda', 0, 1),
#     'batch_size': hp.choice('batch_size', [4, 8, 12]),
#     'lr_init': hp.uniform('lr_init', 0.001, 0.1)
# }

# best_model = fmin(train, space, algo=tpe.suggest, max_evals=10, verbose=1, trials=trials)


sched = AsyncHyperBandScheduler(
        metric="val_l1_loss",
        mode="min")

tune.run(
    train,
    name="exp",
    num_samples=2,
    resources_per_trial={
        "cpu": 0,
        "gpu": 1
    },
    config={
        "num_workers": 0,
        "l1_lambda": tune.sample_from(lambda spec: np.random.uniform(0, 1)),
        "batch_size": tune.sample_from(lambda spec: [4, 8, 12][np.random.randint(3)]),
        "lr_init": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1))
    })