Part 1: # Model training (Lightweight classifier)

In [3]:
# Setup

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Paths
data_dir = '/content/Image_Dataset/TrashType_Image_Dataset/'

# Image preprocessing
img_size = (128, 128)
batch_size = 32

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

train_data = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    subset='training',
    class_mode='categorical'
)

val_data = datagen.flow_from_directory(
    data_dir,
    target_size=img_size,
    batch_size=batch_size,
    subset='validation',
    class_mode='categorical'
)

num_classes = train_data.num_classes

Found 2024 images belonging to 6 classes.
Found 503 images belonging to 6 classes.


In [4]:
# Model architecture (MobileNetV2)

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(128, 128, 3),
    include_top=False,
    weights='imagenet'
)

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_data, validation_data=val_data, epochs=10)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


  self._warn_if_super_not_called()


Epoch 1/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m157s[0m 2s/step - accuracy: 0.5783 - loss: 1.2326 - val_accuracy: 0.3459 - val_loss: 3.3856
Epoch 2/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 2s/step - accuracy: 0.8393 - loss: 0.4965 - val_accuracy: 0.3161 - val_loss: 6.3836
Epoch 3/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 2s/step - accuracy: 0.8822 - loss: 0.4034 - val_accuracy: 0.1988 - val_loss: 12.0115
Epoch 4/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 2s/step - accuracy: 0.9035 - loss: 0.2901 - val_accuracy: 0.3241 - val_loss: 5.2218
Epoch 5/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m122s[0m 2s/step - accuracy: 0.9078 - loss: 0.2674 - val_accuracy: 0.2008 - val_loss: 6.2778
Epoch 6/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 2s/step - accuracy: 0.9302 - loss: 0.2357 - val_accuracy: 0.2107 - val_loss: 9.0078
Epoch 7/10
[1m64/64[0m [32m━━━

<keras.src.callbacks.history.History at 0x7c3815a833e0>

Part 2: Convert to TensorFlow Lite


In [5]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('/content/model.tflite', 'wb') as f:
    f.write(tflite_model)

Saved artifact at '/tmp/tmp8k953avs'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name='keras_tensor_154')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  136580326893008: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326893584: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326896272: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326895888: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326894736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326896464: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326894928: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326897040: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326896656: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136580326894544: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1365803268

Part 3: Test TFLite model


In [6]:
import os
from PIL import Image
import numpy as np

interpreter = tf.lite.Interpreter(model_path='/content/model.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

test_folder = '/content/Image_Dataset/TrashType_Image_Dataset/'
class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

results = []
for class_dir in os.listdir(test_folder):
    class_path = os.path.join(test_folder, class_dir)
    if not os.path.isdir(class_path):
        continue

    for filename in os.listdir(class_path):
        if filename.lower().endswith(('.jpg', '.png')):
            img_path = os.path.join(class_path, filename)
            try:
                image = Image.open(img_path).convert('RGB').resize((128, 128))
                image = np.array(image) / 255.0
                image = np.expand_dims(image, axis=0).astype(np.float32)

                interpreter.set_tensor(input_details[0]['index'], image)
                interpreter.invoke()
                output = interpreter.get_tensor(output_details[0]['index'])

                predicted_class = np.argmax(output)
                confidence = np.max(output)
                predicted_label = class_names[predicted_class]

                results.append({
                    'filename': filename,
                    'true_label': class_dir,
                    'predicted_label': predicted_label,
                    'confidence': round(confidence, 2)
                })

            except Exception as e:
                print(f"Error processing {img_path}: {e}")

# Display results
for r in results:
    print(f"{r['filename']} | True: {r['true_label']} → Predicted: {r['predicted_label']} ({r['confidence']})")

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


cardboard_148.jpg | True: cardboard → Predicted: metal (0.9900000095367432)
cardboard_061.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_130.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_386.jpg | True: cardboard → Predicted: metal (0.9900000095367432)
cardboard_323.jpg | True: cardboard → Predicted: metal (0.9900000095367432)
cardboard_204.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_036.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_154.jpg | True: cardboard → Predicted: metal (0.9700000286102295)
cardboard_022.jpg | True: cardboard → Predicted: metal (0.9800000190734863)
cardboard_017.jpg | True: cardboard → Predicted: cardboard (0.5400000214576721)
cardboard_383.jpg | True: cardboard → Predicted: trash (0.9599999785423279)
cardboard_060.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_284.jpg | True: cardboard → Predicted: metal (0.9900000095367432)
cardboard_196.jpg | True: cardboard → Predicted: metal (1.0)
cardboard_322.jpg | Tr

Part 4: Accuracy metrics

In [7]:
from sklearn.metrics import classification_report, confusion_matrix

y_pred = model.predict(val_data)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = val_data.classes

print("Classification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_names))

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred_classes))

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 443ms/step
Classification Report:
              precision    recall  f1-score   support

   cardboard       0.39      0.09      0.14        80
       glass       0.00      0.00      0.00       100
       metal       0.16      0.71      0.26        82
       paper       0.00      0.00      0.00       118
     plastic       0.16      0.03      0.05        96
       trash       0.04      0.15      0.07        27

    accuracy                           0.14       503
   macro avg       0.12      0.16      0.09       503
weighted avg       0.12      0.14      0.08       503

Confusion Matrix:
[[ 7  0 54  1  3 15]
 [ 2  0 73  0  6 19]
 [ 2  0 58  1  2 19]
 [ 2  0 86  0  5 25]
 [ 3  0 80  0  3 10]
 [ 2  0 21  0  0  4]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
