# RSNA MICCAI Brain Tumor Radiogenomic Classification

[<img src="https://storage.googleapis.com/kaggle-competitions/kaggle/29653/logos/header.png?t=2021-07-07-17-26-56">](http://google.com.au/)

In this notebook I will try to classify the images using differente EfficientNet models. To deal with 3D data I will try several method:
* Using the maximum as its last layer
* Adding and attention layer
* Concatenating the neural network with a gaussian process

## Importing necessary libraries

In [None]:
!pip install tensorflow-gpu==2.6 &> /dev/null
!pip install gpflow &> /dev/null
# Little hack so that it works on GPU
!echo "__version__ = '2.3.0'" > /opt/conda/lib/python3.7/site-packages/gpflow/versions.py
# Same but for CPU
# !cp -r /opt/conda/lib/python3.7/site-packages/tensorboard-2.7.0.dist-info/ /opt/conda/lib/python3.7/site-packages/tensorboard-2.4.1.dist-info/
!pip install gpflux &> /dev/null
!echo "__version__ = '0.2.4'" > /opt/conda/lib/python3.7/site-packages/gpflux/version.py

In [None]:
# Basic libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Neural network libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping

# Reading images and creating video libraries
import cv2
from IPython.display import HTML
from base64 import b64encode
import matplotlib.animation as animation
import os

import SimpleITK as sitk

## Utility functions to visualize the images

I display a video with a collection of the images of each folder.

In [None]:
def play(filename):
    html = ''
    video = open(filename,'rb').read()
    src = 'data:video/mp4;base64,' + b64encode(video).decode()
    html += '<video width=500 controls autoplay loop><source src="%s" type="video/mp4"></video>' % src 
    return HTML(html)

def create_video(imgs, output='/kaggle/working/predicted.mp4', duration=30, subplot=True, 
                frame_delay=200):
    fig, ax = plt.subplots(figsize=(15, 10))
    ims = []
    if not subplot:
        shape = imgs.shape[0]
        for i in range(duration):
            im = ax.imshow(imgs[i % shape], animated=True)
            ims.append([im])
        plt.close(fig)
    else:
        shapes = [imgs[views[0]].shape[0], imgs[views[1]].shape[0], 
                  imgs[views[2]].shape[0], imgs[views[3]].shape[0]]
        fig, ax = plt.subplots(2,2, figsize=(10,10))
        for k in range(duration):
            im_ = []
            for i in range(2):
                for j in range(2):
                    im = ax[i,j].imshow(imgs[views[2*i+j]][k % shapes[2*i+j]], animated=True)
                    im_.append(im)
                    ax[i,j].set_title(views[2*i+j])
                    plt.close()
            ims.append(im_)

    ani = animation.ArtistAnimation(fig, ims, interval=frame_delay, blit=True, repeat_delay=1000)

    ani.save(output)

## Labels

In [None]:
target = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
preds = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')

## Read images utility function

In [None]:
# specify your image path
views = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
def load_imgs(idx, ignore_zeros=True, train=True):
    imgs = {}
    for view in views:
        save_ds = []
        if train:
            dir_path = os.walk(os.path.join(
            '../input/rsna-miccai-png/train/', idx, view
        ))
        else:
            dir_path = os.walk(os.path.join(
            '../input/rsna-miccai-png/test/', idx, view
        ))
        for path, subdirs, files in dir_path:
            for name in files:
                image_path = os.path.join(path, name) 
                ds = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
                save_ds.append(np.array(ds))
        if len(save_ds) == 0:
            save_ds = np.zeros((1,256,256))
        imgs[view] = np.array(save_ds)
    return imgs

Here we try loading 32 images to see how much it takes. This will be the base to set the batch size later on so that each iteration is less expensive in time.

In [None]:
# %%time
# for i in range(32):
#     idx = str(target.BraTS21ID[i]).zfill(5)
#     imgs = load_imgs(idx)

Also, there are some folders without images. For those we simply define a zero-valued image so that the models work fine.

In [None]:
# Pathological one
idx = str(109).zfill(5)
imgs = load_imgs(idx)

## Example of Image Visualization

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
for i in range(2):
    for j in range(2):
        m = ax[i,j].imshow(imgs[views[2*i+j]].mean(axis=0))
        ax[i,j].set_title(views[2*i+j])
plt.show()

# Example of Video Visualization

In [None]:
create_video(imgs, duration=60, subplot=True, frame_delay=300)
play('predicted.mp4')

### EfficientNet

In [None]:
!pip install efficientnet &> /dev/null

## DataGenerator3D

Since the data is massive we need to use a data generator. In the preprocessing we fetch all the images to the same dimensions and reduce bias using N4 Bias Field Correction. Apart from that the values of the images are zero mean and unit variance.

Note: Since we need 3 channels I am just erasing the last one.

In [None]:
from sklearn.preprocessing import StandardScaler

class DataGenerator3D(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=1,
                 n_classes=2, shuffle=True, is_train=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.is_train = (labels is not None)
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]

        X = self.__data_generation(list_IDs_temp)
        # Generate data
        if self.is_train:
            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]
            return np.array(X), np.array(y, dtype='float64')
        else:
            return np.array(X)

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            idx = str(ID).zfill(5)
            imgs = load_imgs(idx, ignore_zeros=False, train=self.is_train)
            new_imgs = []
            for ii in range(2):
                for jj in range(2):
                    if (ii == 1 and jj == 1):
                        img_ = imgs[views[2*ii+jj]]
                        img_ = np.array([cv2.resize(img_[k], dsize=(self.dim[1],self.dim[0]), interpolation=cv2.INTER_LINEAR) for k in range(img_.shape[0])])
                        img_ = np.array([cv2.resize(img_.transpose(1,2,0)[k], dsize=(self.dim[2],self.dim[1]), interpolation=cv2.INTER_LINEAR) for k in range(self.dim[0])])

                        # Removing radiofrequency inhomogeneity using N4 Bias Field Correction 
                        for p in range(len(img_)):
                            inputImage = sitk.GetImageFromArray(img_[p])
                            maskImage = sitk.GetImageFromArray((img_[p] >0.1) * 1)
                            inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)
                            maskImage = sitk.Cast(maskImage, sitk.sitkUInt8)
                            corrector = sitk.N4BiasFieldCorrectionImageFilter()
                            numberFittingLevels = 4
                            maxIter = 100
                            if maxIter is not None:
                                corrector.SetMaximumNumberOfIterations([maxIter]
                                                                       * numberFittingLevels)
                            corrected_image = corrector.Execute(inputImage, maskImage)
                            img_[p] = sitk.GetArrayFromImage(corrected_image)

                        # Normalization
                        sc = StandardScaler()
                        img_ = np.array([sc.fit_transform(img_[i]) for i in range(img_.shape[0])])

                        new_imgs.append(img_)
            new_imgs = np.concatenate(new_imgs).transpose(1,2,0).reshape((*self.dim,-1))
            X[i,] = new_imgs
        
        return X

