In this notebook we will test mixup on the MNIST dataset with a CNN.
Since we had problems we using horovod and the GPU this is runned on
smaller networks and on CPU.

Imports the necessary packages.

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense,Conv2D,Flatten,BatchNormalization,Dropout
from ray import tune
from ray.tune import CLIReporter
from sklearn.metrics import confusion_matrix
#from sparkdl import HorovodRunner
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import shutil
import os


# Fixes the issue "AttributeError: 'ConsoleBuffer has no attribute 'fileno'"
import sys
sys.stdout.fileno = lambda: False

  

  

A datagenerator class that performes mixup in the loaded data

In [None]:
class MixupImageDataGenerator_from_tensor(tf.keras.utils.Sequence):

    """
    A datagenerator that performs mixup on the input data. The input to the generator is numpy arrays with data and labels. 
    """
  
    def __init__(self, X,Y, batch_size, alpha=0.2, subset=None):
        self.batch_size = batch_size
        self.batch_index = 0
        self.alpha = alpha
        self.X = X
        self.Y = Y
        
        # First iterator yielding tuples of (x, y)
        ind = np.random.permutation(len(X))
        self.generator1 = iter(tf.data.Dataset.from_tensor_slices((X[ind],Y[ind])).batch(self.batch_size))
        
        
        # Second iterator yielding tuples of (x, y)
        ind = np.random.permutation(len(X))
        self.generator2 = iter(tf.data.Dataset.from_tensor_slices((X[ind],Y[ind])).batch(self.batch_size))

        # Number of images across all classes in image directory.
        self.n = len(X)


    def __len__(self):
        # returns the number of batches
        return (self.n + self.batch_size - 1) // self.batch_size

    def __getitem__(self, index):
        
        if self.batch_index >= self.__len__()-1:
          self.reset_index()
          self.batch_index = 0
        else:
          self.batch_index += 1
        
        # Get a pair of inputs and outputs from two iterators.
        X1, y1 = self.generator1.next()
        X2, y2 = self.generator2.next()
        
        # random sample the lambda value from beta distribution.
        l = np.random.beta(self.alpha, self.alpha, X1.shape[0])

        X_l = l.reshape(X1.shape[0], 1, 1, 1)
        y_l = l.reshape(X1.shape[0], 1)


        # Perform the mixup.
        X = X1 * X_l + X2 * (1 - X_l)
        y = y1 * y_l + y2 * (1 - y_l)
        return X, y

    def reset_index(self):
        """Reset the generator indexes array.
        """

        # First iterator yielding tuples of (x, y)
        ind = np.random.permutation(len(self.X))
        self.generator1 = iter(tf.data.Dataset.from_tensor_slices((self.X[ind],self.Y[ind])).batch(self.batch_size))
        
        
        # Second iterator yielding tuples of (x, y)
        ind = np.random.permutation(len(self.X))
        self.generator2 = iter(tf.data.Dataset.from_tensor_slices((self.X[ind],self.Y[ind])).batch(self.batch_size))



    def on_epoch_end(self):
        return
        #self.reset_index()

  

  

Two helping methods that creates the model based on the hyperparameters
"number*conv" and "number*dense" and creates the dataloaders needed for
training and validation

In [None]:
"""
creates the CNN with number_conv convolutional layers followed by number_dense dense layers. THe model is compiled with a SGD optimizer and a categorical crossentropy loss.
"""
def create_model(number_conv,number_dense):
    model = Sequential()
    model.add(Conv2D(24,kernel_size = 3, activation='relu',padding="same", input_shape=(img_height, img_width,channels)))
    model.add(BatchNormalization())
    for s in range(1,number_conv):
        model.add(Conv2D(24+12*s,kernel_size = 3,padding="same", activation = 'relu'))
        model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dropout(0.4))
    for s in range(number_dense):
        model.add(Dense(units=num_classes, activation='relu'))
        model.add(Dropout(0.4))
    model.add(BatchNormalization())
    model.add(Dense(num_classes,activation= "softmax"))
    model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
    return model


"""
A method that gives us the different dataloaders that we need for training and validation.

train_mix_loader: A data loader that will give us mixes data for training
train_loader: A data loader that gives us the unmixed training data
val_mixed_loader: A data loader that gives us the mixed validation data
val_loader: A data loader with the unmixed validation data

"""
        
def get_mnist_dataloaders():
  (trainX,trainY),(testX,testY) = tf.keras.datasets.mnist.load_data()
  trainX,testX = tf.cast(trainX,tf.float32),tf.cast(testX,tf.float32)
  trainX,testX = tf.expand_dims(trainX, 3),tf.expand_dims(testX, 3)
  trainY_oh,testY_oh = tf.one_hot(trainY,10),tf.one_hot(testY,10)
  trainY_oh,testY_oh = tf.cast(trainY_oh,tf.float32).numpy(),tf.cast(testY_oh,tf.float32).numpy()
  trainX,testX = trainX.numpy()/255 * 2 - 2,testX.numpy()/255 * 2 - 2

  
  train_loader_mix = MixupImageDataGenerator_from_tensor(trainX,trainY_oh,batch_size)
  train_loader = tf.data.Dataset.from_tensor_slices((trainX,trainY_oh)).batch(batch_size)
  test_loader_mix = MixupImageDataGenerator_from_tensor(testX,testY_oh,batch_size)
  test_loader = tf.data.Dataset.from_tensor_slices((trainX,trainY_oh)).batch(batch_size)
  
  return train_loader_mix,train_loader,test_loader_mix,test_loader
  

  

  

