In [2]:
import tensorflow as tf
import numpy as np
from datetime import datetime   # date stamp the log directory
import json  # for saving and loading hyperparameters
import os, sys, re
import time

import absl
import absl.logging as logging
from tf2_models.matrix_caps import *
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.trainer import Trainer
from util.models import MODELS 
from util.tasks import TASKS
from notebook_utils import *

import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd

import seaborn as sns; sns.set()

from tqdm import tqdm

gfile = tf.io.gfile
flags = absl.app.flags

[nltk_data] Downloading package punkt to /home/dehghani/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
class CapsConfig(object):
  def __init__(self,
               **kwargs):
    self.output_dim = 10
    self.A = 64
    self.B = 8
    self.C = 16
    self.D = 16
    self.epsilon = 1e-9
    self.l2 = 0.0000002
    self.final_lambda = 0.01
    self.iter_routing = 2

In [4]:
def spread_loss(scores, y, global_step):
    """Spread loss.

    "In order to make the training less sensitive to the initialization and 
    hyper-parameters of the model, we use “spread loss” to directly maximize the 
    gap between the activation of the target class (a_t) and the activation of the 
    other classes. If the activation of a wrong class, a_i, is closer than the 
    margin, m, to at then it is penalized by the squared distance to the margin."

    See Hinton et al. "Matrix Capsules with EM Routing" equation (3).

    Author:
    Ashley Gritzman 19/10/2018  
    Credit:
    Adapted from Suofei Zhang's implementation on GitHub, "Matrix-Capsules-
    EM-Tensorflow"
    https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow  
    Args: 
    scores: 
      scores for each class 
      (batch_size, num_class)
    y: 
      index of true class 
      (batch_size, 1)  
    Returns:
    loss: 
      mean loss for entire batch
      (scalar)
    """
  
    batch_size = tf.shape(scores)[0]

    # margin = 0.2 + .79 * tf.sigmoid(tf.minimum(10.0, step / 50000.0 - 4))
    # where step is the training step. We trained with batch size of 64."
    m_min = 0.2
    m_delta = 0.79
    m = (m_min 
         + m_delta * tf.sigmoid(tf.minimum(10.0, global_step / 50000.0 - 4)))

    num_class = tf.shape(scores)[-1]

    y = tf.one_hot(y, num_class, dtype=tf.float32)

    # Get the score of the target class
    # (64, 1, 5)
    scores = tf.reshape(scores, shape=[batch_size, 1, num_class])
    # (64, 5, 1)
    y = tf.expand_dims(y, axis=2)
    # (64, 1, 5)*(64, 5, 1) = (64, 1, 1)
    at = tf.matmul(scores, y)

    # Compute spread loss, paper eq (3)
    loss = tf.math.square(tf.maximum(0., m - (at - scores)))

    # Sum losses for all classes
    # (64, 1, 5)*(64, 5, 1) = (64, 1, 1)
    # e.g loss*[1 0 1 1 1]
    loss = tf.matmul(loss, 1. - y)

    # Compute mean
    loss = tf.reduce_mean(loss)

    return loss


In [None]:
import tensorflow_datasets as tfds
orig = tfds.load('smallnorb', split="train")
dataset = orig.map(map_func=lambda x: (tf.cast(x['image'], tf.float32), tf.one_hot(x['label_category'],depth=10)),
                                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(2)
x,y = next(iter(dataset))

In [None]:
valid_dataset = tfds.load('smallnorb', split="train")
valid_dataset = valid_dataset.map(map_func=lambda x: (tf.cast(x['image'], tf.float32), tf.one_hot(x['label_category'],depth=10)),
                                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
valid_dataset = valid_dataset.batch(2)

In [None]:
model = MatrixCaps(CapsConfig())
outputs = model(x)
print(outputs.shape)
model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=tf.keras.metrics.SparseCategoricalAccuracy())

model.summary()

In [None]:
outputs = model(x)
tf.print(outputs.shape)
tf.print(y.shape)
model.loss(y, outputs)

In [None]:
model.fit(x=x.numpy(), y=y.numpy(),
          epochs=1,
          verbose=2
          )

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data (these are Numpy arrays)
#x_train = x_train.reshape(60000, 784).astype('float32') / 255
#x_test = x_test.reshape(10000, 784).astype('float32') / 255

x_train = x_train[...,None].astype('float32')
x_test = x_test[...,None].astype('float32')
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')

# Reserve 10,000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
tf.print(x_train.shape, y_train.shape)

In [5]:
chkpt_dir='../tf_ckpts'
task_name='mnist'
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    task = TASKS[task_name](get_task_params(batch_size=512), data_dir='../data')

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')


In [6]:
for x,y in task.train_dataset:
    print(x.shape, y.shape)
    break

(512, 28, 28, 1) (512,)


In [7]:
with strategy.scope():
    model = MatrixCaps(CapsConfig())
    example_x, example_y = next(iter(task.train_dataset))
    outputs = model(example_x, training=True)
    print(outputs.shape)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    model.summary()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


(512, 10)
Model: "matrix_caps"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
batch_normalization (BatchNo multiple                  4         
_________________________________________________________________
conv2d (Conv2D)              multiple                  1664      
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  8320      
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  520       
_________________________________________________________________
conv_caps1 (ConvCaps)        multiple                  18464     
_________________________________________________________________
conv_caps2 (ConvCaps)        multiple                  36896     
_________________________________________________________________
class_caps (FcCaps)          multiple        

In [None]:
with strategy.scope():
    history = model.fit(task.train_dataset,
                        epochs=10,
                        steps_per_epoch=task.n_train_batches,
                        validation_data=task.valid_dataset,
                        validation_steps=task.n_valid_batches)

    print('\nhistory dict:', history.history)

Epoch 1/10
INFO:tensorflow:batch_all_reduce: 17 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 17 all-reduces with algorithm = nccl, num_packs = 1


Epoch 2/10
Epoch 3/10
 12/117 [==>...........................] - ETA: 5:12 - loss: 1.8137 - sparse_categorical_accuracy: 0.9484