In [None]:
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
import os
import numpy as np
import tensorflow_datasets as tfds
import warnings
warnings.filterwarnings('ignore')

In [None]:
dataset, info = tfds.load(name='cassava', with_info=True, as_supervised=True, split=['train', 'test', 'validation'])

In [None]:
train, info_train = tfds.load(name='cassava', with_info=True, split='test')
tfds.show_examples(info_train, train)

In [None]:
def sc(image, label):
  image = tf.cast(image, tf.float32)
  image /=255.0
  return tf.image.resize(image,[224, 224]), tf.one_hot(label, 5)

In [None]:
def get_dataset(batch_size = 32):
  train_dataset_sc = dataset[0].map(sc).shuffle(1000).batch(batch_size)
  test_dataset_sc = dataset[1].map(sc).batch(batch_size)
  val_dataset_sc = dataset[2].map(sc).batch(batch_size)
  return train_dataset_sc, test_dataset_sc, val_dataset_sc

In [None]:
train_dataset, test_dataset, val_dataset = get_dataset()
train_dataset.cache()
val_dataset.cache()

In [None]:
f_e = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"

In [None]:
f_e_layer = hub.KerasLayer(f_e, input_shape=(224, 224, 3))

In [None]:
f_e_layer.trainable = False

In [None]:
model = tf.keras.Sequential([
    f_e_layer,
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(5, activation='softmax')
])

model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['acc']
)

In [None]:
history = model.fit(train_dataset, epochs = 30, validation_data=val_dataset)

In [None]:
res = model.evaluate(test_dataset)

In [None]:
for t_sample in dataset[1].take(10):
  image, label = t_sample[0], t_sample[1]
  img_sc, label_array = sc(t_sample[0], t_sample[1])
  img_sc = np.expand_dims(img_sc, axis = 0)
  img = tf.keras.preprocessing.image.img_to_array(image)
  pred = model.predict(img_sc)
  #print(pred)
  plt.figure()
  plt.imshow(image)
  plt.show()
  print("Given: %s" % info.features["label"].names[label.numpy()])
  print("Predicted: %s" % info.features["label"].names[np.argmax(pred)])

In [None]:
for f0, f1 in dataset[1].map(sc).batch(200):
  y = np.argmax(f1, axis=1)
  y_pred = np.argmax(model.predict(f0), axis=1)
  print(tf.math.confusion_matrix(labels=y, predictions=y_pred, num_classes = 5))