The method that describes how to construct and train the model.

In [None]:
def training_function(config, checkpoint_dir=None):
    # Hyperparameters
    number_conv, number_dense,train_with_mixed_data = config["number_conv"], config["number_dense"],config["train_with_mixed_data"]
    
     
    """
    Get the different dataloaders
    One with training data using mixing
    One with training without mixing
    One with validation data with mixing
    One with validation without mixing
    """
    #train_mix_dataloader,train_dataloader,val_mix_dataloader,val_dataloader = get_data_loaders(train_dir,test_dir,for_training = True)
    train_mix_dataloader,train_dataloader,val_mix_dataloader,val_dataloader = get_mnist_dataloaders()
    """
    Construct the model based on hyperparameters
    """
    model = create_model( number_conv,number_dense )

    
    """
    Adds earlystopping to training. This is based on the performance accuracy on the validation dataset. Chould we have validation loss here?
    """
    callbacks = [tf.keras.callbacks.EarlyStopping(patience=10,monitor="val_accuracy",min_delta=0.01,restore_best_weights=True)]

    """
    Train the model and give the training history.
    """
    if train_with_mixed_data:
      history = model.fit_generator(train_mix_dataloader, validation_data = val_mix_dataloader,callbacks = callbacks,verbose = False,epochs = 200)
    else:
      history = model.fit_generator(train_dataloader, validation_data = val_mix_dataloader,callbacks = callbacks,verbose = False,epochs = 200)
    
    """
    Logg the results
    """
    #x_mix, y_mix = mixup_data( x_val, y_val)
    #mix_loss, mix_acc = model.evaluate( x_mix, y_mix )
    #test_loss, test_acc = model.evaluate( x_val, y_val )
    ind_max = np.argmax(history.history['val_accuracy'])
    train_acc = history.history['accuracy'][ind_max]
    val_acc = history.history['val_accuracy'][ind_max]
    
    tune.report(mean_loss=train_acc,val_mix_accuracy = val_acc)


  

  

The global hyperparameters

In [None]:
"""
The global parameters for training.
"""

img_height,img_width,channels = 28,28,1
batch_size = 50
#train_data_dir,test_data_dir = "/dbfs/FileStore/tables/Group20/seg_train/seg_train","/dbfs/FileStore/tables/Group20/seg_test/seg_test"
alpha = 0.2
num_classes = 10
#train_with_mixed_data = True



In [None]:
# Limit the number of rows.
reporter = CLIReporter(max_progress_rows=10)
# Add a custom metric column, in addition to the default metrics.
# Note that this must be a metric that is returned in your training results.
reporter.add_metric_column("val_mix_accuracy")
#reporter.add_metric_column("test_accuracy")

#config = {"number_conv" : 3,"number_dense" : 5}
#training_function(config)

#get_data_loaders()

analysis = tune.run(
    training_function,
    config={
        "number_conv": tune.grid_search(np.arange(2,5,1).tolist()),
        "number_dense": tune.grid_search(np.arange(0,3,1).tolist()),
        "train_with_mixed_data": tune.grid_search([True,False])
    },
    local_dir='ray_results',
    progress_reporter=reporter)

print("Best config: ", analysis.get_best_config(
    metric="mean_loss", mode="max"))

#Get a dataframe for analyzing trial results.
df = analysis.results_df


  

