# 1. initial Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#import numpy for number array handling and represent rgb image pixel values
import numpy as np
from PIL import Image

#Import and initialize WandB
# import wandb

#import tensorflow to use any tools needed for deep learning
import tensorflow as tf

#import keras api needed to implement deep learning techiques
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, BatchNormalization, Conv2D, MaxPool2D, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# from focal_loss import SparseCategoricalFocalLoss

#import libraries for visualization of data
import matplotlib.pyplot as plt

#Allow charts and graphics to display right below the page of browser setup
%matplotlib inline

In [None]:
from conversion import ModelConverter

from model import MyModel
# from examples.wandb_tracker import WandBTracker, TrainTrackingCallback
from examples.mlflow_tracker import MLFlowTracker, MLFlowTrainTrackingCallback
from metrics import plot_loss, plot_accuracy, print_confusion_matrix
from utils import show_worst_preds, crop_resize_image

# 2. Load and Split images along with applying Data Preprocessing and Data Augmentation

In [None]:
#paths to the train, validation and test image datasets 
train_path = '../datasets/kaggle_dataset/images/'
valid_path = '../datasets/kaggle_dataset/images/'

BATCH_SIZE = 16
CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

tracker = MLFlowTracker("trash-classification")

classifier = MyModel(CLASSES, BATCH_SIZE, tracker)
classifier.load_dataset(train_path, valid_path)

train_batches = classifier.train_batches
valid_batches = classifier.valid_batches

# 3. Visualization of the images after Preprocessing

In [None]:
# plot images after applying VGG16 data preprocessing method
def plotImages(images):
    fig, axes = plt.subplots(1, 6, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip(images, axes):
        ax.imshow(img.astype(np.uint8))
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
imgs, labels = next(train_batches)
plotImages(imgs)

# 4. Building CNN Architecture

In [None]:
# set the input image size for proposed CNN model
classifier.build_model()

# 5. Compile the Built CNN Model

In [None]:
# compile the built CNN model by selecting suitable optimizer and loss function
# myFocalLoss = SparseCategoricalFocalLoss(gamma=2)
# myFocalLoss = focal_loss(alpha=0.25)

classifier.compile()

# 6. Train the CNN model

In [None]:
# train the model with appropriate number of epochs
model_details = classifier.fit(epochs=30)

# With VGG16: Epoch 18/18
# 143/143 - 35s - loss: 0.3366 - accuracy: 0.8814 - val_loss: 0.3784 - val_accuracy: 0.8645

In [None]:
# store the losses of training
loss = model_details.history['loss']
validation_loss = model_details.history['val_loss']

In [None]:
# store the accuracy of training
accuracy = model_details.history['accuracy']
validation_accuracy = model_details.history['val_accuracy']

# 7. Fine Tune the CNN model

In [None]:
# unfreeze the convolution base of the base model inorder to fine-tune which adapt these pre-trained weights 
# to work with the new dataset
classifier.base_model.trainable=True

In [None]:
# train and fine-tune the model with appropriate number of epochs
model_details = classifier.fit(epochs=10)

# 8. Visualization of Accuracy and Loss in Training and  Validation sets

In [None]:
# append the losses to previous stored losses
loss.extend(model_details.history['loss'])
validation_loss.extend(model_details.history['val_loss'])

In [None]:
# append the accuracy to previous stored accuracy
accuracy.extend(model_details.history['accuracy'])
validation_accuracy.extend(model_details.history['val_accuracy'])

In [None]:
# plot the training and validation losses
plot_loss(loss, validation_loss)

In [None]:
# plot the training and validation accuracy
plot_accuracy(accuracy, validation_accuracy)

# Finish tracker run

In [None]:
classifier.tracker.finish_run()

# Confusion matrix

In [None]:
Y_pred = classifier.predict(valid_batches)

In [None]:
print_confusion_matrix(classifier.model, valid_batches, Y_pred, CLASSES)

## Print problematic cases

In [None]:
show_worst_preds(valid_batches, Y_pred, CLASSES)

## Convert Model to TFLite

In [None]:
# Convert model to TF Lite
converter = ModelConverter(classifier.model)


In [None]:
converter.to_tflite('../models/model.tflite')

In [None]:
converter.to_tflite_fp16('../models/model_fp16.tflite')

In [None]:
def representative_dataset():
  for data in tf.data.Dataset.from_generator(lambda: train_batches, (tf.float32, tf.float32)).batch(1).take(100):
    yield [data[0][0]]

converter.to_tflite_quantized('../models/model_int8.tflite', representative_dataset)

In [None]:
converter.to_tfjs('../models/js/')

# Test model

In [None]:
model.save('../models/saved_model')

In [None]:
# for i in valid_batches[0][0][0]:
#     print(i.shape)
img = valid_batches[0][0][0]
plotImages([img])

In [None]:
from PIL import Image
img1 = Image.open('../datasets/kaggle_dataset/images/cardboard/cardboard10.jpg')
img = crop_resize_image(img1)
img = np.array(img)
img = img.astype(np.float32)

# img = tf.keras.applications.mobilenet_v3.preprocess_input(img.astype(np.float32))
img = np.expand_dims(img, axis=0)
print(img.shape, img.dtype)

In [None]:
interpreter = tf.lite.Interpreter('../models/model_fp16.tflite')
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = img
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(CLASSES)
