<a href="https://colab.research.google.com/github/chel310/Trash_classifier/blob/main/trash_classifier_MobileNetV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Lets clone our repository initially

We will be using tensorflow as our default library for transfer learning. Code uasge  was based on [tensorflow tutorial](https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub) for the same

In [None]:


import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow_hub as hub

: 

Lets load our training images for processing

In [None]:
batch_size = 32
img_height = 224
img_width = 224
data_root = 'dataset'
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

: 

In [None]:
class_names = np.array(train_ds.class_names)
print('class names for predictions :',class_names)

: 

TensorFlow Hub's conventions for image models is to expect **float** inputs in the **[0, 1]** range. Use the **Rescaling layer** to achieve this.

In [None]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))


: 

Let's make sure to use buffered prefetching so we can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data.


In [None]:
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)

: 

In [None]:
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

: 

TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

In [None]:
feature_extractor_model = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model, input_shape=(224, 224, 3), trainable=False)

: 

: 

In [None]:
num_classes = len(class_names)

model = tf_keras.Sequential([
  feature_extractor_layer,
  tf_keras.layers.Dense(num_classes, name = 'output')
])

model.summary()

: 

In [None]:
image_batch.shape

: 

# Train the model

In [None]:
model.compile(
  optimizer=tf_keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

: 

 To visualize the training progress, use a custom callback to log the loss and accuracy of each batch individually, instead of the epoch average.

In [None]:
class CollectBatchStats(tf_keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()

batch_stats_callback = CollectBatchStats()

history = model.fit(train_ds, epochs=5,
                    callbacks=[batch_stats_callback])


: 

Now after, even just a few training iterations, we can already see that the model is making progress on the task.

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)

: 

In [None]:
plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)

: 

To redo the plot from before, first get the ordered list of class names:



In [None]:
predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]


: 

Plot the result

In [None]:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")


: 

## Save the Model

In [None]:
t = time.time()

export_path = "saved_models/MobileNetV2"
model.save(export_path)

export_path

: 

: 