# EuroSAT Image Classification — Full Notebook

This notebook demonstrates loading the EuroSAT RGB dataset (via TensorFlow Datasets), preprocessing, transfer-learning training (MobileNetV2), evaluation (accuracy, precision, recall, F1, confusion matrix, ROC AUC, top-3 accuracy), saving the model, prediction on a single image, and retraining with local `data/` folders.

**Notes:**
- This notebook assumes you have `tensorflow`, `tensorflow-datasets`, `scikit-learn`, and the `src/` package available (see `requirements.txt`).
- You can also replace the TFDS loading with local `data/train` and `data/test` folders if you prefer.

In [None]:
# Imports and environment checks
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support, roc_auc_score, top_k_accuracy_score
print("TensorFlow version:", tf.__version__)


## Load EuroSAT from TFDS (RGB)
We split dataset into train/val/test. For quicker runs, reduce the training fraction.

In [None]:
# Load Eurosat rgb via TFDS
(ds_all, ), ds_info = tfds.load('eurosat/rgb', split=['train'], with_info=True, as_supervised=True)
print(ds_info)


In [None]:
# Create train/val/test splits from the single 'train' split
total = ds_info.splits['train'].num_examples
print("Total examples:", total)

# We'll use an 70/20/10 split
train_split = 'train[:70%]'
val_split = 'train[70%:90%]'
test_split = 'train[90%:]'

(ds_train, ds_val, ds_test), ds_info = tfds.load('eurosat/rgb', split=[train_split, val_split, test_split], with_info=True, as_supervised=True)
print(ds_info.features)


## Preprocessing: resizing and batching
We will preprocess images to 224x224 and use MobileNetV2 preprocessing function.

In [None]:
IMG_SIZE = (224,224)
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE

from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

def preprocess(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = preprocess_input(image)
    label = tf.cast(label, tf.int32)
    return image, label

def prepare(ds, shuffle=False):
    ds = ds.map(preprocess, num_parallel_calls=AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(1024)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

train_ds = prepare(ds_train, shuffle=True)
val_ds = prepare(ds_val)
test_ds = prepare(ds_test)


## Inspect classes and a few sample images

In [None]:
class_names = ds_info.features['label'].names
print("Classes:", class_names)

import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
for i, (img, lbl) in enumerate(ds_train.take(6)):
    ax = plt.subplot(2,3,i+1)
    plt.imshow(tf.cast(tf.image.resize(img, (128,128)), tf.uint8))
    plt.title(class_names[int(lbl.numpy())])
    plt.axis('off')
plt.show()


## Build transfer-learning model (MobileNetV2)

In [None]:
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2

def build_model(num_classes):
    base = MobileNetV2(weights='imagenet', include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    base.trainable = False
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs=base.input, outputs=outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

num_classes = ds_info.features['label'].num_classes
model = build_model(num_classes)
model.summary()


## Train (first stage: frozen base)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

checkpoint_path = "models/model_latest.h5"
os.makedirs("models", exist_ok=True)

callbacks = [
    EarlyStopping(monitor='val_loss', patience=4, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2),
    ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True)
]

history = model.fit(train_ds, validation_data=val_ds, epochs=8, callbacks=callbacks)


## Fine-tune (unfreeze some of the base model)

In [None]:
# Unfreeze last layers
base = model.layers[1]  # MobileNetV2 base (depends on model summary indexing)
base.trainable = True

# Freeze early layers (optional)
for layer in base.layers[:-30]:
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history_ft = model.fit(train_ds, validation_data=val_ds, epochs=6, callbacks=callbacks)


## Evaluation on test set: Accuracy, Precision, Recall, F1, Confusion Matrix, ROC AUC, Top-3 accuracy

In [None]:
# Collect predictions and true labels
y_true = []
y_pred = []
y_proba = []

for images, labels in test_ds:
    preds = model.predict(images)
    y_true.extend(labels.numpy().tolist())
    y_pred.extend(preds.argmax(axis=1).tolist())
    y_proba.extend(preds.tolist())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_proba = np.array(y_proba)

print("Accuracy:", accuracy_score(y_true, y_pred))
print("\nClassification report:\n", classification_report(y_true, y_pred, target_names=class_names))
cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix shape:", cm.shape)


In [None]:
# Top-3 accuracy
top3 = top_k_accuracy_score(y_true, y_proba, k=3, labels=range(num_classes))
print("Top-3 accuracy:", top3)


## Save model (already saved by checkpoint)

In [None]:
print('Model saved at', checkpoint_path)

## Prediction demo: single image from TFDS or local file

In [None]:
# Pick a test example from TFDS
for img, lbl in ds_test.take(1):
    test_img = tf.image.resize(img, IMG_SIZE)
    inp = preprocess_input(np.expand_dims(test_img.numpy(), axis=0))
    preds = model.predict(inp)
    top_idx = preds[0].argsort()[-3:][::-1]
    print('Top-3 predictions:')
    for idx in top_idx:
        print(class_names[idx], preds[0][idx])


## Retraining with local `data/train` and `data/test` folders
If you have new uploaded images placed into `data/train/<class>` and `data/test/<class>`, you can retrain using the `src.model.train_model` function.

In [None]:
# Example of triggering retrain via src.model using ImageDataGenerator
from src.preprocessing import create_generators
from src.model import train_model

if os.path.exists('data/train') and os.path.exists('data/test'):
    print('Local train/test folders found — running retrain for 3 epochs (demo)...')
    train_gen, val_gen = create_generators('data/train', 'data/test', batch_size=16)
    new_model, new_history = train_model(train_gen, val_gen, out_path='models/model_latest_retrained.h5', epochs=3)
    print('Retraining done, model saved to models/model_latest_retrained.h5')
else:
    print('No local data folders present at data/train and data/test. Place images under these folders to retrain locally.')