# This script trains and predicts the Gaussian naive Bayes classifiers

In [1]:
import vae_tools.sanity
import vae_tools.viz
import vae_tools.callbacks
import vae_tools.loader
from vae_tools.mmvae import MmVae, ReconstructionLoss
from tensorflow.keras.optimizers import Adam, Nadam, RMSprop
vae_tools.sanity.check()
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Lambda, Layer
from tensorflow.keras.datasets import mnist
import numpy as np
from scipy.stats import norm
# Set the seed for reproducible results
import vae_tools.sampling
vae_tools.sampling.set_seed(0)
# resize the notebook if desired
#vae_tools.nb_tools.notebook_resize()
import matplotlib
import matplotlib.pyplot as plt
from itertools import product
import pandas as pd
import glob
import pickle


python version:  3.5.2
keras version: 2.2.4-tf
tensorflow version: 2.0.2
matplotlib uses:  module://ipykernel.pylab.backend_inline
No GPUs available


## Train the bayes classifiers

In [None]:

def eval_bayes_classifier(z_train, z_test, y_train, y_test = None):
    from sklearn.naive_bayes import GaussianNB
    # Train the GNB on the training data
    gnb = GaussianNB().fit(z_train, y_train)
    # Predict the test data using the GNB
    y_pred_ab = gnb.predict(z_test['z_test_ab'])
    y_pred_a = gnb.predict(z_test['z_test_a'])
    y_pred_b = gnb.predict(z_test['z_test_b'])

    if y_test is None:
        print("Test on %d points." %(y_test.shape[0]))
        print("mislabeled in ab : %d" % ((y_test != y_pred_ab).sum()))
        print("mislabeled in a  : %d" % ((y_test != y_pred_a).sum()))
        print("mislabeled in b  : %d" % ((y_test != y_pred_b).sum()))

    return y_pred_ab, y_pred_a, y_pred_b

# Get the models and predict all data
def predict(model_path, x_train = None, x_test = None):
    num_models = 2
    model_enc, _ = vae_tools.vae.GenericVae.load_model_powerset(model_path, num_models)
    model_enc_a, model_enc_b, model_enc_ab = model_enc[0], model_enc[1], model_enc[2]
    z_train, z_test = None, None
    if x_train is not None:
        z_train_ab = model_enc_ab.predict(x_train)
        z_train_a = model_enc_a.predict(x_train[0])
        z_train_b = model_enc_b.predict(x_train[1])
        z_train = (z_train_a, z_train_b, z_train_ab)
    if x_test is not None:
        z_test_ab = model_enc_ab.predict(x_test)
        z_test_a = model_enc_a.predict(x_test[0])
        z_test_b = model_enc_b.predict(x_test[1])
        z_test = (z_test_a, z_test_b, z_test_ab)
    # Cleanup
    for m in model_enc:
        del m
    tf.keras.backend.clear_session()
    return z_train, z_test

def run(seed = '0'):
    (x_train_a, x_train_b), (x_test_a, x_test_b), y_train, y_test = vae_tools.loader.mnist_split(flatten = True, split = 'hor')
    dump_loc = '/mnt/ssd_pcie/mmvae_mnist_split/' + seed + '/'
    import time
    # Process all 60 hyper parameter configurations
    for idx in range(60):
        start = time.time()
        z_train_mean, z_test_mean = predict(dump_loc + 'enc_mean_' + str(idx) + '_ab_', x_train = [x_train_a, x_train_b], x_test = [x_test_a, x_test_b])
        #z_train_logvar, z_test_logvar = predict(dump_loc + 'enc_logvar_' + str(idx) + '_ab_', x_train = [x_train_a, x_train_b], x_test = [x_test_a, x_test_b])

        z_train_a, z_train_b, z_train_ab = z_train_mean[0], z_train_mean[1], z_train_mean[2]
        z_test_a, z_test_b, z_test_ab = z_test_mean[0], z_test_mean[1], z_test_mean[2]

        #z_train_logvar_a, z_train_logvar_b, z_train_logvar_ab = z_train_logvar[0], z_train_mean[1], z_train_logvar[2]
        #z_test_logvar_a, z_test_logvar_b, z_test_logvar_ab = z_test_logvar[0], z_test_logvar[1], z_test_logvar[2]

        z_test_all = {'z_test_ab': z_test_ab, 'z_test_a': z_test_a, 'z_test_b': z_test_b}

        y_pred_ab_ab, y_pred_ab_a, y_pred_ab_b = eval_bayes_classifier(z_train_ab,
                          z_test_all,
                          y_train)

        y_pred_a_ab, y_pred_a_a, y_pred_a_b = eval_bayes_classifier(z_train_a,
                          z_test_all,
                          y_train)

        y_pred_b_ab, y_pred_b_a, y_pred_b_b = eval_bayes_classifier(z_train_b,
                          z_test_all,
                          y_train)
        end = time.time()
        print(end - start)

        # Store with scheme: <trained on dataset>_<validated on dataset>
        y_pred = {'ab_ab': y_pred_ab_ab,
        'ab_a': y_pred_ab_a,
        'ab_b': y_pred_ab_b,
        'a_ab': y_pred_a_ab,
        'a_a':  y_pred_a_a,
        'a_b':  y_pred_a_b,
        'b_ab': y_pred_b_ab,
        'b_a':  y_pred_b_a,
        'b_b':  y_pred_b_b}

        fn = dump_loc + 'bayes_classifier_' + str(idx) + '.p'
        print("write " + fn)
        with open(fn, 'wb') as handle:
            pickle.dump(y_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)





