# TensorFlow for Image Classification

This notebook provides a complete workflow for **image classification** using **TensorFlow** and **Keras**.The popular **MNIST dataset** of handwritten digits is used, focusing on data preparation with `tf.data`, building a simple fully-connected model, training, and implementing **model checkpointing**.

## Setups

Import the required libraries and confirm the TensorFlow version

In [None]:
#!pip install tensorflow
#!pip install tensorflow_datasets
#!pip install tensorflow_data_validation
#!pip install matplotlib

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

print(f"TensorFlow Version: {tf.__version__}")

## Exploring the Dataset

We load the **MNIST_FASHION dataset** using `tfds.load`.

In [None]:
ds, ds_info = tfds.load('fashion_mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, ds_info)


## Data Loading and Preprocessing with `tf.data`

We reload the data, splitting it into **training** and **testing** sets. The `as_supervised=True` argument makes the dataset yield a `(image, label)` tuple, which is standard for supervised learning.

The core of efficient data handling is the `tf.data.Dataset` API. We'll define a function to **normalize** the image pixel values from the range $\text{[0, 255]}$ (8-bit integer) to the range $\text{[0, 1]}$ (float32), which is generally better for neural network training.

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'fashion_mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Preprocessing Pipeline for Training Data

1.  **`normalize_img`**: Applies the normalization function to every element using `ds.map()`. **`tf.data.AUTOTUNE`** automatically parallelizes this process.
2.  **`cache()`**: Caches the dataset in memory after the first epoch to speed up subsequent epochs.
3.  **`shuffle()`**: Shuffles the entire training dataset.
4.  **`batch(128)`**: Groups elements into batches of size 128.
5.  **`prefetch()`**: Allows the data pipeline to fetch new batches while the GPU is training on the current batch, improving efficiency.

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)


### Preprocessing Pipeline for Test Data

The test data is similarly normalized, batched, cached, and prefetched. **Note**: We **do not** shuffle the test dataset.

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


## Building and Training the Model

We define a simple CNN model using **Sequential Keras**

### Compilation and Training

* **Optimizer**: **Adam** is used for efficient gradient descent.
* **Loss**: **Sparse Categorical Crossentropy** is suitable for multi-class classification when the labels are integers (sparse).
* **Metrics**: We track **Sparse Categorical Accuracy**.

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

print("Starting Model Training...")
history = model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
print("Training Complete.")

## Model Checkpointing: Saving and Restoring Weights

Model **checkpointing** is a critical feature that saves the model's weights during training. This prevents data loss from crashes and allows you to resume training or evaluate the best-performing weights later.

In [None]:
import os


checkpoint_path = "model/cp.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(ds_train,
          epochs=2,
          validation_data=ds_test,
          callbacks=[cp_callback])  # Pass callback to training


### Evaluating an Untrained Model

To confirm the checkpointing works, we create a **new, untrained instance** of the model and evaluate it. It should perform poorly (near-random guessing).

In [None]:
# Create a basic model instance
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Evaluate the model
loss, acc = model.evaluate(ds_test, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))


### Restoring Weights and Re-evaluation

Now we load the weights from the saved checkpoint into the untrained model instance and re-evaluate. The accuracy should match the one achieved at the end of the previous training run.

In [None]:
# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(ds_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))


### Checkpointing with Epoch Information (Example 2: Multiple Checkpoints)

By including `{epoch:04d}` in the checkpoint path, we can save a unique checkpoint file for specific training intervals (e.g., every 5 epochs). This is useful for tracking performance over time or choosing the best model.

In [None]:
import numpy as np

# Include the epoch in the file name
checkpoint_path = "model_2/cp-{epoch:04d}.keras"
checkpoint_dir = os.path.dirname(checkpoint_path)
os.makedirs(checkpoint_dir, exist_ok=True)

batch_size = 128

# Create a callback that saves the model's weights every 5 epochs (using steps based on batch_size)
# save_freq='epoch' or an integer (number of batches)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=False,
    save_freq='epoch') # Save after every epoch in this case for simplicity

# Create a new model instance
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Train the model with the new callback
model.fit(ds_train,
          epochs=2,
          callbacks=[cp_callback],
          validation_data=ds_test,
          verbose=1)


### Finding the Latest Checkpoint

The `tf.train.latest_checkpoint` utility function is invaluable for finding the path to the most recently saved checkpoint within a directory. This is ideal for resuming training.

In [None]:
import os
import glob

def get_latest_model(folder, extension):
    # Collect all files matching the extension
    files = glob.glob(os.path.join(folder, extension))
    if not files:
        return None

    # Pick the most recently modified file
    latest_file = max(files, key=os.path.getmtime)
    return latest_file


In [None]:
latest = get_latest_model("model_2", "*.keras")
print(f"Latest checkpoint found: {latest}")

### Loading the Latest Checkpoint

We load the weights from the latest checkpoint into a new model instance.

In [None]:
# Create a new model instance
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(ds_test, verbose=2)
print("Restored model from latest checkpoint, accuracy: {:5.2f}%".format(100 * acc))


---

### Retraining Latest Checkpoint

We load the weights from the latest checkpoint into a new model instance and retrain it

In [None]:
# Train the model with the new callback
model.fit(ds_train,
          epochs=1,
          callbacks=[cp_callback],
          validation_data=ds_test,
          verbose=1)


## Making Predictions

Finally, we can use the trained model to make predictions on new data. To do this correctly, we need to extract a sample, ensure it has the correct shape (`(1, 28, 28)` for a batch of 1 image), and then run `model.predict()`. The output will be **logits** (raw scores) which need to be converted to probabilities (e.g., using `tf.nn.softmax`) to get the final class prediction.

In [None]:
import numpy as np

# Get one sample from the test dataset
# ds_test yields (image, label), so we need to iterate once.
for image_batch, label_batch in ds_test.take(1):
    sample_image = image_batch.numpy()[4] # Take the 5th image in the batch
    sample_label = label_batch.numpy()[4] # Take the 5th label in the batch

# The model expects a batch, so we add a batch dimension (1, 28, 28)
sample_for_prediction = tf.expand_dims(sample_image, axis=0)

predictions = model.predict(sample_for_prediction)
predicted_class = tf.argmax(predictions[0]).numpy()

print(f"True Label: {sample_label}")
print(f"Raw Logits: {predictions[0]}")
print(f"Predicted Class: {predicted_class}")
