# Purpose

This notebook is to test implementation of Ensemble, Flipout and DUQ models.
- Question: Can the ensembles use dropout and dropconnect weights? Or are they to be trained without either.

In [1]:
# Model for ensemble:
# model = DeepEnsembleClassifier(lambda: build_standard_model(dropout_best_hps), num_estimators=10)     100 epochs
# DUQ: 200 epochs
# Flipout: 200 epochs

In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [3]:
import sys, os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping
from models_bachelors import *
from file_functions import *
import tensorflow as tf
import keras_tuner as kt
from keras_uncertainty.models import StochasticClassifier

In [4]:
'''
Load best hyperparams
'''


n_epochs= 200
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
subject_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8]
'''
Load data
'''
dataset = load('all_subjects_runs_no_bandpass')
lockbox = load('lockbox')['data']
loaded_inputs = dataset['inputs']
loaded_targets = dataset['targets']

In [6]:
hp = load_tuned_flipout()

Reloading Tuner from flipout/tuning/flipout_flipout_classification_layer/tuner0.json


In [None]:
'''
Training Loop
'''
for method in methods:
    directory = f'{method}/weights'
    # This loop leaves one subject for testing (denoted by the number in the name of the weights file).
    # Then it combines all the subject trials such that shape is now (8 * 576, 22, 1125).
    # Then selects 10% of this as the validation set. Then it trains diff. model on each set of train subjects.
    for test_subject_id in subject_ids:
        train_ids = subject_ids[:]
        train_ids.remove(test_subject_id)       # Remove test subject id
        test_subj_lockbox = lockbox[test_subject_id]        # Get lockbox indexes (8, 57) for the test subject
        inputs = loaded_inputs[train_ids]           # Get train set inputs
        targets = loaded_targets[train_ids]         # Get train set targets
        inputs, targets = remove_lockbox(inputs, targets, test_subj_lockbox)    # Remove lockboxed set from train set
        X_train, X_val, Y_train, Y_val = train_test_split(inputs, targets,test_size=0.1)       
        model = build_flipout_model(X_train)
        history = model.fit(X_train, Y_train, epochs=n_epochs, validation_data=[X_val, Y_val],
                        callbacks=[early_stopping])
        model.save_weights(f'{directory}_subj_{test_subject_id}')

# Predicting

In [None]:

# For each iteration, store results dict into a 
for iteration in range(0, 1):
    # For each method, get preds and labels for each test subject
    # and their corresponding lockbox set.
    # predictions = {'standard': 
    #             {'test': {'preds':[], 'labels':[]}, 
    #                 'lockbox':{'preds':[], 'labels':[]}},
    #                 'standard_dropconnect': 
    #             {'test': {'preds':[], 'labels':[]}, 
    #                 'lockbox':{'preds':[], 'labels':[]}}
    #             }
    # predictions = {'duq': 
    #             {'test': {'preds':[], 'labels':[]}, 
    #                 'lockbox':{'preds':[], 'labels':[]}}
    #             }
    predictions = {'flipout': 
                {'test': {'preds':[], 'labels':[]}, 
                    'lockbox':{'preds':[], 'labels':[]}}
                }

    for method, values in predictions.items():
        print(f'{method}')
        if method == 'standard':
            wts_directory = f'mcdropout/weights'
        elif method == 'standard_dropconnect':
            wts_directory = f'mcdropconnect/weights'
        else:
            wts_directory = f'{method}/weights'
        # Iterate through test subjects
        for test_subject_id in range(0, 9):
            print(f'test subject {test_subject_id}')
            train_subj_ids = [x for x in subject_ids if x != test_subject_id]
            X_test = loaded_inputs[test_subject_id]
            Y_true = loaded_targets[test_subject_id]
            # Train set is sent in because lockbox is returned from the train set not the whole dataset.
            # This is because lockbox shape: (9, 8, 57) and inputs shape: (9, 576, 22, 1125)
            # Axis 0 are test_subj_ids and axis 1 are the train_subject_ids.
            # The function assumes that shape[0] of lockbox[test_subj_id] and shape[0] of
            # inputs is the same: 8.
            X_lock, Y_lock = get_lockbox_data(loaded_inputs[train_subj_ids], loaded_targets[train_subj_ids], lockbox[test_subject_id])
            wts_path = checkpoint_path = f'{wts_directory}/weights_subj_{test_subject_id}'
            if method == 'mcdropout':
                model = build_dropout_model(dropout_best_hps)
            elif method == 'mcdropconnect':
                model = build_dropconnect_model(dropconnect_best_hps)
            elif method == 'standard_dropconnect':
                model = build_standard_model_dropconnect(dropconnect_best_hps)
            elif method == 'duq':
                model = build_duq_model(hp)
            elif method == 'flipout':
                model = build_flipout_model(hp)
            else:
                model = build_standard_model(dropout_best_hps)
            
            model.load_weights(wts_path).expect_partial()
            # Get Y_preds for test subject
            if method in ['mcdropconnect', 'mcdropout', 'flipout']:
                model = StochasticClassifier(model)
                Y_preds = model.predict_samples(X_test, num_samples=50)
                # Get lockboxed Y_preds for test subject
                lockbox_Y_preds = model.predict_samples(X_lock, num_samples=50)
            else:
                Y_preds = model.predict(X_test)
                # Get lockboxed Y_preds for test subject
                lockbox_Y_preds = model.predict(X_lock)

            lockbox_Y_true = Y_lock
            values['test']['preds'].append(Y_preds)
            values['test']['labels'].append(Y_true)
            values['lockbox']['preds'].append(lockbox_Y_preds)
            values['lockbox']['labels'].append(lockbox_Y_true)

    for method, values in predictions.items():
        values['test']['preds'] = np.array(values['test']['preds'])
        values['test']['labels'] = np.array(values['test']['labels'])
        values['lockbox']['preds'] = np.array(values['lockbox']['preds'])
        values['lockbox']['labels'] = np.array(values['lockbox']['labels'])

    dict2hdf5(f'predictions/predictions_flipout.h5', predictions)