In [None]:
import pandas as pd
# import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

IMG_HEIGHT = 50
IMG_WIDTH = 50
PATH = '../input/iris-computer-vision'

# Add our data-augmentation parameters to ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255.,
                                   horizontal_flip = True)

val_datagen = ImageDataGenerator( rescale = 1./255.)

training_ds = train_datagen.flow_from_directory(
    PATH, 
    target_size=(IMG_HEIGHT, IMG_WIDTH), 
    class_mode='categorical', 
)

validation_ds = val_datagen.flow_from_directory(
    PATH, 
    target_size=(IMG_HEIGHT, IMG_WIDTH), 
    class_mode='categorical', 
)

class_names = training_ds.classes

In [None]:
training_ds.image_shape

In [None]:
model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=training_ds.image_shape), 
        tf.keras.layers.MaxPooling2D(), 
        tf.keras.layers.Conv2D(64, 3, activation='relu'), 
        tf.keras.layers.MaxPooling2D(), 
        tf.keras.layers.Flatten(), 
        tf.keras.layers.Dense(256, activation='relu'), 
        tf.keras.layers.Dense(3, activation='softmax')
    ])

checkpoint = tf.keras.callbacks.ModelCheckpoint('./model-{epoch:03d}-{val_acc:03f}-{val_loss:03f}.h5', save_best_only=True, monitor='val_loss', mode='min')
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss='categorical_crossentropy', metrics=['acc'])

history = model.fit(
            training_ds,
            validation_data = validation_ds,
            epochs = 50,
            verbose = 1, 
            callbacks=[checkpoint])

In [None]:
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()


plt.show()