This module will demonstrate more on multi-task learning

In [None]:
# Multi-Task Learning on MNIST: digit classification + parity classification

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.datasets import mnist

# 1. Load and preprocess MNIST
(x_train, y_train_digit), (x_test, y_test_digit) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test  = x_test.astype("float32")  / 255.0

# Add channel dimension
x_train = np.expand_dims(x_train, -1)
x_test  = np.expand_dims(x_test, -1)

# Create parity labels: 0 if even, 1 if odd
y_train_parity = y_train_digit % 2
y_test_parity  = y_test_digit  % 2

# 2. Build a multi-task model
inputs = layers.Input(shape=(28, 28, 1))

# Shared base
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.MaxPooling2D()(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)

# Task‐1 head: digit classification (10 classes)
digit_output = layers.Dense(10, activation="softmax", name="digit_output")(x)

# Task‐2 head: parity classification (2 classes)
parity_output = layers.Dense(2, activation="softmax", name="parity_output")(x)

model = Model(inputs=inputs, outputs=[digit_output, parity_output])

# 3. Compile with two losses and optional loss weights
model.compile(
    optimizer="adam",
    ## Loss functions for separate tasks
    loss={
        "digit_output": "sparse_categorical_crossentropy",
        "parity_output": "sparse_categorical_crossentropy",
    },
    loss_weights={
        "digit_output": 1.0,
        "parity_output": 0.5,  # give parity half the weight (To distribute focus)
    },
    metrics={
        "digit_output": "accuracy",
        "parity_output": "accuracy",
    },
)

model.summary()

# 4. Train on both tasks simultaneously
history = model.fit(
    x_train,
    {"digit_output": y_train_digit, "parity_output": y_train_parity},
    validation_split=0.1,
    epochs=5,
    batch_size=128,
    verbose=2,
)

# 5. Evaluate on test data
results = model.evaluate(
    x_test,
    {"digit_output": y_test_digit, "parity_output": y_test_parity},
    verbose=0,
)
print(f"\nTest digit accuracy : {results[3]:.4f}")
print(f"Test parity accuracy: {results[4]:.4f}")


Epoch 1/5
422/422 - 8s - 20ms/step - digit_output_accuracy: 0.9369 - digit_output_loss: 0.2156 - loss: 0.2805 - parity_output_accuracy: 0.9486 - parity_output_loss: 0.1298 - val_digit_output_accuracy: 0.9815 - val_digit_output_loss: 0.0695 - val_loss: 0.0997 - val_parity_output_accuracy: 0.9797 - val_parity_output_loss: 0.0611
Epoch 2/5
422/422 - 7s - 17ms/step - digit_output_accuracy: 0.9814 - digit_output_loss: 0.0603 - loss: 0.0822 - parity_output_accuracy: 0.9844 - parity_output_loss: 0.0439 - val_digit_output_accuracy: 0.9880 - val_digit_output_loss: 0.0443 - val_loss: 0.0620 - val_parity_output_accuracy: 0.9883 - val_parity_output_loss: 0.0361
Epoch 3/5
422/422 - 6s - 15ms/step - digit_output_accuracy: 0.9870 - digit_output_loss: 0.0416 - loss: 0.0575 - parity_output_accuracy: 0.9892 - parity_output_loss: 0.0318 - val_digit_output_accuracy: 0.9898 - val_digit_output_loss: 0.0377 - val_loss: 0.0515 - val_parity_output_accuracy: 0.9897 - val_parity_output_loss: 0.0283
Epoch 4/5
422

This is a simple example demonstrating how to use shared based layer (h(shared)) to see more examples, so they can learn more robust, general-purpose features. This can prevent from overfitting.

Results:
Both tasks learn quickly with high accuracy
Shared representation supports multiple outputs (single base suffices for two classification tasks -> parameter efficientcy)
Good generalization with minimal overfitting (Multi-task sharing act as a regularizer, forcing the model to learn features that are useful rather than overfitting to one)

Takeaway:
Multi-task learning here gives you a compact model that handles two related classification problems simultaneously, achieving state-of-the-art–level accuracy on MNIST digit recognition while also solving a secondary parity task without incurring extra overfitting.