In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Import Libraries

In [None]:
!pip install tensorflow==2.13.0 --quiet

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
np.random.seed(42)
import pandas as pd

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

from mlxtend.plotting import plot_confusion_matrix
from scikitplot.metrics import plot_roc_curve

from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

# Load Data

In [None]:
(X_train, y_train), (X_test, y_test) = load_data()

In [None]:
X_train = X_train.reshape((60000, 28, 28, 1)).astype("float32")
X_test = X_test.reshape((10000, 28, 28, 1)).astype("float32")

In [None]:
X_train = X_train / 255
X_test = X_test / 255

In [None]:
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# Model

In [None]:
class Length(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1))

    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

In [None]:
class Mask(layers.Layer):
    def call(self, inputs, **kwargs):
        if type(inputs) is list:
            inputs, mask = inputs
        else:
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
            mask = tf.one_hot(indices=tf.argmax(x, 1), depth=x.shape[1])
        inputs_masked = K.batch_flatten(inputs * tf.expand_dims(mask, -1))
        return inputs_masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:
            return tuple([None, input_shape[0][-1]])
        else:
            return tuple([None, input_shape[-1]])

In [None]:
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / (K.sqrt(s_squared_norm))
    return scale * vectors

In [None]:
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_vector, num_routing=3, kernel_initializer='glorot_uniform', bias_initializer="zeros", **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_vector = dim_vector
        self.num_routing = num_routing
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer

    def build(self, input_shape):
        self.input_num_capsule = input_shape[1]
        self.input_dim_vector = input_shape[2]
        self.W = self.add_weight(shape=[self.num_capsule, 
                                        self.input_num_capsule, 
                                        self.dim_vector, 
                                        self.input_dim_vector], 
                                 initializer=self.kernel_initializer, name='w')
        
        self.bias = self.add_weight(shape=[1, self.input_num_capsule, self.num_capsule, 1, 1],
                                    initializer=self.bias_initializer,
                                    name='bias',
                                    trainable=False)
        
        self.built = True

    def call(self, inputs, training=None):
        input_expand = tf.expand_dims(tf.expand_dims(inputs, 1), -1)
        inputs_tiled = K.tile(input_expand, [1, self.num_capsule, 1, 1, 1])
        input_hat = tf.squeeze(tf.map_fn(lambda x: tf.matmul(self.W, x), elems=inputs_tiled))
        self.bias = tf.zeros(shape=[tf.shape(inputs)[0], self.num_capsule, 1, self.input_num_capsule])
        for i in range(self.num_routing):
            c = tf.nn.softmax(self.bias, axis=1)
            output = squash(tf.matmul(c, input_hat))
            if i < self.num_routing - 1:
                self.bias += tf.matmul(output, input_hat, transpose_b=True)
        return tf.squeeze(output)

    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_vector])

In [None]:
def margin_loss(y_true, y_pred):
    L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))
    return tf.reduce_mean(tf.reduce_sum(L, 1))

In [None]:
def primary_capsule(input_layer, num_filters, kernel_size):
    conv = layers.Conv2D(filters=num_filters, kernel_size=kernel_size, activation="relu", padding="valid")(input_layer)
    batchnorm = layers.BatchNormalization()(conv)
    maxp =  layers.MaxPooling2D(pool_size=(1, 1))(batchnorm)
    return conv, maxp

In [None]:
input_layer = layers.Input(shape=(28, 28, 1), batch_size=100)

conv_1, max_1 = primary_capsule(input_layer, 256, (9, 9))

conv_2 = layers.Conv2D(filters=256, kernel_size=(9, 9), strides=2, activation=None, padding="valid")(max_1)
reshape_layer_1 = layers.Reshape([-1, 8])(conv_2)

squash_layer = layers.Lambda(squash)(reshape_layer_1)

digitcaps = CapsuleLayer(num_capsule=10, dim_vector=16, num_routing=3)(squash_layer)
out_caps = Length(name="capsnet")(digitcaps)

y = layers.Input(shape=(10,))
masked_by_y = Mask()([digitcaps, y])
masked = Mask()(digitcaps)

x_recon = layers.Dense(512, activation="relu")(masked_by_y)
x_recon = layers.Dropout(0.5)(x_recon)
x_recon = layers.Dense(1024, activation="relu")(x_recon)
x_recon = layers.Dropout(0.5)(x_recon)
x_recon = layers.Dense(784, activation="sigmoid")(x_recon)
reshape_layer_2 = layers.Reshape((28, 28, 1))(x_recon)

model = models.Model(inputs=[input_layer, y], outputs=[out_caps, reshape_layer_2])

In [None]:
model.summary()

In [None]:
plot_model(model, show_shapes=True, show_layer_names=True, expand_nested=True)

# Train

In [None]:
early_stopping = EarlyStopping(monitor='val_capsnet_accuracy', mode='max', patience=3)
lr_scheduler = ReduceLROnPlateau(monitor='val_capsnet_accuracy', mode='max', patience=3, factor=0.5)

In [None]:
model.compile(optimizer=optimizers.Adam(learning_rate=0.001), 
                    loss=[margin_loss, 'mse'], 
                    loss_weights=[1.0, 0.0005], 
                    metrics=['accuracy'])

In [None]:
history = model.fit(
    [X_train, y_train], 
    [y_train, X_train], 
    validation_data=([X_test, y_test], [y_test, X_test]),
    batch_size=200, 
    epochs=25, 
    callbacks=[early_stopping, lr_scheduler]
)

# Results

In [None]:
history_df = pd.DataFrame(history.history)
history_df.head()

In [None]:
plt.figure()
plt.plot(history.history["capsnet_loss"])
plt.plot(history.history["val_capsnet_loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["train", "valid"])
plt.title("Loss Curve")
plt.show()

In [None]:
plt.figure()
plt.plot(history.history["capsnet_accuracy"])
plt.plot(history.history["val_capsnet_accuracy"])
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(["train", "valid"])
plt.title("Accuracy Curve")
plt.show()

# Test

In [None]:
num_samples = X_test.shape[0]

In [None]:
random_indices = np.random.choice(num_samples, size=4000, replace=False)

In [None]:
label_pred, image_pred = model.predict([X_test[random_indices], y_test[random_indices]], batch_size=200, verbose=0)

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10))

for i, ax in enumerate(axes.ravel()):
    if i < 16:
        label = np.argmax(y_test[i])
        pred = np.argmax(label_pred[i])
        
        ax.imshow(image_pred[i], cmap="gray")
        ax.set_title(f"True: {label}\nPred: {pred}")
        ax.axis('off')
        
    else:
        ax.axis('off')
        
plt.tight_layout()
plt.show()