In [None]:
import os
import shutil
import sys
sys.path.append('../datapipeline')
sys.path.append('../imgclsmob/tensorflow2')
sys.path.append('../akhelpers')


import autokeras as ak
from Resnet_AK import CustomResnetBlock

import random
random.seed(47)
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.python.keras.utils.data_utils import Sequence
from kerastuner.engine.hyperparameters import HyperParameters

from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(x_train.shape)  # (50000, 32, 32, 3)
print(y_train.shape)  # (50000, 1)
print(y_train[:3])  # [[6], [9], [9]]
print(x_test.shape)  # (10000, 32, 32, 3)
print(y_test.shape)  # (10000, 1)

In [None]:
IMAGE_SIZE = x_train.shape[1:]
NUM_CHANNELS = (3,)
NUM_CLASSES = 10
MAX_NETWORK_SEARCH_TRIALS = 3
# batchsize has to be 1 if IMAGE_SIZE is (None, None), raggedtensor is yet not supproted
BATCH_SIZE = 128
VAL_DS_SIZE = 1000
TRAIN_DS_SIZE = 10000
EPOCHS = 10
MODEL_DIR = 'auto_model'
PROJECT_NAME = 'resnet_ak'
OVERITE_PROJECT = True

# has_bg = False

In [None]:
# restric the search space
hp = HyperParameters()
hp.Choice(
            "optimizer",
            ["adam", "sgd"],
            default="adam",
        )

In [None]:
zipped_test = list(zip(x_test, y_test))
random.shuffle(zipped_test)
zipped_train = list(zip(x_train, y_train))
random.shuffle(zipped_train)

In [None]:
zipped_train = list(zip(*zipped_train)) 
zipped_test = list(zip(*zipped_test)) 

In [None]:
(sel_x_train, sel_y_train) = zipped_train
(sel_x_test, sel_y_test) = zipped_test

In [None]:
sel_x_train = sel_x_train[:TRAIN_DS_SIZE]
sel_y_train = sel_y_train[:TRAIN_DS_SIZE]
sel_x_test = sel_x_test[:VAL_DS_SIZE]
sel_y_test = sel_y_test[:VAL_DS_SIZE]

In [None]:
sel_y_train = np.reshape(sel_y_train, (TRAIN_DS_SIZE, ))
sel_y_test = np.reshape(sel_y_test, (VAL_DS_SIZE, ))

In [None]:
sel_y_train = tf.one_hot(sel_y_train, NUM_CLASSES)
sel_y_test = tf.one_hot(sel_y_test, NUM_CLASSES)


In [None]:
sel_y_test = sel_y_test.numpy()
sel_y_train = sel_y_train.numpy()

In [None]:
sel_x_train = np.stack(sel_x_train, axis = 0)
print(sel_x_train.shape)
sel_x_test = np.stack(sel_x_test, axis = 0)
print(sel_x_test.shape)

### Define Model

In [None]:
input_node = ak.Input()
output_node = ak.ImageAugmentation(
    translation_factor=0.2,
    vertical_flip=False,
    horizontal_flip=True,
    rotation_factor=0.3,
    zoom_factor=0.2,
    contrast_factor=0.2
)(input_node)
output_node = ak.Normalization()(output_node)
output_node = CustomResnetBlock(in_size=IMAGE_SIZE)(output_node)
output_node = ak.ClassificationHead(num_classes=NUM_CLASSES, multi_label=False)(output_node)
auto_model = ak.AutoModel(
    inputs=input_node, 
    outputs=output_node,
    overwrite=OVERITE_PROJECT,
    directory=MODEL_DIR,
    hyperparameters=hp,
    max_trials=MAX_NETWORK_SEARCH_TRIALS,
    project_name=PROJECT_NAME)

### Train Model

In [None]:
# only tested with 'greedy' algorithm. it deletes the checkpoint after every trial saving disk space
class ClearTrialCheckpoints(tf.keras.callbacks.Callback):
    def on_train_end(*args, **kwargs):
        dir_to_look = os.path.join(MODEL_DIR, PROJECT_NAME)
        dir_ls = [os.path.join(dir_to_look, d) for d in os.listdir(dir_to_look) if os.path.isdir(os.path.join(dir_to_look, d)) and 'trial' in d]
        for d in dir_ls:
            dir_of_concern = os.path.join(d, 'checkpoints')
            if os.path.isdir(dir_of_concern):
                print(dir_of_concern)
                shutil.rmtree(dir_of_concern)


In [None]:
callbacks = [EarlyStopping(patience=3),
            ReduceLROnPlateau(patience=1),
             ClearTrialCheckpoints(),
            ]

In [None]:
# auto_model.fit(train_dataset, validation_data=val_dataset, batch_size=None, epochs=EPOCHS, callbacks=callbacks)
auto_model.fit(sel_x_train, sel_y_train, validation_data=(sel_x_test, sel_y_test), batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks)

### Show Model 

In [None]:
model = auto_model.export_model()
model.summary()

In [None]:
auto_model.tuner.results_summary()

In [None]:
dict(auto_model.tuner.get_best_hyperparameters(num_trials=1)[0])

In [None]:
a = [[1, 1, 1, 1], [2, 2, 2, 2]]

In [None]:
[a1[:3] for a1 in a]