In [40]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras import layers, initializers, models
import tensorflow.keras.backend as K
from tensorflow.keras.utils import to_categorical

In [29]:
import tensorflow as tf

In [30]:
class Length(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1))
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

In [32]:
class Mask(layers.Layer):
    def call(self, inputs, **kwargs):
        if type(inputs) is list:
            assert len(inputs) == 2
            inputs, mask = inputs
        else:
            x = K.sqrt(K.sum(K.square(inputs), -1))
            mask = K.one_hot(indices = K.argmax(x, 1), num_classes = x.get_shape().as_list()[1])
            
        masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
        return masked
    
    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:
            return tuple([None, input_shape[0][1]*input_shape[0][2]])
        else:
            return tuple([None, input_shape[1]*input_shape[2]])

In [33]:
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1+s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    
    return scale * vectors

In [83]:
class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_capsule, num_routing=3, kernel_initializer='glorot_uniform', **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing
        self.kernel_initializer = initializers.get(kernel_initializer)
        
    def build(self, input_shape):
        assert len(input_shape) >= 3
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]
        
        self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                       self.dim_capsule, self.input_dim_capsule],
                                initializer = self.kernel_initializer, name = 'W')
        self.built = True
        
    def call(self, inputs, training=None):
        print("inputs", inputs.shape)
        inputs_expand = K.expand_dims(inputs, 1)
        print("inputs expand", inputs_expand.shape)
        inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
        print("inputs tiled", inputs_tiled.shape)
        inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems = inputs_tiled)
        print("inputs hat", inputs_hat.shape)
        mai
        b = K.zeros(shape = [100, self.num_capsule, self.input_num_capsule])
        print("b", b.shape)
        
        def body(i, b, outputs):
            c = tf.nn.softmax(b, axis=1)
            print("c", c.shape)
            outputs = squash(K.batch_dot(c, inputs_hat, [2,2]))
            print("outputs", outputs.shape)
            if i!=1:
                print(outputs.shape, "and", inputs_hat.shape)
                b = b + K.batch_dot(outputs, inputs_hat, [2,3])
                
            return [i-1, b, outputs]
        
        cond = lambda i, b, inputs_hat: i > 0
        loop_vars = [K.constant(self.num_routing), b, K.sum(inputs_hat, 2, keepdims=False)]
        shape_invariants = [tf.TensorShape([]),
                            tf.TensorShape([None, self.num_capsule, self.input_num_capsule]),
                            tf.TensorShape([None, self.num_capsule, self.dim_capsule])]
        
        _,_,outputs = tf.while_loop(cond, body, loop_vars, shape_invariants)
        
        inputs_hat_stopped = K.stop_gradient(inputs_hat)
        b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
        
        assert self.num_routing > 0
        
        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=1)
            
            if i == self.num_routing - 1:
                outputs = squash(K.batch_dot(c, inputs_hat, [2,2]))
            else:
                outputs = squash(K.batch_dot(c, inputs_hat_stopped, [2,2]))
                b += K.batch_dot(outputs, inputs_hat_stopped, [2,3])
                
        return outputs
    
    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])

In [84]:
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
    output = layers.Conv2D(filters = dim_capsule*n_channels, kernel_size=kernel_size, strides=strides,
                          padding=padding, name='primarycap_conv2d')(inputs)
    outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output)
    return layers.Lambda(squash, name='primarycap_squash')(outputs)

In [85]:
def CapsNet(input_shape, n_class, num_routing):
    x = layers.Input(shape = input_shape)
    
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
    primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, num_routing=num_routing,
                            name='digitcaps')(primarycaps)
    out_caps = Length(name='capsnet')(digitcaps)
    
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])
    masked = Mask()(digitcaps)
    
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))
    
    train_model = models.Model([x,y], [out_caps, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_caps, decoder(masked)])
    
    return train_model, eval_model

In [86]:
train_data = pd.read_csv("../data/fashion-mnist_train.csv")
test_data = pd.read_csv("../data/fashion-mnist_test.csv")

y_train = train_data.label
X_train = train_data.drop(columns=['label'])

y_test = test_data.label
X_test = test_data.drop(columns=['label'])

X_train = np.reshape(np.array(X_train), (60000, 28, 28, 1))
X_train = X_train.astype('float32') / 255.
X_test = np.reshape(np.array(X_test), (10000, 28, 28, 1))
X_test = X_test.astype('float32') / 255.

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

print(X_train.shape)
print(y_train.shape)

(60000, 28, 28, 1)
(60000, 10)


In [87]:
model, eval_model = CapsNet(input_shape=X_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))),
                           num_routing=3)

inputs (None, 1152, 8)
inputs expand (None, 1, 1152, 8)
inputs tiled (None, 10, 1152, 8)
inputs hat (None, 10, 1152, 1152, 16)
b (100, 10, 1152)
c (None, 10, 1152)
outputs (None, 10, 10, 1152, 16)
(None, 10, 10, 1152, 16) and (None, 10, 1152, 1152, 16)


ValueError: Exception encountered when calling layer "digitcaps" (type CapsuleLayer).

in user code:

    File "/var/folders/wz/cnwy77v95fz65y4_p_vs2l140000gn/T/ipykernel_65449/2977921735.py", line 38, in body  *
        b = b + K.batch_dot(outputs, inputs_hat, [2,3])
    File "/Users/adityasingh/opt/anaconda3/lib/python3.9/site-packages/keras/src/backend.py", line 2584, in batch_dot
        raise ValueError(

    ValueError: Cannot do batch_dot on inputs with shapes (None, 10, 10, 1152, 16) and (None, 10, 1152, 1152, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 1152).


Call arguments received by layer "digitcaps" (type CapsuleLayer):
  • inputs=tf.Tensor(shape=(None, 1152, 8), dtype=float32)
  • training=None

In [66]:
x_train = x_train.reshape(-1,28,28,1).astype('float32') / 255.
y_train = to_categorical(y_test.astype('float32'))
print(x_train.shape)
print(y_train.shape)

(60000, 28, 28, 1)
(10000, 10)