### Usual train-validation split

In [None]:
target_red = target[288:(288+288)]

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(target_red.BraTS21ID, target_red.MGMT_value,
                                                 test_size=0.3, random_state=0,
                                                 stratify=target_red.MGMT_value)

In [None]:
dim=(32,32,34)
batch_size=24
train_dataset = DataGenerator3D(X_train, y_train, batch_size=batch_size, dim=dim)
val_dataset = DataGenerator3D(X_val, y_val, batch_size=batch_size, dim=dim)
test_dataset = DataGenerator3D(preds.BraTS21ID, batch_size=batch_size, dim=dim)

### Model #1

Apply efficientNet to each band of the image, then take the maximum.

In [None]:
import efficientnet.tfkeras as efn

with tf.device('/gpu:0'):
    def effNet(inp):
        inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1,dim[0], dim[1], 3))
        return efn.EfficientNetB0(include_top=False, pooling='avg')(inp)

In [None]:
with tf.device('/gpu:0'):
    
    def build_model():
        inp = keras.Input(shape=(*dim,3))
        intermediate = effNet(inp)
        out = layers.Dense(1, activation='sigmoid')(intermediate)
        out = tf.reshape(out, shape=(-1,dim[2],out.shape[1]))
        out = layers.maximum(tf.unstack(out, axis=1))
        return keras.Model(inputs=inp, outputs=out)
    
    model = build_model()
    
    earlyStopping = EarlyStopping(patience=2, min_delta=0.001, verbose=1)
    
    checkpoint_path = "training_0/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    # model.load_weights(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_best_only=True,
                                                 save_weights_only=True,
                                                 verbose=1)
    
    model.compile(
        optimizer='adam', 
        loss='binary_crossentropy',
        metrics=[keras.metrics.AUC()]
        )
    
    #history = model.fit(train_dataset, validation_data=val_dataset,
    #                             epochs=10, callbacks=[earlyStopping, cp_callback])

### Model #2

