In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Flatten, Dense, Reshape, BatchNormalization, Activation
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import MeanSquaredError
from sklearn.metrics import roc_curve, auc, confusion_matrix
import seaborn as sns
from sklearn.model_selection import train_test_split

In [None]:
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
# Normalize the dataset
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1)  # Add channel dimension
x_test = np.expand_dims(x_test, axis=-1)
# merge test and training
x_train = np.concatenate((x_train, x_test))
y_train = np.concatenate((y_train, y_test))
del x_test, y_test

In [None]:
print(x_train.shape, y_train.shape)

In [None]:
# Extract normal (background) digits (excluding digit 5)
normal_digits = np.where(y_train != 5)[0]
anomaly_digits = np.where(y_train == 5)[0]
x_normal = x_train[normal_digits]
x_anomalies = x_train[anomaly_digits]

In [None]:
# Split the normal digits into train, validation, and test sets (50%, 25%, 25%)
x_train_normal, x_temp  = train_test_split(x_normal, test_size=0.5, random_state=42)
x_val_normal, x_test_normal = train_test_split(x_temp, test_size=0.5, random_state=42)

In [None]:
# Combine validation set with anomalies (label 0 for background, 1 for anomalies)
x_val = np.concatenate([x_val_normal, x_anomalies], axis=0)
y_val = np.concatenate([np.zeros(len(x_val_normal)), np.ones(len(x_anomalies))], axis=0)
# shuffle them if you want. Not really needed

In [None]:
print(x_train_normal.shape, x_val_normal.shape, x_test_normal.shape)
print(x_val.shape, y_val.shape)

In [None]:
# Visualize some digits from the dataset
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
digits = [1,2,3,-1,-2]
for i, ax in enumerate(axes):
    ax.imshow(x_val[digits[i]].reshape(28, 28), cmap='gray')
    ax.set_title(f"Label: {y_val[digits[i]]}")
    ax.axis('off')
plt.show()

In [None]:
def build_autoencoder(input_shape):
    input_img = Input(shape=input_shape)

    # Encoder
    x = Conv2D(16, (3, 3), padding='same')(input_img)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(8, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(1, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Flatten()(x)
    x = Dense(16,"relu")(x)
    x = Dense(8,"relu")(x)
    x = Dense(16,"relu")(x)
    x = Dense(49,"relu")(x)
    x = Reshape((7,7,1))(x)
    # Decoder
    x = Conv2D(4, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(8, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

    autoencoder = Model(input_img, x)
    return autoencoder

In [None]:
# Build the autoencoder model
input_shape = (28, 28, 1)
autoencoder = build_autoencoder(input_shape)
autoencoder.summary()

In [None]:
# Compile the model
autoencoder.compile(optimizer='adam', loss=MeanSquaredError())

In [None]:
# Early stopping to avoid overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

In [None]:
# Train the autoencoder on the training data
history = autoencoder.fit(x_train_normal, x_train_normal,
                          epochs=20,
                          batch_size=64,
                          shuffle=True,
                          validation_data=(x_val_normal, x_val_normal),
                          callbacks=[early_stopping])

In [None]:
# Plot the training history
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training History')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Calculate MSE for the validation dataset
x_out_val = autoencoder.predict(x_val)
mse_val = np.mean(np.square(x_val - x_out_val), axis=(1, 2, 3))

In [None]:
# Visualize some digits from the dataset
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
digits = [1,2,3,-1,-2]
for i, ax in enumerate(axes):
    ax.imshow(x_val[digits[i]].reshape(28, 28), cmap='gray')
    ax.set_title(f"Label: {y_val[digits[i]]}")
    ax.axis('off')
plt.show()

In [None]:
# Visualize some digits from the dataset
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
digits = [1,2,3,-1,-2]
for i, ax in enumerate(axes):
    ax.imshow(x_out_val[digits[i]].reshape(28, 28), cmap='gray')
    ax.set_title(f"Label: {y_val[digits[i]]}")
    ax.axis('off')
plt.show()

In [None]:
# Plot the distribution of MSE for normal vs anomalies
plt.figure(figsize=(10,6))
sns.histplot(mse_val[y_val == 0], label='Normal', color='blue', kde=True)
sns.histplot(mse_val[y_val == 1], label='Anomalies', color='red', kde=True)
plt.title('MSE Distribution: Normal vs Anomalies')
plt.legend()
plt.show()

In [None]:
# ROC curve
fpr, tpr, thresholds = roc_curve(y_val, mse_val)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()