In [None]:
import matplotlib.pyplot as plt
import os
import seaborn as sns
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
from tensorflow_examples.lite.model_maker.core.task import image_preprocessing

from tflite_model_maker import image_classifier
from tflite_model_maker import ImageClassifierDataLoader
from tflite_model_maker.image_classifier import ModelSpec

In [None]:
# Retrieve the cassava plant disease dataset and splits into training, validation and test datasets
tfds_name = 'cassava'
(ds_train, ds_validation, ds_test), ds_info = tfds.load(
    name=tfds_name,
    split=['train', 'validation', 'test'],
    with_info=True,
    as_supervised=True)
TFLITE_NAME_PREFIX = tfds_name

In [None]:
# Construct the list of labels and loads the training and validation datasets
label_names = ds_info.features['label'].names

train_data = ImageClassifierDataLoader(ds_train,
                                       ds_train.cardinality(),
                                       label_names)
validation_data = ImageClassifierDataLoader(ds_validation,
                                            ds_validation.cardinality(),
                                            label_names)

model_name = 'mobilenet_v3_large_100_224' 

map_model_name = {
    'cropnet_cassava':
        'https://tfhub.dev/google/cropnet/feature_vector/cassava_disease_V1/1',
    'cropnet_concat':
        'https://tfhub.dev/google/cropnet/feature_vector/concat/1',
    'cropnet_imagenet':
        'https://tfhub.dev/google/cropnet/feature_vector/imagenet/1',
    'mobilenet_v3_large_100_224':
        'https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5',
}

model_handle = map_model_name[model_name]

image_model_spec = ModelSpec(uri=model_handle)

In [None]:
# Build the model by training the model with the training dataset
# WARNING: This takes almost 1.5 hours to run

model = image_classifier.create(
    train_data,
    model_spec=image_model_spec,
    batch_size=128,
    learning_rate=0.03,
    epochs=5,
    shuffle=True,
    train_whole_model=True,
    validation_data=validation_data)

In [None]:
# This is a mapping dictionary of the disease codes to names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

[(name_map[x],x) for x in label_names]

In [None]:
# Use the trained model to make predictions off the standard test dataset split from cassava 
test_data = ImageClassifierDataLoader(ds_test, ds_test.cardinality(),
                                      label_names)
model.predict_top_k(test_data)

In [None]:
# Retrieve the image files from Amazon S3
import boto3
import io
import matplotlib.image as mpimg

ACCESS_ID = "AKIAVQY6VF263Q7CZX6P"
ACCESS_KEY = "NECJqrijQ9dZyuLKqrZSr8smxt1lk4GYReBKB1v3"

s3 = boto3.resource('s3', region_name='us-west-2',
         aws_access_key_id=ACCESS_ID,
         aws_secret_access_key= ACCESS_KEY)
bucket = s3.Bucket('planttest123')

bstream = io.BytesIO()
bucket.Object('image_1.JPG').download_fileobj(bstream)
img1 = mpimg.imread(bstream, format="JPEG")

bstream = io.BytesIO()
bucket.Object('image_2.JPG').download_fileobj(bstream)
img2 = mpimg.imread(bstream, format="JPEG")

In [None]:
# Construct a new dataset using my image files 
input_1 = np.array((img1, img2))
input_2 = np.array((0, 0))

ds_test2 = tf.data.Dataset.from_tensor_slices((input_1, input_2))

#tf.data.Dataset.from_tensor_slices([img1, img2])
test_data2 = ImageClassifierDataLoader(ds_test2, ds_test2.cardinality(),
                                      label_names)

In [None]:
# Show the images 
from PIL import Image as im
labels = [(name_map[x[0][0]],x[0][1]) for x in model.predict_top_k(test_data2)]

i=0
for d in ds_test2:
    display(im.fromarray(d[0].numpy()))
    print(labels[i])
    i+=1