In [None]:
import time
import os
import numpy as np
import seaborn as sns
import numpy.random as rng
import matplotlib.pyplot as plt

from tensorflow.keras.optimizers import Adam
from keras.optimizers import *
from keras import backend as K
K.set_image_data_format('channels_last')

from ml_siamese import OneshotLoader

%matplotlib inline
%load_ext autoreload
%reload_ext autoreload

In [None]:
input_image_types = [ "mel", "chroma", "mfcc", "sfft", "sfftchroma", "spectrogram" ]
# spectrogram_type = input_image_types [rng.randint(0, len(input_image_types)) ]
spectrogram_type = "old"
data_path = "../data/images/" + spectrogram_type + "/"

train_folder = "training"
val_folder = "test"

weights_path = os.path.join("../models", spectrogram_type + "_model_weights.h5")

evaluate_every = 100 # interval for evaluating on one-shot tasks
loss_every = 10 # interval for printing loss (iterations)
batch_size = 2
n_iterations = 1000
N_way = 8 # how many classes for testing one-shot tasks>
n_val = 7 # how many one-shot tasks to validate on?

param_loss_function = "binary_crossentropy"
param_optimizer = Adam(lr = 0.00006)

# Needed to fix some tensorflow compilation errors
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
loader = OneshotLoader(data_path)

In [None]:
# Intialize bias with mean 0.0 and standard deviation of 10^-2
weights = loader.initialize_weights((1000, 1))
sns.distplot(weights)
plt.title("Plot of weights initialized, with mean of 0.0 and standard deviation of 0.01")

In [None]:
# Intialize bias with mean 0.5 and standard deviation of 10^-2
bias = loader.initialize_bias((1000, 1))
sns.distplot(bias)
plt.title("Plot of biases initialized, with mean of 0.0 and standard deviation of 0.01")

In [None]:
model = loader.get_model((100, 100, 3))
model.compile(loss=param_loss_function,
            optimizer=param_optimizer,
            metrics = ['accuracy'])
model.summary()

In [None]:
X=loader.data["train"]
print(X.shape)
X

In [None]:
best = -1

print("Starting training process!")
print("-------------------------------------")
t_start = time.time()

for i in range(1, n_iterations):
    (inputs,targets) = loader.get_batch(batch_size)
    loss = model.train_on_batch(inputs,targets)
    print("\n ------------- \n")
    print("Loss: {0}".format(loss)) 
    
    if i % evaluate_every == 0:
        print("Time for {0} iterations: {1}".format(i, time.time()-t_start))
        val_acc = loader.test(model,N_way,n_val,verbose=True)
        if val_acc >= best:
            print("Current best: {0}, previous best: {1}".format(val_acc, best))
            print("Saving weights to: {0} \n".format(weights_path))
            model.save_weights(weights_path)
            best = val_acc
    
    # if i % loss_every == 0:
    #     print("iteration", i)
    #     print("training loss: ", loss)

model.load_weights(weights_path)

In [None]:
def nearest_neighbour_correct(pairs,targets):
    """
    Returns 1 if nearest neighbour gets the correct answer for a one-shot task
        given by (pairs, targets)
    """
    L2_distances = np.zeros_like(targets)
    for i in range(len(targets)):
        L2_distances[i] = np.sum(np.sqrt(pairs[0][i]**2 - pairs[1][i]**2))
    if np.argmin(L2_distances) == np.argmax(targets):
        return 1
    return 0


def test_nn_accuracy(N_ways,n_trials,loader):
    """
    Returns accuracy of one shot
    """
    print("Evaluating nearest neighbour on {} unique {} way one-shot learning tasks ...".format(n_trials,N_ways))

    n_right = 0
    
    for i in range(n_trials):
        pairs,targets = loader.make_task(N_ways,"val")
        correct = nearest_neighbour_correct(pairs,targets)
        n_right += correct
    return 100.0 * n_right / n_trials

ways = np.arange(1, 9, 1)
resume =  False
val_accs, train_accs,nn_accs = [], [], []
trials = 100
for N in ways:
    val_accs.append(loader.test(model, N, trials, "val", verbose=True))
    train_accs.append(loader.test(model, N, trials, "train", verbose=True))
    nn_accs.append(test_nn_accuracy(N,trials, loader))

plt.plot(ways, val_accs, "m")
plt.plot(ways, train_accs, "y")
plt.plot(ways, nn_accs, "c")

plt.plot(ways,100.0/ways,"r")
plt.show()
