In [1]:
import numpy as np
import matplotlib.pyplot as plt
from utils import load_dataset, cost, probas_pred, labels_pred, accuracy
from time import time
from skopt import gp_minimize

In [2]:
n_samples = 50
neg_images, pos_images = load_dataset(n_samples)

images = np.concatenate([neg_images, pos_images], axis=0)
targets = np.concatenate([np.zeros(n_samples), np.ones(n_samples)], axis=0)

In [None]:
times = [0.4, .5]
pulses = [.2, .16]

t0 = time()
cost(images, targets, times, pulses)
t1 = time()

In [None]:
t1-t0

In [None]:
import warnings
warnings.filterwarnings('ignore')

def get_score(param):
    middle = len(param)//2
    times = param[0:middle]
    pulses = param[middle::]
    return cost(images, targets, times, pulses)

bounds = [(0.02, 2.)] * 3 + [(0.02, .5)] * 3
    
opt_result = gp_minimize(
    get_score,
    bounds,
    acq_func='LCB',
    n_calls=60,
    n_initial_points=30,
    acq_optimizer='sampling',
    n_points=5000,
    kappa=4,
    n_jobs=-1)

In [8]:
middle = len(opt_result.x)//2
times = opt_result.x[0:middle]
pulses = opt_result.x[middle::]

preds = probas_pred(images, times, pulses)
preds = labels_pred(preds)

In [9]:
accuracy(preds, targets)

0.5

In [10]:
opt_result.func_vals

array([0.74007647, 0.73789975, 0.72705463, 0.72929529, 0.74506476,
       0.73838496, 0.72611979, 0.74343145, 0.72559518, 0.72697321,
       0.74130231, 0.73498838, 0.72917275, 0.74323319, 0.74107601,
       0.72485773, 0.74883014, 0.73576748, 0.74255955, 0.7376751 ,
       0.73382555, 0.74237822, 0.73110586, 0.7278939 , 0.72441937,
       0.74011531, 0.74720144, 0.72520462, 0.73362927, 0.72450188,
       0.72389355, 0.74533769, 0.72679333, 0.72495484, 0.72405248,
       0.73575743, 0.72624295, 0.73000888, 0.78479842, 0.74221572,
       0.72490883, 0.72708609, 0.73090508, 0.72531271, 0.72538393,
       0.73351552, 0.72559763, 0.73522148, 0.72501044, 0.72447329,
       0.72742944, 0.7240879 , 0.72519716, 0.72340346, 0.72500191,
       0.7277984 , 0.72976101, 0.72539308, 0.7243163 , 0.73240758])

In [11]:
opt_result.x

[0.299475525071482, 1.1637895578078652, 0.495942336643321, 0.03883882301248237]