In [None]:
import numpy as np
from scipy.io import loadmat, savemat

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras.callbacks import ModelCheckpoint

In [None]:
datadict = loadmat('split_set.mat')
train_imgs = datadict['train_imgs']
train_target = datadict['train_target'][0]
val_imgs = datadict['val_imgs']
val_target = datadict['val_target'][0]

In [None]:
X = tf.cast(train_imgs, tf.float32)
y = tf.cast(train_target, tf.int8)

In [None]:
img_height = 256
img_width = 256
channels = 3
classes = 3

epochs = 50

c1 = 0.075; c2 = 0.2

data_augmentation = Sequential([
    layers.RandomFlip("horizontal",
                      input_shape=(img_height,
                                  img_width,
                                  channels)),
    layers.RandomRotation(c1),
    layers.RandomTranslation(c1, c1),
    layers.RandomZoom(height_factor = (-1*c1, c1)),
    layers.RandomContrast(c1),
])



for i in range(0,60):

    model = Sequential([
    data_augmentation,
    layers.Conv2D(8, 3, padding = 'same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(c2),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(classes)
    ])

    model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

    model_callback = ModelCheckpoint(f'Test_models/model{i+1}.hdf5',
                                  monitor='val_accuracy',
                                  save_best_only=True)

    history = model.fit(x=X,
                  y=y,
                  epochs=epochs,
                  validation_split=0.15,
                  callbacks=[model_callback])

    perfdict = history.history
    savemat(f'Test_models/model{i+1}.mat', perfdict)
    del model