In [1]:
import os
import gc
import cv2
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
from datetime import datetime
import tensorflow_datasets as tfds



ModuleNotFoundError: No module named 'tensorflow'

In [None]:
def load_img_data_colorectal_histology(img_size):
    dataset = tfds.load('colorectal_histology', split=['train'], as_supervised=True)
    array = np.vstack(tfds.as_numpy(dataset[0]))
    x_data = np.array(list(map(lambda x: x[0], array)))
    y_data = np.array(list(map(lambda x: x[1], array)))

    X_train, X_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.3, stratify=y_data)
    X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, test_size=0.25, stratify=y_val)

    print('Number of training images: %i' %X_train.shape[0])
    print('Number of validation images: %i' %X_val.shape[0])
    print('Number of test images: %i' %X_test.shape[0])
    print('Image Size: ', X_train[0].shape)
    return (X_train, y_train), (X_val, y_val), (X_test, y_test)


def preprocessed_dataset(n_classes, img_size):
    (x_train, y_train), (x_val, y_val), (x_test, y_test) = load_img_data_colorectal_histology(img_size)

    x_train = x_train.astype('float32')
    x_val = x_val.astype('float32')
    x_test = x_test.astype('float32')

    y_train = tf.keras.utils.to_categorical(y_train, num_classes=n_classes)
    y_val = tf.keras.utils.to_categorical(y_val, num_classes=n_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)

    print('x_train: ', x_train.shape)
    print('y_train: ', y_train.shape)

    train_DataGen = tf.keras.preprocessing.image.ImageDataGenerator(
                                                              zoom_range=0.2,
                                                              width_shift_range=0.1,
                                                              height_shift_range = 0.1,
                                                              horizontal_flip=True
                                                              )

    valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator()

    test_datagen = tf.keras.preprocessing.image.ImageDataGenerator()

    train_set = train_DataGen.flow(x_train, y_train, batch_size=32)
    valid_set = valid_datagen.flow(x_val, y_val, batch_size=32)
    test_set = test_datagen.flow(x_test, y_test, batch_size=1)

    return train_set, valid_set, test_set

In [None]:
img_size = 150
n_classes = 8

train_set, valid_set, test_set = preprocessed_dataset(n_classes, img_size)

tf.keras.backend.clear_session()

model = None

#########################################
#########################################
#Add Model definition here!
#########################################
#########################################
    
    
sgd = SGD(learning_rate = 0.01, momentum = 0.9, nesterov = True)
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
model.build(input_shape = (None, img_size, img_size, 3))
model.summary()

fit_history = model.fit(train_set,
              validation_data = valid_set,
              epochs = 200,
              verbose = 1)

loss_val = fit_history.history['loss']
loss, acc = model.evaluate(test_set, verbose=0)

test_accuracy = 'Test set accuracy is: %f' %acc
print(test_accuracy)


now = datetime.now().strftime("%Y%m%d-%H%M%S")

hist = pd.DataFrame(fit_history.history) 
csv_file = now + '_training_curve.csv'
with open(csv_file, mode='w') as f:
    hist.to_csv(f)
    f.write(test_accuracy)

pyplot.figure(1, figsize = (15,8))

pyplot.subplot(221)
pyplot.plot(fit_history.history['accuracy'])
pyplot.plot(fit_history.history['val_accuracy'])
pyplot.title('model accuracy')
pyplot.ylabel('accuracy')
pyplot.xlabel('epoch')
pyplot.legend(['train', 'valid'])

pyplot.subplot(222)
pyplot.plot(fit_history.history['loss'])
pyplot.plot(fit_history.history['val_loss'])
pyplot.title('model loss')
pyplot.ylabel('loss')
pyplot.xlabel('epoch')
pyplot.legend(['train', 'valid'])

filename = now + '_loss_curve' + '.tif'
pyplot.savefig(filename)
pyplot.close()