The maximum is substituted by an attention layer.

In [None]:
with tf.device('/gpu:0'):
    D = 128
    
    def attention(inp):
        out = layers.Dense(D, activation='tanh')(inp)
        out = layers.Dense(1)(out)
        out = tf.reshape(out, shape=(-1,dim[2]))
        return tf.nn.softmax(out, axis=1)
    
    def build_model():
        inp = keras.Input(shape=(*dim,3))
        H = effNet(inp)
        A = attention(H)
        H = tf.reshape(H, shape=(-1,dim[2], H.shape[1]))
        A = tf.expand_dims(A, axis=1)
        intermediate = tf.linalg.matmul(A,H)
        intermediate = tf.squeeze(intermediate, axis=1)
        out = layers.Dense(1, activation='sigmoid')(intermediate)
        return keras.Model(inputs=inp, outputs=out)
    
    model = build_model()
    
    earlyStopping = EarlyStopping(patience=2, min_delta=0.001, verbose=1)
    
    checkpoint_path = "training_1/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

    model.compile(
        optimizer='adam', 
        loss='binary_crossentropy',
        metrics=[keras.metrics.AUC()]
        )
    
    # history = model.fit(train_dataset, validation_data=val_dataset,
    #                              epochs=3, callbacks=[earlyStopping, cp_callback])

### Model #3

It's the model #1 plus a gaussian process. I am using the library GPFlux to do so.

Note: Install gpflow before loading any library, otherwise it won't work

#### Toy Example

In [None]:
import gpflow
import gpflux

from gpflow.config import default_float

# Default float must be same type or it will give errors from time to time
gpflow.config.set_default_float("float32")
tf.keras.backend.set_floatx("float32")

In [None]:
X = np.linspace(-10,10,1000)
Y = np.sin(X) + np.random.randn(1000) / 2

In [None]:
X = (X - X.mean()) / X.std()
num_data = X.shape
input_dim = 1

In [None]:
num_data = len(X)
num_inducing = 10
output_dim = 1

kernel = gpflow.kernels.SquaredExponential()
inducing_variable = gpflow.inducing_variables.InducingPoints(
    np.ones((num_inducing,1))
)
gp_layer = gpflux.layers.GPLayer(
    kernel, inducing_variable, num_data=num_data, num_latent_gps=1
)

In [None]:
likelihood = gpflow.likelihoods.Gaussian(0.1)

# So that Keras can track the likelihood variance, we need to provide the likelihood as part of a "dummy" layer:
likelihood_container = gpflux.layers.TrackableLayer()
likelihood_container.likelihood = likelihood

def build_model():
    inp = keras.Input(shape=(input_dim))
    x = layers.Dense(100, activation="relu")(inp)
    x = layers.Dense(100, activation="relu")(x)
    x = layers.Dense(1, activation="linear")(x)
    x = gp_layer.predict(x)
    # These two operations are to make the covariance matrix appear in the layers output later on
    x = tf.concat([x[0], x[1]], axis=0)
    x, y = tf.split(x, 2,axis=0)
    out = likelihood_container(x)
    return keras.Model(inputs=inp, outputs=out)

model = build_model()
loss = gpflux.losses.LikelihoodLoss(likelihood)

In [None]:
model.compile(loss=loss, optimizer="adam")
hist = model.fit(X, Y, epochs=100, verbose=0)
plt.plot(hist.history["loss"])
plt.show()

In [None]:
from tensorflow.keras import backend as K
def get_all_outputs(model, input_data, learning_phase=1):
    outputs = [layer.output for layer in model.layers[1:]] # exclude Input
    layers_fn = K.function([model.input], outputs)
    return layers_fn([input_data])

def get_layer_outputs(model, layer_name, input_data, learning_phase=1):
    outputs   = [layer.output for layer in model.layers if layer_name in layer.name]
    layers_fn = K.function([model.input], outputs)
    return layers_fn([input_data])

In [None]:
def plot(model, X, Y, ax=None):
    if ax is None:
        fig, ax = plt.subplots()

    x_margin = 0.2
    N_test = 100
    X_test = np.linspace(X.min() - x_margin, X.max() + x_margin, N_test).reshape(-1, 1)
    f_distribution = model(X_test)

    mean = f_distribution.numpy().squeeze()
    var = np.max(get_layer_outputs(model, 'tf.linalg.adjoint_', X_test)[0],axis=1).squeeze() + model.layers[-1].likelihood.variance.numpy()
    X_test = X_test.squeeze()
    lower = mean - 2 * np.sqrt(var)
    upper = mean + 2 * np.sqrt(var)

    ax.set_ylim(Y.min() - 0.5, Y.max() + 0.5)
    ax.plot(X, Y, "kx", alpha=0.5)
    ax.plot(X_test, mean, "C1")

    ax.fill_between(X_test, lower, upper, color="C1", alpha=0.3)


