<a href="https://colab.research.google.com/github/pulindu-seniya-silva/CapsNet-implementation/blob/main/CapsNet_implemention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

✅ Part 1: Data **Preparation**

In [14]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)


✅ Part 2: Convolution & Primary **Capsules**

In [15]:
from tensorflow.keras import layers

def squash(vectors, axis=-1):
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1.0 + s_squared_norm)
    return scale * vectors / tf.sqrt(s_squared_norm + 1e-8)

def ConvLayer(inputs):
    return layers.Conv2D(256, kernel_size=9, strides=1, activation='relu')(inputs)

def PrimaryCaps(inputs, dim_caps=8, n_channels=32, kernel_size=9, strides=2):
    conv = layers.Conv2D(filters=dim_caps * n_channels,
                         kernel_size=kernel_size,
                         strides=strides,
                         activation='relu')(inputs)
    capsules = layers.Reshape((-1, dim_caps))(conv)
    capsules = layers.Lambda(squash)(capsules)
    return capsules


✅ Part 3: Digit Capsules & Dynamic Routing

In [16]:
class DigitCaps(layers.Layer):
    def __init__(self, num_caps=10, dim_caps=16, routing_iters=3):
        super(DigitCaps, self).__init__()
        self.num_caps = num_caps
        self.dim_caps = dim_caps
        self.routing_iters = routing_iters

    def build(self, input_shape):
        self.input_num_caps = input_shape[1]
        self.input_dim_caps = input_shape[2]
        self.W = self.add_weight(
            shape=[1, self.input_num_caps, self.num_caps, self.dim_caps, self.input_dim_caps],
            initializer='glorot_uniform',
            trainable=True
        )

    def call(self, u):
        batch_size = tf.shape(u)[0]
        u = tf.expand_dims(tf.expand_dims(u, 2), -1)  # [batch, input_caps, 1, dim, 1]
        W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1])
        u_hat = tf.matmul(W_tiled, u)  # [batch, input_caps, num_caps, dim_caps, 1]
        u_hat = tf.squeeze(u_hat, axis=-1)

        b = tf.zeros([batch_size, self.input_num_caps, self.num_caps])
        for i in range(self.routing_iters):
            c = tf.nn.softmax(b, axis=2)
            s = tf.reduce_sum(tf.expand_dims(c, -1) * u_hat, axis=1)
            v = squash(s)
            if i < self.routing_iters - 1:
                b += tf.reduce_sum(u_hat * tf.expand_dims(v, 1), axis=-1)
        return v


✅ Part 4: Model, Loss, Training & **Evaluation**

In [None]:
from tensorflow.keras import Input, Model

inputs = Input(shape=(28, 28, 1))
x = ConvLayer(inputs)
x = PrimaryCaps(x)
digit_caps = DigitCaps()(x)

# Capsule lengths for classification
output = layers.Lambda(lambda z: tf.sqrt(tf.reduce_sum(tf.square(z), axis=-1)))(digit_caps)

model = Model(inputs=inputs, outputs=output)

# Margin loss
def margin_loss(y_true, y_pred):
    m_plus = 0.9
    m_minus = 0.1
    lambda_ = 0.5
    L = y_true * tf.square(tf.maximum(0., m_plus - y_pred)) + \
        lambda_ * (1 - y_true) * tf.square(tf.maximum(0., y_pred - m_minus))
    return tf.reduce_mean(tf.reduce_sum(L, axis=1))

model.compile(optimizer='adam', loss=margin_loss, metrics=['accuracy'])

# Train model
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test))


Epoch 1/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3265s[0m 7s/step - accuracy: 0.7545 - loss: 0.2232 - val_accuracy: 0.9865 - val_loss: 0.0178
Epoch 2/5
[1m 34/469[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m46:52[0m 6s/step - accuracy: 0.9863 - loss: 0.0170