In [4]:
import tensorflow as tf
from tensorflow.keras import models, layers
import matplotlib.pyplot as plt
from IPython.display import HTML

from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [16]:

IMAGE_SIZE = 256
CHANNELS = 3

training_data_generator = ImageDataGenerator(
    rotation_range=10,
    horizontal_flip=True,
    rescale=1./255
)

train_gen = training_data_generator.flow_from_directory(
    'dataset/train', 
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size = 32,
    class_mode='sparse',
    save_to_dir=None
)

Found 1507 images belonging to 3 classes.


split_folders --output dataset --ratio .7 .1 .2 -- ../Dataset/PlantVillage

In [17]:
validation_data_generator = ImageDataGenerator(
    rotation_range=10,
    horizontal_flip=True,
    rescale=1./255
)

validation_gen = validation_data_generator.flow_from_directory(
    'dataset/val', 
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size = 32,
    class_mode='sparse',
    save_to_dir=None
)

Found 215 images belonging to 3 classes.


In [18]:
test_data_generator = ImageDataGenerator(
    rotation_range=10,
    horizontal_flip=True,
    rescale=1./255
)

test_gen = test_data_generator.flow_from_directory(
    'dataset/test', 
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size = 32,
    class_mode='sparse',
    save_to_dir=None
)

Found 431 images belonging to 3 classes.


In [20]:
input_shape = (IMAGE_SIZE, IMAGE_SIZE, CHANNELS)
n_classes = 3

model = models.Sequential([
    layers.Resizing(IMAGE_SIZE, IMAGE_SIZE, input_shape=input_shape),
    layers.Rescaling(1.0 / 255),
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),

    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Flatten(),
    layers.Dense(n_classes, activation='softmax')
])

  super().__init__(**kwargs)


In [21]:
model.summary()

In [22]:
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])

In [23]:
history = model.fit(
train_gen, 
batch_size=32,
validation_data=validation_gen, 
verbose=1, 
steps_per_epoch=47,
validation_steps = 6,
epochs=20
)

  self._warn_if_super_not_called()


Epoch 1/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 2s/step - accuracy: 0.4569 - loss: 0.9437 - val_accuracy: 0.4583 - val_loss: 0.8840
Epoch 2/20
[1m 1/47[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m47s[0m 1s/step - accuracy: 0.4688 - loss: 1.0091



[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 72ms/step - accuracy: 0.4688 - loss: 1.0091 - val_accuracy: 0.4792 - val_loss: 0.8935
Epoch 3/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 2s/step - accuracy: 0.4705 - loss: 0.9059 - val_accuracy: 0.4844 - val_loss: 0.9143
Epoch 4/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 167ms/step - accuracy: 0.5938 - loss: 0.9379 - val_accuracy: 0.4583 - val_loss: 0.8970
Epoch 5/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 2s/step - accuracy: 0.4427 - loss: 0.9068 - val_accuracy: 0.4635 - val_loss: 0.9038
Epoch 6/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 132ms/step - accuracy: 0.4375 - loss: 0.8249 - val_accuracy: 0.4531 - val_loss: 0.9042
Epoch 7/20
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 1s/step - accuracy: 0.4603 - loss: 0.9109 - val_accuracy: 0.4479 - val_loss: 0.9165
Epoch 8/20
[1m47/47[0m [32m━━━━━━━━━━━━━━

In [24]:
score = model.evaluate(test_gen)

  self._warn_if_super_not_called()


[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 894ms/step - accuracy: 0.4640 - loss: 0.9070


In [25]:
score

[0.9070383310317993, 0.46403712034225464]

In [26]:
model.save('../potatoes.h5')