# plot(model, X, Y)

#### Real data

In [None]:
with tf.device('/gpu:0'):
    num_data = 4
    num_inducing = 4
    output_dim = 1

    kernel = gpflow.kernels.SquaredExponential()
    inducing_variable = gpflow.inducing_variables.InducingPoints(
        np.ones((num_inducing,1))
    )
    gp_layer = gpflux.layers.GPLayer(
        kernel, inducing_variable, num_data=num_data * dim[2], num_latent_gps=output_dim
    )
    
    likelihood = gpflow.likelihoods.Bernoulli()

    # So that Keras can track the likelihood variance, we need to provide the likelihood as part of a "dummy" layer:
    likelihood_container = gpflux.layers.TrackableLayer()
    likelihood_container.likelihood = likelihood

    def build_model():
        inp = keras.Input(shape=(*dim,3))
        latent = effNet(inp)
        latent = layers.Dense(1, activation='sigmoid')(latent)
        out = gp_layer.predict(latent)
        out = out[0]
        out = tf.reshape(out, shape=(-1,dim[2],out.shape[1]))
        out = layers.maximum(tf.unstack(out, axis=1))
        out = likelihood_container(out)
        return keras.Model(inputs=inp, outputs=out)
    
    model = build_model()
    
    earlyStopping = EarlyStopping(patience=2, min_delta=0.001, verbose=1)
    
    checkpoint_path = "training_2/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    # model.load_weights(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_best_only=True,
                                                 save_weights_only=True,
                                                 verbose=1)
    
    loss = gpflux.losses.LikelihoodLoss(likelihood)
    model.compile(
        optimizer='adam', 
        loss=loss,
        metrics=[keras.metrics.AUC()]
        )
    
    #history = model.fit(train_dataset, validation_data=val_dataset,
    #                             epochs=10, callbacks=[earlyStopping, cp_callback])

### Model #4

Gaussian process but using clustering to select the inducing points.

In [None]:
class ClusteringLayer(layers.Layer):
    """
    Clustering layer converts input sample (feature) to soft label.

    # Example
    ```
        model.add(ClusteringLayer(n_clusters=10))
    ```
    # Arguments
        n_clusters: number of clusters.
        weights: list of Numpy array with shape `(n_clusters, n_features)` witch represents the initial cluster centers.
        alpha: degrees of freedom parameter in Student's t-distribution. Default to 1.0.
    # Input shape
        2D tensor with shape: `(n_samples, n_features)`.
    # Output shape
        2D tensor with shape: `(n_samples, n_clusters)`.
    """

    def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(ClusteringLayer, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.initial_weights = weights
        self.input_spec = layers.InputSpec(ndim=2)

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = input_shape[1]
        self.input_spec = layers.InputSpec(dtype=K.floatx(), shape=(None, input_dim))
        self.clusters = self.add_weight(shape=(self.n_clusters, input_dim), initializer='glorot_uniform', name='clusters')
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, inputs, **kwargs):
        """ student t-distribution, as same as used in t-SNE algorithm.        
                 q_ij = 1/(1+dist(x_i, µ_j)^2), then normalize it.
                 q_ij can be interpreted as the probability of assigning sample i to cluster j.
                 (i.e., a soft assignment)
        Arguments:
            inputs: the variable containing data, shape=(n_samples, n_features)
        Return:
            q: student's t-distribution, or soft labels for each sample. shape=(n_samples, n_clusters)
        """
        q = 1.0 / (1.0 + (K.sum(K.square(K.expand_dims(inputs, axis=1) - self.clusters), axis=2) / self.alpha))
        q **= (self.alpha + 1.0) / 2.0
        q = K.transpose(K.transpose(q) / K.sum(q, axis=1)) # Make sure each sample's 10 values add up to 1.
        return q

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], self.n_clusters

    def get_config(self):
        config = {'n_clusters': self.n_clusters}
        base_config = super(ClusteringLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
with tf.device('/gpu:0'):
    num_data = 4
    num_inducing = num_data
    output_dim = 1

    kernel = gpflow.kernels.SquaredExponential()
    inducing_variable = gpflow.inducing_variables.InducingPoints(
        np.ones((num_inducing,1))
    )
    gp_layer = gpflux.layers.GPLayer(
        kernel, inducing_variable, num_data=num_data * dim[2], num_latent_gps=output_dim
    )
    
    likelihood = gpflow.likelihoods.Bernoulli()

    # So that Keras can track the likelihood variance, we need to provide the likelihood as part of a "dummy" layer:
    likelihood_container = gpflux.layers.TrackableLayer()
    likelihood_container.likelihood = likelihood

    def build_model():
        inp = keras.Input(shape=(*dim,3))
        latent = effNet(inp)
        latent = ClusteringLayer(2, name='clustering')(latent)
        latent = layers.Dense(1, activation='sigmoid')(latent)
        out = gp_layer.predict(latent)
        out = out[0]
        out = tf.reshape(out, shape=(-1,dim[2],out.shape[1]))
        out = layers.maximum(tf.unstack(out, axis=1))
        out = likelihood_container(out)
        return keras.Model(inputs=inp, outputs=out)
    
    model = build_model()
    
    earlyStopping = EarlyStopping(patience=2, min_delta=0.001, verbose=1)
    
    checkpoint_path = "training_2/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    # model.load_weights(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_best_only=True,
                                                 save_weights_only=True,
                                                 verbose=1)
    
    loss = gpflux.losses.LikelihoodLoss(likelihood)
    model.compile(
        optimizer='adam', 
        loss=loss,
        metrics=[keras.metrics.AUC()]
        )
    
    #history = model.fit(train_dataset, validation_data=val_dataset,
    #                             epochs=10, callbacks=[earlyStopping, cp_callback])