>     2021-01-11 19:47:45,175	INFO services.py:1173 -- View the Ray dashboard at http://127.0.0.1:8265
>     == Status ==
>     Memory usage on this node: 5.5/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 1/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 1/18 (1 RUNNING)
>     +-------------------------------+----------+-------+---------------+----------------+-------------------------+
>     | Trial name                    | status   | loc   |   number_conv |   number_dense | train_with_mixed_data   |
>     |-------------------------------+----------+-------+---------------+----------------+-------------------------|
>     | training_function_e11a2_00000 | RUNNING  |       |             2 |              0 | True                    |
>     +-------------------------------+----------+-------+---------------+----------------+-------------------------+
>
>
>     (pid=1806) 2021-01-11 19:47:48.051367: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>     (pid=1806) 2021-01-11 19:47:48.051468: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>     (pid=1805) 2021-01-11 19:47:48.133442: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>     (pid=1805) 2021-01-11 19:47:48.133538: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>     (pid=1803) 2021-01-11 19:47:48.123712: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>     (pid=1803) 2021-01-11 19:47:48.123818: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>     (pid=1804) 2021-01-11 19:47:48.145212: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>     (pid=1804) 2021-01-11 19:47:48.145295: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>     (pid=1806) Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
>     (pid=1805) Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
>     (pid=1804) Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
>     (pid=1806)     8192/11490434 [..............................] - ETA: 0s
>     (pid=1803) Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
>     (pid=1806)   139264/11490434 [..............................] - ETA: 4s
>     (pid=1806)   303104/11490434 [..............................] - ETA: 3s
>     (pid=1805)     8192/11490434 [..............................] - ETA: 0s
>     (pid=1806)   499712/11490434 [>.............................] - ETA: 3s
>     (pid=1805)   147456/11490434 [..............................] - ETA: 3s
>     (pid=1806)   688128/11490434 [>.............................] - ETA: 3s
>     (pid=1805)   319488/11490434 [..............................] - ETA: 3s
>     (pid=1803)     8192/11490434 [..............................] - ETA: 0s
>     (pid=1806)   876544/11490434 [=>............................] - ETA: 3s
>     (pid=1805)   516096/11490434 [>.............................] - ETA: 3s
>     (pid=1804)     8192/11490434 [..............................] - ETA: 0s
>     (pid=1803)   163840/11490434 [..............................] - ETA: 3s
>     (pid=1806)  1064960/11490434 [=>............................] - ETA: 3s
>     (pid=1805)   712704/11490434 [>.............................] - ETA: 3s
>     (pid=1804)   139264/11490434 [..............................] - ETA: 4s
>     (pid=1803)   344064/11490434 [..............................] - ETA: 3s
>     (pid=1806)  1261568/11490434 [==>...........................] - ETA: 2s
>     (pid=1805)   901120/11490434 [=>............................] - ETA: 3s
>     (pid=1804)   303104/11490434 [..............................] - ETA: 3s
>     (pid=1803)   557056/11490434 [>.............................] - ETA: 2s
>     (pid=1806)  1458176/11490434 [==>...........................] - ETA: 2s
>     (pid=1805)  1081344/11490434 [=>............................] - ETA: 2s
>     (pid=1804)   491520/11490434 [>.............................] - ETA: 3s
>     (pid=1803)   770048/11490434 [=>............................] - ETA: 2s  983040/11490434 [=>............................] - ETA: 2s
>     (pid=1806)  1654784/11490434 [===>..........................] - ETA: 2s
>     (pid=1805)  1277952/11490434 [==>...........................] - ETA: 2s
>     (pid=1804)   663552/11490434 [>.............................] - ETA: 3s
>     (pid=1803)  1179648/11490434 [==>...........................] - ETA: 2s
>     (pid=1806)  1851392/11490434 [===>..........................] - ETA: 2s
>     (pid=1805)  1474560/11490434 [==>...........................] - ETA: 2s
>     (pid=1804)   860160/11490434 [=>............................] - ETA: 3s
>     (pid=1803)  1392640/11490434 [==>...........................] - ETA: 2s
>     (pid=1806)  2048000/11490434 [====>.........................] - ETA: 2s
>     (pid=1805)  1671168/11490434 [===>..........................] - ETA: 2s
>     (pid=1804)  1032192/11490434 [=>............................] - ETA: 3s
>     (pid=1803)  1605632/11490434 [===>..........................] - ETA: 2s
>     (pid=1806)  2244608/11490434 [====>.........................] - ETA: 2s
>     (pid=1805)  1867776/11490434 [===>..........................] - ETA: 2s
>     (pid=1804)  1212416/11490434 [==>...........................] - ETA: 3s
>     (pid=1803)  1818624/11490434 [===>..........................] - ETA: 2s
>     (pid=1806)  2441216/11490434 [=====>........................] - ETA: 2s
>     (pid=1805)  2064384/11490434 [====>.........................] - ETA: 2s
>     (pid=1804)  1392640/11490434 [==>...........................] - ETA: 2s
>     (pid=1803)  2031616/11490434 [====>.........................] - ETA: 2s
>     (pid=1806)  2637824/11490434 [=====>........................] - ETA: 2s
>     (pid=1805)  2260992/11490434 [====>.........................] - ETA: 2s
>     (pid=1804)  1589248/11490434 [===>..........................] - ETA: 2s
>     (pid=1803)  2244608/11490434 [====>.........................] - ETA: 2s
>     (pid=1805)  2457600/11490434 [=====>........................] - ETA: 2s
>     (pid=1806)  2834432/11490434 [======>.......................] - ETA: 2s
>     (pid=1805)  2654208/11490434 [=====>........................] - ETA: 2s
>     (pid=1804)  1769472/11490434 [===>..........................] - ETA: 2s
>     (pid=1803)  2457600/11490434 [=====>........................] - ETA: 2s
>     (pid=1806)  3031040/11490434 [======>.......................] - ETA: 2s
>     (pid=1805)  2850816/11490434 [======>.......................] - ETA: 2s
>     (pid=1804)  1966080/11490434 [====>.........................] - ETA: 2s
>     (pid=1803)  2670592/11490434 [=====>........................] - ETA: 2s
>     (pid=1806)  3227648/11490434 [=======>......................] - ETA: 2s
>     (pid=1805)  3047424/11490434 [======>.......................] - ETA: 2s
>     (pid=1804)  2146304/11490434 [====>.........................] - ETA: 2s
>     (pid=1803)  2883584/11490434 [======>.......................] - ETA: 2s
>     (pid=1806)  3407872/11490434 [=======>......................] - ETA: 2s
>     (pid=1805)  3260416/11490434 [=======>......................] - ETA: 2s
>     (pid=1804)  2342912/11490434 [=====>........................] - ETA: 2s
>     (pid=1803)  3096576/11490434 [=======>......................] - ETA: 2s
>     (pid=1806)  3604480/11490434 [========>.....................] - ETA: 2s
>     (pid=1805)  3457024/11490434 [========>.....................] - ETA: 2s
>     (pid=1804)  2539520/11490434 [=====>........................] - ETA: 2s
>     (pid=1803)  3309568/11490434 [=======>......................] - ETA: 2s
>     (pid=1806)  3801088/11490434 [========>.....................] - ETA: 2s
>     (pid=1805)  3653632/11490434 [========>.....................] - ETA: 2s
>     (pid=1804)  2719744/11490434 [======>.......................] - ETA: 2s
>     (pid=1803)  3538944/11490434 [========>.....................] - ETA: 1s
>     (pid=1806)  3997696/11490434 [=========>....................] - ETA: 2s
>     (pid=1805)  3850240/11490434 [=========>....................] - ETA: 2s
>     (pid=1804)  2916352/11490434 [======>.......................] - ETA: 2s
>     (pid=1803)  3751936/11490434 [========>.....................] - ETA: 1s
>     (pid=1806)  4194304/11490434 [=========>....................] - ETA: 1s
>     (pid=1805)  4046848/11490434 [=========>....................] - ETA: 1s
>     (pid=1804)  3096576/11490434 [=======>......................] - ETA: 2s 3293184/11490434 [=======>......................] - ETA: 2s
>     (pid=1803)  3964928/11490434 [=========>....................] - ETA: 1s
>     (pid=1806)  4390912/11490434 [==========>...................] - ETA: 1s
>     (pid=1805)  4243456/11490434 [==========>...................] - ETA: 1s
>     (pid=1804)  3473408/11490434 [========>.....................] - ETA: 2s
>     (pid=1803)  4177920/11490434 [=========>....................] - ETA: 1s
>     (pid=1806)  4587520/11490434 [==========>...................] - ETA: 1s
>     (pid=1805)  4440064/11490434 [==========>...................] - ETA: 1s
>     (pid=1804)  3670016/11490434 [========>.....................] - ETA: 2s
>     (pid=1803)  4390912/11490434 [==========>...................] - ETA: 1s
>     (pid=1806)  4784128/11490434 [===========>..................] - ETA: 1s
>     (pid=1805)  4653056/11490434 [===========>..................] - ETA: 1s
>     (pid=1804)  3850240/11490434 [=========>....................] - ETA: 2s
>     (pid=1803)  4620288/11490434 [===========>..................] - ETA: 1s
>     (pid=1806)  4980736/11490434 [============>.................] - ETA: 1s
>     (pid=1805)  4866048/11490434 [===========>..................] - ETA: 1s
>     (pid=1804)  4046848/11490434 [=========>....................] - ETA: 2s
>     (pid=1803)  4833280/11490434 [===========>..................] - ETA: 1s
>     (pid=1806)  5177344/11490434 [============>.................] - ETA: 1s
>     (pid=1805)  5062656/11490434 [============>.................] - ETA: 1s
>     (pid=1804)  4227072/11490434 [==========>...................] - ETA: 2s
>     (pid=1803)  5046272/11490434 [============>.................] - ETA: 1s
>     (pid=1806)  5373952/11490434 [=============>................] - ETA: 1s
>     (pid=1805)  5259264/11490434 [============>.................] - ETA: 1s
>     (pid=1804)  4423680/11490434 [==========>...................] - ETA: 1s
>     (pid=1803)  5259264/11490434 [============>.................] - ETA: 1s
>     (pid=1806)  5570560/11490434 [=============>................] - ETA: 1s
>     (pid=1805)  5455872/11490434 [=============>................] - ETA: 1s
>     (pid=1804)  4603904/11490434 [===========>..................] - ETA: 1s
>     (pid=1803)  5488640/11490434 [=============>................] - ETA: 1s
>     (pid=1806)  5767168/11490434 [==============>...............] - ETA: 1s
>     (pid=1805)  5652480/11490434 [=============>................] - ETA: 1s
>     (pid=1804)  4800512/11490434 [===========>..................] - ETA: 1s
>     (pid=1803)  5701632/11490434 [=============>................] - ETA: 1s
>     (pid=1806)  5963776/11490434 [==============>...............] - ETA: 1s
>     (pid=1805)  5865472/11490434 [==============>...............] - ETA: 1s
>     (pid=1804)  4997120/11490434 [============>.................] - ETA: 1s
>     (pid=1803)  5914624/11490434 [==============>...............] - ETA: 1s
>     (pid=1806)  6160384/11490434 [===============>..............] - ETA: 1s
>     (pid=1805)  6078464/11490434 [==============>...............] - ETA: 1s
>     (pid=1804)  5193728/11490434 [============>.................] - ETA: 1s
>     (pid=1803)  6127616/11490434 [==============>...............] - ETA: 1s
>     (pid=1806)  6356992/11490434 [===============>..............] - ETA: 1s
>     (pid=1804)  5390336/11490434 [=============>................] - ETA: 1s
>     (pid=1803)  6340608/11490434 [===============>..............] - ETA: 1s
>     (pid=1806)  6553600/11490434 [================>.............] - ETA: 1s
>     (pid=1805)  6291456/11490434 [===============>..............] - ETA: 1s
>     (pid=1804)  5586944/11490434 [=============>................] - ETA: 1s
>     (pid=1803)  6553600/11490434 [================>.............] - ETA: 1s
>     (pid=1806)  6750208/11490434 [================>.............] - ETA: 1s
>     (pid=1805)  6504448/11490434 [===============>..............] - ETA: 1s
>     (pid=1804)  5767168/11490434 [==============>...............] - ETA: 1s
>     (pid=1803)  6766592/11490434 [================>.............] - ETA: 1s
>     (pid=1806)  6946816/11490434 [=================>............] - ETA: 1s
>     (pid=1805)  6701056/11490434 [================>.............] - ETA: 1s
>     (pid=1804)  5963776/11490434 [==============>...............] - ETA: 1s
>     (pid=1803)  6979584/11490434 [=================>............] - ETA: 1s
>     (pid=1806)  7143424/11490434 [=================>............] - ETA: 1s
>     (pid=1805)  6897664/11490434 [=================>............] - ETA: 1s 7094272/11490434 [=================>............] - ETA: 1s
>     (pid=1804)  6144000/11490434 [===============>..............] - ETA: 1s
>     (pid=1803)  7192576/11490434 [=================>............] - ETA: 1s
>     (pid=1806)  7340032/11490434 [==================>...........] - ETA: 1s
>     (pid=1805)  7290880/11490434 [==================>...........] - ETA: 1s
>     (pid=1804)  6340608/11490434 [===============>..............] - ETA: 1s
>     (pid=1803)  7405568/11490434 [==================>...........] - ETA: 0s
>     (pid=1806)  7536640/11490434 [==================>...........] - ETA: 1s
>     (pid=1805)  7487488/11490434 [==================>...........] - ETA: 1s
>     (pid=1804)  6520832/11490434 [================>.............] - ETA: 1s
>     (pid=1803)  7618560/11490434 [==================>...........] - ETA: 0s
>
>     *** WARNING: skipped 102225 bytes of output ***
>
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     ... 8 more trials not shown (8 TERMINATED)
>
>
>     Result for training_function_e11a2_00013:
>       date: 2021-01-12_09-57-10
>       done: true
>       experiment_id: 1d6933bc08ab4445ad5da9e5a347d211
>       experiment_tag: 13_number_conv=3,number_dense=1,train_with_mixed_data=False
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.8584166765213013
>       neg_mean_loss: -0.8584166765213013
>       node_ip: 10.149.234.234
>       pid: 16642
>       time_since_restore: 15591.14414525032
>       time_this_iter_s: 15591.14414525032
>       time_total_s: 15591.14414525032
>       timestamp: 1610445430
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00013
>       val_mix_accuracy: 0.9580000042915344
>       
>     2021-01-12 09:57:11,151	WARNING worker.py:1034 -- The actor or task with ID ffffffffffffffff9a0c190a01000000 cannot be scheduled right now. It requires {CPU: 1.000000} for placement, but this node only has remaining {CPU: 1.000000}, {node:10.149.234.234: 1.000000}, {object_store_memory: 4.052734 GiB}, {memory: 11.865234 GiB}. In total there are 0 pending tasks and 1 pending actors on this node. This is likely due to all cluster resources being claimed by actors. To resolve the issue, consider creating fewer actors or increase the resources available to this Ray cluster. You can ignore this message if this Ray cluster is expected to auto-scale.
>     (pid=24653) 2021-01-12 09:57:17.567536: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>     (pid=24653) 2021-01-12 09:57:17.567633: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>     (pid=24653) 2021-01-12 09:57:29.520020: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
>     (pid=24653) 2021-01-12 09:57:29.520411: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
>     (pid=24653) 2021-01-12 09:57:29.520461: W tensorflow/stream_executor/cuda/cuda_driver.cc:326] failed call to cuInit: UNKNOWN ERROR (303)
>     (pid=24653) 2021-01-12 09:57:29.532198: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (0111-153043-gazer659-10-149-234-234): /proc/driver/nvidia/version does not exist
>     (pid=24653) 2021-01-12 09:57:29.533132: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
>     (pid=24653) /databricks/python/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:1844: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.
>     (pid=24653)   warnings.warn('`Model.fit_generator` is deprecated and '
>     (pid=24653) 2021-01-12 09:57:48.439418: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
>     (pid=24653) 2021-01-12 09:57:48.515396: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2300075000 Hz
>     Result for training_function_e11a2_00011:
>       date: 2021-01-12_10-48-28
>       done: false
>       experiment_id: 499d8408750c4c3581f91ab7db5f1511
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.9863666892051697
>       neg_mean_loss: -0.9863666892051697
>       node_ip: 10.149.234.234
>       pid: 14928
>       time_since_restore: 22859.48212838173
>       time_this_iter_s: 22859.48212838173
>       time_total_s: 22859.48212838173
>       timestamp: 1610448508
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00011
>       val_mix_accuracy: 0.9587000012397766
>       
>     == Status ==
>     Memory usage on this node: 18.5/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 4/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 18/18 (4 RUNNING, 14 TERMINATED)
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     | Trial name                    | status     | loc                  |   number_conv |   number_dense | train_with_mixed_data   |     loss |   iter |   total time (s) |   val_mix_accuracy |
>     |-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------|
>     | training_function_e11a2_00011 | RUNNING    | 10.149.234.234:14928 |             4 |              0 | False                   | 0.986367 |      1 |         22859.5  |             0.9587 |
>     | training_function_e11a2_00014 | RUNNING    |                      |             4 |              1 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00016 | RUNNING    |                      |             3 |              2 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00017 | RUNNING    |                      |             4 |              2 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00000 | TERMINATED |                      |             2 |              0 | True                    | 0.91115  |      1 |          8726.33 |             0.9044 |
>     | training_function_e11a2_00001 | TERMINATED |                      |             3 |              0 | True                    | 0.919217 |      1 |         20095.5  |             0.9445 |
>     | training_function_e11a2_00002 | TERMINATED |                      |             4 |              0 | True                    | 0.877833 |      1 |         20895.4  |             0.914  |
>     | training_function_e11a2_00003 | TERMINATED |                      |             2 |              1 | True                    | 0.789233 |      1 |          8132.31 |             0.9602 |
>     | training_function_e11a2_00004 | TERMINATED |                      |             3 |              1 | True                    | 0.7919   |      1 |         10600.6  |             0.9568 |
>     | training_function_e11a2_00005 | TERMINATED |                      |             4 |              1 | True                    | 0.790583 |      1 |         18313.7  |             0.9623 |
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     ... 8 more trials not shown (8 TERMINATED)
>
>
>     Result for training_function_e11a2_00011:
>       date: 2021-01-12_10-48-28
>       done: true
>       experiment_id: 499d8408750c4c3581f91ab7db5f1511
>       experiment_tag: 11_number_conv=4,number_dense=0,train_with_mixed_data=False
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.9863666892051697
>       neg_mean_loss: -0.9863666892051697
>       node_ip: 10.149.234.234
>       pid: 14928
>       time_since_restore: 22859.48212838173
>       time_this_iter_s: 22859.48212838173
>       time_total_s: 22859.48212838173
>       timestamp: 1610448508
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00011
>       val_mix_accuracy: 0.9587000012397766
>       
>     Result for training_function_e11a2_00016:
>       date: 2021-01-12_13-09-55
>       done: false
>       experiment_id: 60661505e32e4be6a41027634e3d9254
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.6158000230789185
>       neg_mean_loss: -0.6158000230789185
>       node_ip: 10.149.234.234
>       pid: 22764
>       time_since_restore: 15145.674983501434
>       time_this_iter_s: 15145.674983501434
>       time_total_s: 15145.674983501434
>       timestamp: 1610456995
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00016
>       val_mix_accuracy: 0.8560000061988831
>       
>     == Status ==
>     Memory usage on this node: 20.0/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 3/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 18/18 (3 RUNNING, 15 TERMINATED)
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     | Trial name                    | status     | loc                  |   number_conv |   number_dense | train_with_mixed_data   |     loss |   iter |   total time (s) |   val_mix_accuracy |
>     |-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------|
>     | training_function_e11a2_00014 | RUNNING    |                      |             4 |              1 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00016 | RUNNING    | 10.149.234.234:22764 |             3 |              2 | False                   | 0.6158   |      1 |         15145.7  |             0.856  |
>     | training_function_e11a2_00017 | RUNNING    |                      |             4 |              2 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00000 | TERMINATED |                      |             2 |              0 | True                    | 0.91115  |      1 |          8726.33 |             0.9044 |
>     | training_function_e11a2_00001 | TERMINATED |                      |             3 |              0 | True                    | 0.919217 |      1 |         20095.5  |             0.9445 |
>     | training_function_e11a2_00002 | TERMINATED |                      |             4 |              0 | True                    | 0.877833 |      1 |         20895.4  |             0.914  |
>     | training_function_e11a2_00003 | TERMINATED |                      |             2 |              1 | True                    | 0.789233 |      1 |          8132.31 |             0.9602 |
>     | training_function_e11a2_00004 | TERMINATED |                      |             3 |              1 | True                    | 0.7919   |      1 |         10600.6  |             0.9568 |
>     | training_function_e11a2_00005 | TERMINATED |                      |             4 |              1 | True                    | 0.790583 |      1 |         18313.7  |             0.9623 |
>     | training_function_e11a2_00006 | TERMINATED |                      |             2 |              2 | True                    | 0.549283 |      1 |          7750.54 |             0.8406 |
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     ... 8 more trials not shown (8 TERMINATED)
>
>
>     Result for training_function_e11a2_00016:
>       date: 2021-01-12_13-09-55
>       done: true
>       experiment_id: 60661505e32e4be6a41027634e3d9254
>       experiment_tag: 16_number_conv=3,number_dense=2,train_with_mixed_data=False
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.6158000230789185
>       neg_mean_loss: -0.6158000230789185
>       node_ip: 10.149.234.234
>       pid: 22764
>       time_since_restore: 15145.674983501434
>       time_this_iter_s: 15145.674983501434
>       time_total_s: 15145.674983501434
>       timestamp: 1610456995
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00016
>       val_mix_accuracy: 0.8560000061988831
>       
>     Result for training_function_e11a2_00014:
>       date: 2021-01-12_14-01-25
>       done: false
>       experiment_id: 3efac11caf12479394e9b0448a2b97f7
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.857866644859314
>       neg_mean_loss: -0.857866644859314
>       node_ip: 10.149.234.234
>       pid: 17271
>       time_since_restore: 28856.51887536049
>       time_this_iter_s: 28856.51887536049
>       time_total_s: 28856.51887536049
>       timestamp: 1610460085
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00014
>       val_mix_accuracy: 0.9544000029563904
>       
>     == Status ==
>     Memory usage on this node: 18.7/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 2/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 18/18 (2 RUNNING, 16 TERMINATED)
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     | Trial name                    | status     | loc                  |   number_conv |   number_dense | train_with_mixed_data   |     loss |   iter |   total time (s) |   val_mix_accuracy |
>     |-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------|
>     | training_function_e11a2_00014 | RUNNING    | 10.149.234.234:17271 |             4 |              1 | False                   | 0.857867 |      1 |         28856.5  |             0.9544 |
>     | training_function_e11a2_00017 | RUNNING    |                      |             4 |              2 | False                   |          |        |                  |                    |
>     | training_function_e11a2_00000 | TERMINATED |                      |             2 |              0 | True                    | 0.91115  |      1 |          8726.33 |             0.9044 |
>     | training_function_e11a2_00001 | TERMINATED |                      |             3 |              0 | True                    | 0.919217 |      1 |         20095.5  |             0.9445 |
>     | training_function_e11a2_00002 | TERMINATED |                      |             4 |              0 | True                    | 0.877833 |      1 |         20895.4  |             0.914  |
>     | training_function_e11a2_00003 | TERMINATED |                      |             2 |              1 | True                    | 0.789233 |      1 |          8132.31 |             0.9602 |
>     | training_function_e11a2_00004 | TERMINATED |                      |             3 |              1 | True                    | 0.7919   |      1 |         10600.6  |             0.9568 |
>     | training_function_e11a2_00005 | TERMINATED |                      |             4 |              1 | True                    | 0.790583 |      1 |         18313.7  |             0.9623 |
>     | training_function_e11a2_00006 | TERMINATED |                      |             2 |              2 | True                    | 0.549283 |      1 |          7750.54 |             0.8406 |
>     | training_function_e11a2_00007 | TERMINATED |                      |             3 |              2 | True                    | 0.60175  |      1 |         11056.9  |             0.8676 |
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     ... 8 more trials not shown (8 TERMINATED)
>
>
>     Result for training_function_e11a2_00014:
>       date: 2021-01-12_14-01-25
>       done: true
>       experiment_id: 3efac11caf12479394e9b0448a2b97f7
>       experiment_tag: 14_number_conv=4,number_dense=1,train_with_mixed_data=False
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.857866644859314
>       neg_mean_loss: -0.857866644859314
>       node_ip: 10.149.234.234
>       pid: 17271
>       time_since_restore: 28856.51887536049
>       time_this_iter_s: 28856.51887536049
>       time_total_s: 28856.51887536049
>       timestamp: 1610460085
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00014
>       val_mix_accuracy: 0.9544000029563904
>       
>     Result for training_function_e11a2_00017:
>       date: 2021-01-12_15-06-58
>       done: false
>       experiment_id: 0a699ecb32e34c56a9dda2e7625465d9
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.6455333232879639
>       neg_mean_loss: -0.6455333232879639
>       node_ip: 10.149.234.234
>       pid: 24653
>       time_since_restore: 18571.127383708954
>       time_this_iter_s: 18571.127383708954
>       time_total_s: 18571.127383708954
>       timestamp: 1610464018
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00017
>       val_mix_accuracy: 0.8565000295639038
>       
>     == Status ==
>     Memory usage on this node: 15.0/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 1/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 18/18 (1 RUNNING, 17 TERMINATED)
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     | Trial name                    | status     | loc                  |   number_conv |   number_dense | train_with_mixed_data   |     loss |   iter |   total time (s) |   val_mix_accuracy |
>     |-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------|
>     | training_function_e11a2_00017 | RUNNING    | 10.149.234.234:24653 |             4 |              2 | False                   | 0.645533 |      1 |         18571.1  |             0.8565 |
>     | training_function_e11a2_00000 | TERMINATED |                      |             2 |              0 | True                    | 0.91115  |      1 |          8726.33 |             0.9044 |
>     | training_function_e11a2_00001 | TERMINATED |                      |             3 |              0 | True                    | 0.919217 |      1 |         20095.5  |             0.9445 |
>     | training_function_e11a2_00002 | TERMINATED |                      |             4 |              0 | True                    | 0.877833 |      1 |         20895.4  |             0.914  |
>     | training_function_e11a2_00003 | TERMINATED |                      |             2 |              1 | True                    | 0.789233 |      1 |          8132.31 |             0.9602 |
>     | training_function_e11a2_00004 | TERMINATED |                      |             3 |              1 | True                    | 0.7919   |      1 |         10600.6  |             0.9568 |
>     | training_function_e11a2_00005 | TERMINATED |                      |             4 |              1 | True                    | 0.790583 |      1 |         18313.7  |             0.9623 |
>     | training_function_e11a2_00006 | TERMINATED |                      |             2 |              2 | True                    | 0.549283 |      1 |          7750.54 |             0.8406 |
>     | training_function_e11a2_00007 | TERMINATED |                      |             3 |              2 | True                    | 0.60175  |      1 |         11056.9  |             0.8676 |
>     | training_function_e11a2_00008 | TERMINATED |                      |             4 |              2 | True                    | 0.110917 |      1 |         14448.1  |             0.1165 |
>     +-------------------------------+------------+----------------------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     ... 8 more trials not shown (8 TERMINATED)
>
>
>     Result for training_function_e11a2_00017:
>       date: 2021-01-12_15-06-58
>       done: true
>       experiment_id: 0a699ecb32e34c56a9dda2e7625465d9
>       experiment_tag: 17_number_conv=4,number_dense=2,train_with_mixed_data=False
>       hostname: 0111-153043-gazer659-10-149-234-234
>       iterations_since_restore: 1
>       mean_loss: 0.6455333232879639
>       neg_mean_loss: -0.6455333232879639
>       node_ip: 10.149.234.234
>       pid: 24653
>       time_since_restore: 18571.127383708954
>       time_this_iter_s: 18571.127383708954
>       time_total_s: 18571.127383708954
>       timestamp: 1610464018
>       timesteps_since_restore: 0
>       training_iteration: 1
>       trial_id: e11a2_00017
>       val_mix_accuracy: 0.8565000295639038
>       
>     == Status ==
>     Memory usage on this node: 14.6/24.9 GiB
>     Using FIFO scheduling algorithm.
>     Resources requested: 0/4 CPUs, 0/0 GPUs, 0.0/11.87 GiB heap, 0.0/4.05 GiB objects
>     Result logdir: /databricks/driver/ray_results/training_function_2021-01-11_19-47-47
>     Number of trials: 18/18 (18 TERMINATED)
>     +-------------------------------+------------+-------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>     | Trial name                    | status     | loc   |   number_conv |   number_dense | train_with_mixed_data   |     loss |   iter |   total time (s) |   val_mix_accuracy |
>     |-------------------------------+------------+-------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------|
>     | training_function_e11a2_00000 | TERMINATED |       |             2 |              0 | True                    | 0.91115  |      1 |          8726.33 |             0.9044 |
>     | training_function_e11a2_00001 | TERMINATED |       |             3 |              0 | True                    | 0.919217 |      1 |         20095.5  |             0.9445 |
>     | training_function_e11a2_00002 | TERMINATED |       |             4 |              0 | True                    | 0.877833 |      1 |         20895.4  |             0.914  |
>     | training_function_e11a2_00003 | TERMINATED |       |             2 |              1 | True                    | 0.789233 |      1 |          8132.31 |             0.9602 |
>     | training_function_e11a2_00004 | TERMINATED |       |             3 |              1 | True                    | 0.7919   |      1 |         10600.6  |             0.9568 |
>     | training_function_e11a2_00005 | TERMINATED |       |             4 |              1 | True                    | 0.790583 |      1 |         18313.7  |             0.9623 |
>     | training_function_e11a2_00006 | TERMINATED |       |             2 |              2 | True                    | 0.549283 |      1 |          7750.54 |             0.8406 |
>     | training_function_e11a2_00007 | TERMINATED |       |             3 |              2 | True                    | 0.60175  |      1 |         11056.9  |             0.8676 |
>     | training_function_e11a2_00008 | TERMINATED |       |             4 |              2 | True                    | 0.110917 |      1 |         14448.1  |             0.1165 |
>     | training_function_e11a2_00009 | TERMINATED |       |             2 |              0 | False                   | 0.9881   |      1 |          8279.36 |             0.9543 |
>     | training_function_e11a2_00010 | TERMINATED |       |             3 |              0 | False                   | 0.98315  |      1 |          9679.01 |             0.9392 |
>     | training_function_e11a2_00011 | TERMINATED |       |             4 |              0 | False                   | 0.986367 |      1 |         22859.5  |             0.9587 |
>     | training_function_e11a2_00012 | TERMINATED |       |             2 |              1 | False                   | 0.835117 |      1 |          5523.47 |             0.9474 |
>     | training_function_e11a2_00013 | TERMINATED |       |             3 |              1 | False                   | 0.858417 |      1 |         15591.1  |             0.958  |
>     | training_function_e11a2_00014 | TERMINATED |       |             4 |              1 | False                   | 0.857867 |      1 |         28856.5  |             0.9544 |
>     | training_function_e11a2_00015 | TERMINATED |       |             2 |              2 | False                   | 0.63225  |      1 |          7013.64 |             0.862  |
>     | training_function_e11a2_00016 | TERMINATED |       |             3 |              2 | False                   | 0.6158   |      1 |         15145.7  |             0.856  |
>     | training_function_e11a2_00017 | TERMINATED |       |             4 |              2 | False                   | 0.645533 |      1 |         18571.1  |             0.8565 |
>     +-------------------------------+------------+-------+---------------+----------------+-------------------------+----------+--------+------------------+--------------------+
>
>
>     2021-01-12 15:06:58,847	INFO tune.py:448 -- Total run time: 69554.22 seconds (69551.66 seconds for the tuning loop).
>     Best config:  {'number_conv': 2, 'number_dense': 0, 'train_with_mixed_data': False}

In [None]:
#print(df)
df

In [None]:
def get_correct_labels():
  #Cifar10
  return np.array([[i]*1000 for i in range(10)]).reshape(-1)


"""
Get best model
"""

config = {"number_conv" : 4,"number_dense" : 1}

model = create_model(config)
test_loader = get_data_loaders(for_training = False)

pred = model.predict_generator(test_loader)
y_true = get_correct_labels()

confusion_matrix(y_true,y_pred)