## Simpler model

Let's try with a simplified version of VGG for the CNN.

In [None]:
def conv(inp, n, n_out):
    x = inp
    for i in range(n):
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(
                n_out, (5,5), padding='same'
            )(x)
        x = layers.Dropout(0.2)(x)
        x = layers.Activation('relu')(x)
    return x

def pool(inp):
    return layers.MaxPool2D(
                pool_size=(2, 2), strides=None, padding='valid'
            )(inp)

def dense(inp, n_neur, drop=0.2, act='relu'):
    x = layers.BatchNormalization()(inp)
    x = layers.Dense(n_neur)(x)
    x = layers.Dropout(drop)(x)
    return layers.Activation(act)(x)

def VGG(inp):
    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1,dim[0], dim[1], 1))
    x = conv(inp, 2, 16)
    x = pool(x)
    
    x = conv(x, 3, 32)
    x = pool(x)
    
    x = layers.Flatten()(x)
    return dense(x, 1, 0.4, act='linear')

In [None]:
with tf.device('/gpu:0'):
    num_data = batch_size
    num_inducing = 10
    output_dim = 1

    kernel = gpflow.kernels.SquaredExponential()
    inducing_variable = gpflow.inducing_variables.InducingPoints(
        np.ones((num_inducing,output_dim))
    )
    gp_layer = gpflux.layers.GPLayer(
        kernel, inducing_variable, num_data=num_data * dim[2], num_latent_gps=output_dim
    )
    
    likelihood = gpflow.likelihoods.Bernoulli()

    # So that Keras can track the likelihood variance, we need to provide the likelihood as part of a "dummy" layer:
    likelihood_container = gpflux.layers.TrackableLayer()
    likelihood_container.likelihood = likelihood

In [None]:
with tf.device('/gpu:0'):

    def build_model():
        inp = keras.Input(shape=(*dim,3))
        latent = VGG(inp)
        out = gp_layer.predict(latent)
        out = out[0]
        out = tf.reshape(out, shape=(-1,dim[2],out.shape[1]))
        out = layers.maximum(tf.unstack(out, axis=1))
        out = layers.Activation('sigmoid')(out)
        out = likelihood_container(out)
        return keras.Model(inputs=inp, outputs=out)
    
    model = build_model()
    
    earlyStopping = EarlyStopping(patience=2, min_delta=0.001, verbose=1)
    
    checkpoint_path = "training_4/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    # model.load_weights(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_best_only=True,
                                                 save_weights_only=True,
                                                 verbose=1)
    
    loss = gpflux.losses.LikelihoodLoss(likelihood)
    model.compile(
        optimizer='adam', 
        loss=loss,
        metrics=[keras.metrics.AUC(), 'accuracy']
        )
    
    history = model.fit(train_dataset, validation_data=val_dataset,
                                 epochs=30, callbacks=[earlyStopping, cp_callback])