In [1]:
import tensorflow as tf
import numpy as np
from datetime import datetime   # date stamp the log directory
import json  # for saving and loading hyperparameters
import os, sys, re
import time

# Get logger that has already been created in config.py
import daiquiri
logger = daiquiri.getLogger(__name__)

import absl
import absl.logging as logging
gfile = tf.io.gfile
flags = absl.app.flags

In [2]:


# Need this line for flags to work with Jupyter
# https://github.com/tensorflow/tensorflow/issues/17702
flags.DEFINE_string('f', '', 'kernel')


#------------------------------------------------------------------------------
# HYPERPARAMETERS
#------------------------------------------------------------------------------
# set to 64 according to authors (https://openreview.net/forum?id=HJWLfGWRb)
flags.DEFINE_integer('batch_size', 64, 'batch size in total across all gpus') 
flags.DEFINE_integer('epoch', 2000, 'epoch')
flags.DEFINE_integer('iter_routing', 3, 'number of iterations')
flags.DEFINE_integer('num_gpus', 1, 'number of GPUs')
flags.DEFINE_float('epsilon', 1e-9, 'epsilon')
flags.DEFINE_float('lrn_rate', 3e-3, 'learning rate to use in Adam optimiser')
flags.DEFINE_float('val_prop', 0.1, 
                   'proportion of test dataset to use for validation')
flags.DEFINE_boolean('weight_reg', False, 
                     'train with regularization of weights')
flags.DEFINE_string('norm', 'norm2', 'norm type')
flags.DEFINE_integer('num_threads', 8, 
                     'number of parallel calls in the input pipeline')
flags.DEFINE_string('dataset', 'smallNORB', 
                    '''dataset name: currently only "smallNORB" supported, feel
                    free to add your own''')
flags.DEFINE_float('final_lambda', 0.01, 'final lambda in EM routing')


#------------------------------------------------------------------------------
# ARCHITECTURE PARAMETERS
#------------------------------------------------------------------------------
flags.DEFINE_integer('A', 64, 'number of channels in output from ReLU Conv1')
flags.DEFINE_integer('B', 8, 'number of capsules in output from PrimaryCaps')
flags.DEFINE_integer('C', 16, 'number of channels in output from ConvCaps1')
flags.DEFINE_integer('D', 16, 'number of channels in output from ConvCaps2')


#------------------------------------------------------------------------------
# ENVIRONMENT SETTINGS
#------------------------------------------------------------------------------
flags.DEFINE_string('mode', 'train', 'train, validate, or test')
flags.DEFINE_string('name', '', 'name of experiment in log directory')
flags.DEFINE_boolean('reset', False, 'clear the train or test log directory')
flags.DEFINE_string('debugger', None, 
                    '''set to host of TensorBoard debugger e.g. "dccxc180:8886 
                    or dccxl015:8770"''')
flags.DEFINE_boolean('profile', False, 
                     '''get runtime statistics to display inTensorboard e.g. 
                     compute time''')
flags.DEFINE_string('load_dir', None, 
                    '''directory containing train or test checkpoints to 
                    continue from''')
flags.DEFINE_string('ckpt_name', None, 
                    '''None to load the latest ckpt; all to load all ckpts in 
                      dir; name to load specific ckpt''')
flags.DEFINE_string('params_path', None, 'path to JSON containing parameters')

LOCAL_STORAGE = '../'
flags.DEFINE_string('storage', LOCAL_STORAGE, 
                    'directory where logs and data are stored')
flags.DEFINE_string('db_name', 'capsules_ex1', 
                    'Name of the DB for mongo for sacred')

# Parse flags
FLAGS = flags.FLAGS




In [3]:
#------------------------------------------------------------------------------
# DIRECTORIES
#------------------------------------------------------------------------------
def setup_train_directories():
  
  # Set log directory
  date_stamp = datetime.now().strftime('%Y%m%d')
  save_dir = os.path.join(absl.app.flags.FLAGS.storage, 'logs/',
              absl.app.flags.FLAGS.dataset)
  train_dir = '{}/{}_{}/train'.format(save_dir, date_stamp, FLAGS.name)

  # Clear the train log directory
  if FLAGS.reset is True and gfile.exists(train_dir):
    gfile.remove(train_dir)

  # Create train directory
  if not gfile.exists(train_dir):
    gfile.makedirs(train_dir)

  # Set summary directory
  train_summary_dir = os.path.join(train_dir, 'summary')

  # Create summary directory
  if not gfile.exists(train_summary_dir):
    gfile.makedirs(train_summary_dir)
    
  return train_dir, train_summary_dir


#------------------------------------------------------------------------------
# SETUP LOGGER
#------------------------------------------------------------------------------
def setup_logger(logger_dir, name="logger"):
  os.environ['TZ'] = 'Africa/Johannesburg'
  time.tzset()
  daiquiri_formatter = daiquiri.formatter.ColorFormatter(
      fmt= "%(asctime)s %(color)s%(levelname)s: %(message)s%(color_stop)s",
      datefmt="%Y-%m-%d %H:%M:%S")
  logger_path = os.path.join(logger_dir, name)
  daiquiri.setup(level=logging.INFO, outputs=(
      daiquiri.output.Stream(formatter=daiquiri_formatter),
      daiquiri.output.File(logger_path,formatter=daiquiri_formatter),
     ))
  # To access the logger from other files, just put this line at the top:
  # logger = daiquiri.getLogger(__name__)

  
#------------------------------------------------------------------------------
# LOAD OR SAVE HYPERPARAMETERS
#------------------------------------------------------------------------------
def load_or_save_hyperparams(train_dir=None):
     
  # Load parameters from file
  # params_path is given in the case that run a new training using existing 
  # parameters
  # load_dir is given in the case of testing or continuing training 
  if FLAGS.params_path or FLAGS.load_dir:

    if FLAGS.params_path:
      params_path = os.path.abspath(FLAGS.params_path)
    elif FLAGS.load_dir:
      params_path = os.path.join(FLAGS.load_dir, "train", 
                     "params", "params.json")
      params_path = os.path.abspath(params_path)

    with open(params_path, 'r') as params_file:
      params = json.load(params_file)
      
      # Get list of flags that were specifically set in command line
      cl_args = sys.argv[1:]
      specified_flags = [re.search('--(.*)=', s).group(1) for s in cl_args]
      
      for name, value in params.items():
        # ignore flags that were specifically set./run in command line
        if name in specified_flags:
          pass
        else:
          FLAGS.__flags[name].value = value 
    logger.info("Loaded parameters from file: {}".format(params_path))

  # Save parameters to file
  elif FLAGS.mode == 'train': 
    params_dir_path = os.path.join(train_dir, "params")
    os.makedirs(params_dir_path, exist_ok=True)
    params_file_path = os.path.join(params_dir_path, "params.json")
    params = FLAGS.flag_values_dict()
    params_json = json.dumps(params, indent=4, separators=(',', ':'))
    with open(params_file_path, 'w') as params_file:
      params_file.write(params_json)
    logger.info("Parameters saved to file: {}".format(params_file_path))


#------------------------------------------------------------------------------
# FACTORIES FOR DATASET
#------------------------------------------------------------------------------
def get_dataset_path(dataset_name: str): 
  options = {'smallNORB': 'data/smallNORB/tfrecord'}
  path = FLAGS.storage + options[dataset_name]
  return path


def get_dataset_size_train(dataset_name: str):
  options = {'mnist': 55000, 
             'smallNORB': 23400 * 2,
             'fashion_mnist': 55000, 
             'cifar10': 50000, 
             'cifar100': 50000}
  return options[dataset_name]


def get_dataset_size_test(dataset_name: str):
  options = {'mnist': 10000, 
             'smallNORB': 23400 * 2,
             'fashion_mnist': 10000, 
             'cifar10': 10000, 
             'cifar100': 10000}
  return options[dataset_name]


def get_dataset_size_validate(dataset_name: str):
  options = {'smallNORB': 23400 * 2}
  return options[dataset_name]


def get_num_classes(dataset_name: str):
  options = {'mnist': 10, 
             'smallNORB': 5, 
             'fashion_mnist': 10, 
             'cifar10': 10, 
             'cifar100': 100}
  return options[dataset_name]


# import data_pipeline_norb as data_norb
def get_create_inputs(dataset_name: str, mode="train"):
  
  if mode == "train":
    is_train = True
  else:
    is_train = False
    
  path = get_dataset_path(dataset_name)
  
  options = {'smallNORB': 
         lambda: create_inputs_norb(path, is_train)}
  return options[dataset_name]


def get_dataset_architecture(dataset_name: str):
  options = {'smallNORB': build_arch_smallnorb}
  return options[dataset_name]

In [4]:
def create_routing_map(child_space, k, s):
  """Generate TFRecord for train and test datasets from .mat files.
  
  Create a binary map where the rows are capsules in the lower layer (children)
  and the columns are capsules in the higher layer (parents). The binary map 
  shows which children capsules are connected to which parent capsules along the   spatial dimension.
  
  Author:
    Ashley Gritzman 19/10/2018     
  Args: 
    child_space: spatial dimension of lower capsule layer
    k: kernel size
    s: stride    
  Returns:
    binmap: 
      A 2D numpy matrix containing mapping between children capsules along the 
      rows, and parent capsules along the columns.
      (child_space^2, parent_space^2)
      (7*7, 5*5)
  """
  
  parent_space = int((child_space - k)/s + 1)
  binmap = np.zeros((child_space**2, parent_space**2))
  for r in range(parent_space):
    for c in range(parent_space):
      p_idx = r*parent_space + c
      for i in range(k):
        # c_idx stand for child_index; p_idx is parent_index
        c_idx = r*s*child_space + c*s + child_space*i
        binmap[(c_idx):(c_idx + k), p_idx] = 1
  return binmap


def kernel_tile(input, kernel, stride):
  """Tile the children poses/activations so that the children for each parent occur in one axis.
  
  Author:
    Ashley Gritzman 19/10/2018
  Args: 
    input: 
      tensor of child poses or activations
      poses (N, child_space, child_space, i, 4, 4) -> (64, 7, 7, 8, 4, 4)
      activations (N, child_space, child_space, i, 1) -> (64, 7, 7, 8, 16) 
    kernel: 
    stride: 
  Returns:
    tiled: 
      (N, parent_space, parent_space, kh*kw, i, 16 or 1)
      (64, 5, 5, 9, 8, 16 or 1)
    child_parent_matrix:
      A 2D numpy matrix containing mapping between children capsules along the 
      rows, and parent capsules along the columns.
      (child_space^2, parent_space^2)
      (7*7, 5*5)
  """
  
  input_shape = input.get_shape()
  batch_size   = int(input_shape[0])
  spatial_size = int(input_shape[1])
  n_capsules   = int(input_shape[3])
  parent_spatial_size = int((spatial_size - kernel)/stride + 1)
  
  # Check that dim 1 and 2 correspond to the spatial size
  assert input_shape[1] == input_shape[2]
  
  # Check if we have poses or activations
  if len(input_shape) > 5: 
    # Poses
    size = input_shape[4]*input_shape[5]
  else:
    # Activations
    size = 1
  
  # Matrix showing which children map to which parent. Children are rows, 
  # parents are columns.
  child_parent_matrix = create_routing_map(spatial_size, kernel, stride)
  
  # Convert from np to tf
  #child_parent_matrix = tf.constant(child_parent_matrix)

  # Each row contains the children belonging to one parent
  child_to_parent_idx = group_children_by_parent(child_parent_matrix)
  
  # Spread out spatial dimension of children
  input = tf.reshape(input, [batch_size, spatial_size*spatial_size, -1])
  
  # Select which children go to each parent capsule
  tiled = tf.gather(input, child_to_parent_idx, axis=1)
  
  tiled = tf.squeeze(tiled)
  tiled = tf.reshape(tiled, [batch_size, parent_spatial_size, parent_spatial_size, kernel*kernel, n_capsules, -1])
  
  return tiled, child_parent_matrix


def compute_votes(poses_i, o, regularizer, tag=False):
  """Compute the votes by multiplying input poses by transformation matrix.
  
  Multiply the poses of layer i by the transform matrix to compute the votes for 
  layer j.
  
  Author:
    Ashley Gritzman 19/10/2018
    
  Credit: 
    Suofei Zhang's implementation on GitHub, "Matrix-Capsules-EM-Tensorflow"
    https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow
    
  Args: 
    poses_i: 
      poses in layer i tiled according to the kernel
      (N*OH*OW, kh*kw*i, 16)
      (64*5*5, 9*8, 16) 
    o: number of output capsules, also called "parent_caps"
    regularizer:    
    
  Returns:
    votes: 
      (N*OH*OW, kh*kw*i, o, 16)
      (64*5*5, 9*8, 32, 16)
  """
  
  batch_size = int(poses_i.get_shape()[0]) # 64*5*5
  kh_kw_i = int(poses_i.get_shape()[1]) # 9*8
  
  # (64*5*5, 9*8, 16) -> (64*5*5, 9*8, 1, 4, 4)
  output = tf.reshape(poses_i, shape=[batch_size, kh_kw_i, 1, 4, 4])
  
  # the output of capsule is miu, the mean of a Gaussian, and activation, the 
  # sum of probabilities it has no relationship with the absolute values of w 
  # and votes using weights with bigger stddev helps numerical stability
  w = slim.model_variable('w', shape=[1, kh_kw_i, o, 4, 4], 
                          dtype=tf.float32, 
                          initializer=tf.truncated_normal_initializer(
                            mean=0.0, 
                            stddev=1.0), #1.0
                          regularizer=regularizer)
  
  # (1, 9*8, 32, 4, 4) -> (64*5*5, 9*8, 32, 4, 4)
  w = tf.tile(w, [batch_size, 1, 1, 1, 1])
  
  # (64*5*5, 9*8, 1, 4, 4) -> (64*5*5, 9*8, 32, 4, 4)
  output = tf.tile(output, [1, 1, o, 1, 1])
  
  # (64*5*5, 9*8, 32, 4, 4) x (64*5*5, 9*8, 32, 4, 4) 
  # -> (64*5*5, 9*8, 32, 4, 4)
  mult = tf.matmul(output, w)
  
  # (64*5*5, 9*8, 32, 4, 4) -> (64*5*5, 9*8, 32, 16)
  votes = tf.reshape(mult, [batch_size, kh_kw_i, o, 16])
  
  # tf.summary.histogram('w', w) 

  return votes


def group_children_by_parent(bin_routing_map):
  """Groups children capsules by parent capsule.
  
  Rearrange the bin_routing_map so that each row represents one parent capsule,   and the entries in the row are indexes of the children capsules that route to   that parent capsule. This mapping is only along the spatial dimension, each 
  child capsule along in spatial dimension will actually contain many capsules,   e.g. 32. The grouping that we are doing here tell us about the spatial 
  routing, e.g. if the lower layer is 7x7 in spatial dimension, with a kernel of 
  3 and stride of 1, then the higher layer will be 5x5 in the spatial dimension. 
  So this function will tell us which children from the 7x7=49 lower capsules 
  map to each of the 5x5=25 higher capsules. One child capsule can be in several 
  different parent capsules, children in the corners will only belong to one 
  parent, but children towards the center will belong to several with a maximum   of kernel*kernel (e.g. 9), but also depending on the stride.
  
  Author:
    Ashley Gritzman 19/10/2018
  Args: 
    bin_routing_map: 
      binary routing map with children as rows and parents as columns
  Returns:
    children_per_parents: 
      parents are rows, and the indexes in the row are which children belong to       that parent
  """
  
  tmp = np.where(np.transpose(bin_routing_map))
  children_per_parent = np.reshape(tmp[1],[bin_routing_map.shape[1], -1])
  
  return children_per_parent


def init_rr(spatial_routing_matrix, child_caps, parent_caps):
  """Initialise routing weights.
  
  Initialise routing weights taking into accout spatial position of child 
  capsules. Child capsules in the corners only go to one parent capsule, while 
  those in the middle can go to kernel*kernel capsules.
  
  Author:
    Ashley Gritzman 19/10/2018
    
  Args: 
    spatial_routing_matrix: 
      A 2D numpy matrix containing mapping between children capsules along the 
      rows, and parent capsules along the columns.
      (child_space^2, parent_space^2)
      (7*7, 5*5)
    child_caps: number of child capsules along depth dimension
    parent_caps: number of parent capsules along depth dimension
    
  Returns:
    rr_initial: 
      initial routing weights
      (1, parent_space, parent_space, kk, child_caps, parent_caps)
      (1, 5, 5, 9, 8, 32)
  """

  # Get spatial dimension of parent & child
  parent_space_2 = int(spatial_routing_matrix.shape[1])
  parent_space = int(np.sqrt(parent_space_2))
  child_space_2 = int(spatial_routing_matrix.shape[0])
  child_space = int(np.sqrt(child_space_2))

  # Count the number of parents that each child belongs to
  parents_per_child = np.sum(spatial_routing_matrix, axis=1, keepdims=True)

  # Divide the vote of each child by the number of parents that it belongs to
  # If the striding causes the filter not to fit, it will result in some  
  # "dropped" child capsules, which effectively means child capsules that do not 
  # have any parents. This would create a divide by 0 scenario, so need to add 
  # 1e-9 to prevent NaNs.
  rr_initial = (spatial_routing_matrix 
                / (parents_per_child * parent_caps + 1e-9))

  # Convert the sparse matrix to be compatible with votes.
  # This is done by selecting the child capsules belonging to each parent, which 
  # is achieved by selecting the non-zero values down each column. Need the 
  # combination of two transposes so that order is correct when reshaping
  mask = spatial_routing_matrix.astype(bool)
  rr_initial = rr_initial.T[mask.T]
  rr_initial = np.reshape(rr_initial, [parent_space, parent_space, -1])

  # Copy values across depth dimensions
  # i.e. the number of child_caps and the number of parent_caps
  # (5, 5, 9) -> (5, 5, 9, 8, 32)
  rr_initial = rr_initial[..., np.newaxis, np.newaxis]
  rr_initial = np.tile(rr_initial, [1, 1, 1, child_caps, parent_caps])
  
  # Add one mode dimension for batch size
  rr_initial = np.expand_dims(rr_initial, 0)
  
  # Check the total of the routing weights is equal to the number of child 
  # capsules
  # child_space * child_space * child_caps (minus the dropped ones)
  dropped_child_caps = np.sum(np.sum(spatial_routing_matrix, axis=1) < 1e-9)
  effective_child_cap = ((child_space*child_space - dropped_child_caps) 
                         * child_caps)
  
  sum_routing_weights = np.sum(rr_initial)
  
#   assert_op = tf.assert_less(
#       np.abs(sum_routing_weights - effective_child_cap), 1e-9)
#   with tf.control_dependencies([assert_op]):
#     return rr_initial
  
  assert np.abs(sum_routing_weights - effective_child_cap) < 1e-3
  
  return rr_initial


def to_sparse(probs, spatial_routing_matrix, sparse_filler=tf.math.log(1e-20)):
  """Convert probs tensor to sparse along child_space dimension.
  
  Consider a probs tensor of shape (64, 6, 6, 3*3, 32, 16). 
  (batch_size, parent_space, parent_space, kernel*kernel, child_caps, 
  parent_caps)
  The tensor contains the probability of each child capsule belonging to a 
  particular parent capsule. We want to be able to sum the total probabilities 
  for a single child capsule to all the parent capsules. So we need to convert 
  the 3*3 spatial locations have been condensed, into a sparse format across
  all child spatial location e.g. 14*14. 
  
  Since we are working in log space, we must replace the zeros that come about 
  during sparse with log(0). The 'sparse_filler' option allows us to specify the 
  number to use to fill.
  
  Author:
    Ashley Gritzman 01/11/2018
    
  Args: 
    probs: 
      tensor of log probabilities of each child capsule belonging to a 
      particular parent capsule
      (batch_size, parent_space, parent_space, kernel*kernel, child_caps, 
      parent_caps)
      (64, 5, 5, 3*3, 32, 16)
    spatial_routing_matrix: 
      binary routing map with children as rows and parents as columns
    sparse_filler: 
      the number to use to fill in the sparse locations instead of zero
      
  Returns:
    sparse: 
      the sparse representation of the probs tensor in log space
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps)
      (64, 5, 5, 7*7, 32, 16)
  """
  
  # Get shapes of probs
  shape = probs.get_shape().as_list()
  batch_size = shape[0]
  parent_space = shape[1]
  kk = shape[3]
  child_caps = shape[4]
  parent_caps = shape[5]
  
  # Get spatial dimesion of child capsules
  child_space_2 = int(spatial_routing_matrix.shape[0])
  parent_space_2 = int(spatial_routing_matrix.shape[1])
  
  # Unroll the probs along the spatial dimension
  # e.g. (64, 6, 6, 3*3, 8, 32) -> (64, 6*6, 3*3, 8, 32)
  probs_unroll = tf.reshape(
      probs, 
      [batch_size, parent_space_2, kk, child_caps, parent_caps])
  
  # Each row contains the children belonging to one parent
  child_to_parent_idx = group_children_by_parent(spatial_routing_matrix)

  # Create an index mapping each capsule to the correct sparse location
  # Each element of the index must contain [batch_position, 
  # parent_space_position, child_sparse_position]
  # E.g. [63, 24, 49] maps image 63, parent space 24, sparse position 49
  child_sparse_idx = child_to_parent_idx
  child_sparse_idx = child_sparse_idx[np.newaxis,...]
  child_sparse_idx = np.tile(child_sparse_idx, [batch_size,1,1])

  parent_idx = np.arange(parent_space_2)
  parent_idx = np.reshape(parent_idx,[-1,1])
  parent_idx = np.repeat(parent_idx, kk)
  parent_idx = np.tile(parent_idx, batch_size)
  parent_idx = np.reshape(parent_idx,[batch_size,parent_space_2,kk])

  batch_idx = np.arange(batch_size)
  batch_idx = np.reshape(batch_idx, [-1,1])
  batch_idx = np.tile(batch_idx, parent_space_2*kk)
  batch_idx = np.reshape(batch_idx, [batch_size,parent_space_2,kk])

  # Combine the 3 coordinates
  indices = np.stack((batch_idx, parent_idx, child_sparse_idx),axis=3)
  indices = tf.constant(indices)

  # Convert each spatial location to sparse
  shape = [batch_size, parent_space_2, child_space_2, child_caps, parent_caps]
  sparse = tf.scatter_nd(indices, probs_unroll, shape)
  
  # scatter_nd pads the output with zeros, but since we are operating
  # in log space, we need to replace 0 with log(0), or log(1e-9)
  zeros_in_log = tf.ones_like(sparse, dtype=tf.float32) * sparse_filler
  sparse = tf.where(tf.equal(sparse, 0.0), zeros_in_log, sparse)
  
  # Reshape
  # (64, 5*5, 7*7, 8, 32) -> (64, 6, 6, 14*14, 8, 32)
  sparse = tf.reshape(sparse, [batch_size, parent_space, parent_space, child_space_2, child_caps, parent_caps])
  
  # Checks
  # 1. Shape
  assert sparse.get_shape().as_list() == [batch_size, parent_space, parent_space, child_space_2, child_caps, parent_caps]
  
  # This check no longer holds since we have replaced zeros with log(1e-9), so 
  # the total of dense and sparse no longer match.
  # 2. Total of dense and sparse must be the same
#   pct_delta = tf.abs(
#     (tf.reduce_sum(probs) - tf.reduce_sum(sparse))
#     /tf.reduce_sum(probs))

#   assert_op = tf.assert_less(
#       pct_delta, 
#       1e-4, 
#       message="in fn to_sparse: total of probs and sparse are different", 
#       data=[pct_delta, tf.reduce_sum(probs), tf.reduce_sum(sparse)])
#   with tf.control_dependencies([assert_op]):
#      sparse = tf.identity(sparse)
  
  return sparse
  
  
def normalise_across_parents(probs_sparse, spatial_routing_matrix):
  """Normalise across all parent capsules including spatial and depth.
  
  Consider a sparse matrix of probabilities (1, 5, 5, 49, 8, 32)  
  (batch_size, parent_space, parent_space, child_space*child_space, child_caps,   parent_caps)  
  For one child capsule, we need to normalise across all parent capsules that 
  receive output from that child. This includes the depth of parent capsules, 
  and the spacial dimension od parent capsules. In the example matrix of 
  probabilities above this would mean normalising across [1, 2, 5] or 
  [parent_space, parent_space, parent_caps]. 
  
  Author:
    Ashley Gritzman 05/11/2018
  Args: 
    probs_sparse: 
      the sparse representation of the probs matrix, not in log
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps) 
      (64, 5, 5, 49, 8, 32)
             
  Returns:
    rr_updated: 
      softmax across all parent capsules, same shape as input
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps) 
      (64, 5, 5, 49, 8, 32)
  """
  
  # e.g. (1, 5, 5, 49, 8, 32)
  # (batch_size, parent_space, parent_space, child_space*child_space, child_caps, parent_caps) 
  shape = probs_sparse.get_shape().as_list()
  batch_size = shape[0]
  parent_space = shape[1]
  child_space_2 = shape[3]  # squared
  child_caps = shape[4]
  parent_caps = shape[5]
  
  rr_updated = probs_sparse/(tf.reduce_sum(probs_sparse, 
                                           axis=[1,2,5], 
                                           keepdims=True) + 1e-9)
  
  # Checks
  # 1. Shape
  assert (rr_updated.get_shape().as_list() 
          == [batch_size, parent_space, parent_space, child_space_2, 
              child_caps, parent_caps])
  
  # 2. Total of routing weights must equal number of child capsules minus 
  # dropped ones. 
  # Because of numerical issues it is not likely that the routing weights will 
  # equal the calculated number of capsules, so we check that it is within a 
  # certain percent.
  dropped_child_caps = np.sum(np.sum(spatial_routing_matrix, axis=1) < 1e-9)
  #effective_child_caps = (child_space_2 - dropped_child_caps) * child_caps * 
  # batch_size
  effective_child_caps = (child_space_2 - dropped_child_caps) * child_caps
  effective_child_caps = tf.to_double(effective_child_caps)
  
  sum_routing_weights = tf.reduce_sum(tf.to_double(rr_updated), 
                                      axis=[1,2,3,4,5])
  
  pct_delta = tf.abs((effective_child_caps - sum_routing_weights) 
                     / effective_child_caps)

#   assert_op = tf.assert_less(
#       pct_delta, 
#       tf.to_double(0.01), 
#       message="""function normalise_across_parents: total of routing weights 
#               not equal to number of child capsules""",
#       data=[pct_delta, sum_routing_weights, effective_child_caps, 
#             tf.reduce_min(sum_routing_weights)], 
#       summarize=10)
#   with tf.control_dependencies([assert_op]):
#       rr_updated = tf.identity(rr_updated)
  
  return rr_updated


def softmax_across_parents(probs_sparse, spatial_routing_matrix):
  """Softmax across all parent capsules including spatial and depth.
  
  Consider a sparse matrix of probabilities (1, 5, 5, 49, 8, 32)  
  (batch_size, parent_space, parent_space, child_space*child_space, child_caps,   parent_caps)  
  For one child capsule, we need to normalise across all parent capsules that 
  receive output from that child. This includes the depth of parent capsules, 
  and the spacial dimension od parent capsules. In the example matrix of 
  probabilities above this would mean normalising across [1, 2, 5] or 
  [parent_space, parent_space, parent_caps]. But the softmax function 
  `tf.nn.softmax` can only operate across one axis, so we need to reshape the 
  matrix such that we can combine paret_space and parent_caps into one axis. 
  
  Author:
    Ashley Gritzman 05/11/2018
    
  Args: 
    probs_sparse: 
      the sparse representation of the probs matrix, in log
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps) 
      (64, 5, 5, 49, 8, 32)
             
  Returns:
    rr_updated: 
      softmax across all parent capsules, same shape as input
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps) 
      (64, 5, 5, 49, 8, 32)
  """
  
  # e.g. (1, 5, 5, 49, 8, 32)
  # (batch_size, parent_space, parent_space, child_space*child_space, 
  # child_caps, parent_caps) 
  shape = probs_sparse.get_shape().as_list()
  batch_size = shape[0]
  parent_space = shape[1]
  child_space_2 = shape[3]  # squared
  child_caps = shape[4]
  parent_caps = shape[5]
  
  # Move parent space dimensions, and parent depth dimension to end
  # (1, 5, 5, 49, 8, 32)  -> (1, 49, 4, 5, 5, 3)
  sparse = tf.transpose(probs_sparse, perm=[0,3,4,1,2,5])
  
  # Combine parent 
  # (1, 49, 4, 75)
  sparse = tf.reshape(sparse, [batch_size, child_space_2, child_caps, -1])
  
  # Perform softmax across parent capsule dimension
  parent_softmax = tf.nn.softmax(sparse, axis=-1)
  
  # Uncombine parent space and depth
  # (1, 49, 4, 5, 5, 3)
  parent_softmax = tf.reshape(
    parent_softmax, 
    [batch_size, child_space_2, child_caps, parent_space, parent_space, 
     parent_caps])
  
  # Return to original order
  # (1, 5, 5, 49, 8, 32)
  parent_softmax = tf.transpose(parent_softmax, perm=[0,3,4,1,2,5])
  
  # Softmax across the parent capsules actually gives us the updated routing 
  # weights
  rr_updated = parent_softmax
  
  # Checks
  # 1. Shape
  assert (rr_updated.get_shape().as_list() 
          == [batch_size, parent_space, parent_space, child_space_2, 
              child_caps, parent_caps])
  
  # 2. Check the total of the routing weights is equal to the number of child 
  # capsules
  # Note: during convolution some child capsules may be dropped if the 
  # convolution doesn't fit nicely. So in the sparse form of child capsules, the   # dropped capsules will be 0 everywhere. When we do a softmax, these capsules
  # will then be given a value, so when we check the total child capsules we 
  # need to include these. But these will then be excluded when we convert back   # to dense so it's not a problem. 
  total_child_caps = tf.to_float(child_space_2 * child_caps * batch_size)
  sum_routing_weights = tf.round(tf.reduce_sum(rr_updated))
  
#   assert_op = tf.assert_equal(
#       sum_routing_weights, 
#       total_child_caps,
#       message="""in fn softmax_across_parents: sum_routing_weights and 
#               effective_child_caps are different""")
#   with tf.control_dependencies([assert_op]):
#      rr_updated = tf.identity(rr_updated)
  
  return rr_updated   


def to_dense(sparse, spatial_routing_matrix):
  """Convert sparse back to dense along child_space dimension.
  
  Consider a sparse probs tensor of shape (64, 5, 5, 49, 8, 32).
  (batch_size, parent_space, parent_space, child_space*child_space, child_caps,
  parent_caps) 
  The tensor contains all child capsules at every parent spatial location, but 
  if the child does not route to the parent then it is just zero at that spot.
  Now we want to get back to the dense representation:
  (64, 5, 5, 49, 8, 32) -> (64, 5, 5, 9, 8, 32)
  
  Author:
    Ashley Gritzman 05/11/2018
  Args: 
    sparse: 
      the sparse representation of the probs tensor
      (batch_size, parent_space, parent_space, child_space*child_space, 
      child_caps, parent_caps) 
      (64, 5, 5, 49, 8, 32)
    spatial_routing_matrix: 
      binary routing map with children as rows and parents as columns
      
  Returns:
    dense: 
      the dense representation of the probs tensor
      (batch_size, parent_space, parent_space, kk, child_caps, parent_caps) 
      (64, 5, 5, 9, 8, 32)
  """
  
  # Get shapes of probs
  shape = sparse.get_shape().as_list()
  batch_size = shape[0]
  parent_space = shape[1]
  child_space_2 = shape[3] #squared
  child_caps = shape[4]
  parent_caps = shape[5]
  
  # Calculate kernel size by adding up column of spatial routing matrix
  kk = int(np.sum(spatial_routing_matrix[:,0]))
  
  # Unroll parent spatial dimensions
  # (64, 5, 5, 49, 8, 32) -> (64, 5*5, 49, 8, 32)
  sparse_unroll = tf.reshape(sparse, [batch_size, parent_space*parent_space, 
                                      child_space_2, child_caps, parent_caps])
  
  
  # Apply boolean_mask on axis 1 and 2
  # sparse_unroll: (64, 5*5, 49, 8, 32)
  # spatial_routing_matrix: (49, 25) -> (25, 49)
  # dense: (64, 5*5, 49, 8, 32) -> (64, 5*5*9, 8, 32)
  dense = tf.boolean_mask(sparse_unroll, 
                          tf.transpose(spatial_routing_matrix), axis=1)
  
  # Reshape
  dense = tf.reshape(dense, [batch_size, parent_space, parent_space, kk, 
                             child_caps, parent_caps])    
  
  # Checks
  # 1. Shape
  assert (dense.get_shape().as_list() 
          == [batch_size, parent_space, parent_space, kk, child_caps, 
              parent_caps])
  
#   # 2. Total of dense and sparse must be the same
#   delta = tf.abs(tf.reduce_sum(dense, axis=[3]) 
#                  - tf.reduce_sum(sparse, axis=[3]))
#   assert_op = tf.assert_less(
#       delta, 
#       1e-6,
#       message="in fn to_dense: total of dense and sparse are different",
#       data=[tf.reduce_sum(dense,[1,2,3,4,5]), 
#             tf.reduce_sum(sparse,[1,2,3,4,5]), 
#             tf.reduce_sum(dense),tf.reduce_sum(sparse)],
#       summarize=10)
#   with tf.control_dependencies([assert_op]):
#      dense = tf.identity(dense)
      
  return dense  


def logits_one_vs_rest(logits, positive_class = 0):
  """Return the logit from the positive class and the maximum logit from the 
  other classes.
  
  This function is used to prepare the logits from a multi class classifier to 
  be used for binary classification. The logits from the positive class are 
  placed in column 0. The maximum logit from the remaining classes is placed in   column 1.
  
  Author:
    Ashley Gritzman 04/12/2018
  Args: 
    logits_all: logits from multiple classes
    positive_class: the index of the positive class
  Returns:
    logits_one_vs_rest: 
      logits from positive class in column 0, and maximum logits of other 
      classes in column 1  
  """
  
  logits_positive = tf.reshape(logits[:,positive_class], [-1,1])
  
  logits_rest = tf.concat([logits[:,:positive_class], 
                           logits[:,(positive_class+1):]], axis=1)
  logits_rest_max = tf.reduce_max(logits_rest, axis=1, keepdims=True)

  logits_one_vs_rest = tf.concat([logits_positive, logits_rest_max], axis=1)
  
  return logits_one_vs_rest

In [5]:
def em_routing(votes_ij, activations_i, batch_size, spatial_routing_matrix):
  """The EM routing between input capsules (i) and output capsules (j).
  
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  of EM routing.
  
  Author:
    Ashley Gritzman 19/10/2018
  Definitions:
    N -> number of samples in batch
    OH -> output height
    OW -> output width
    kh -> kernel height
    kw -> kernel width
    kk -> kh * kw
    i -> number of input capsules, also called "child_caps"
    o -> number of output capsules, also called "parent_caps"
    child_space -> spatial dimensions of input capsule layer i
    parent_space -> spatial dimensions of output capsule layer j
    n_channels -> number of channels in pose matrix (usually 4x4=16)
  Args: 
    votes_ij: 
      votes from capsules in layer i to capsules in layer j
      For conv layer:
        (N*OH*OW, kh*kw*i, o, 4x4)
        (64*6*6, 9*8, 32, 16)
      For FC layer:
        The kernel dimensions are equal to the spatial dimensions of the input 
        layer i, and the spatial dimensions of the output layer j are 1x1.
        (N*1*1, child_space*child_space*i, o, 4x4)
        (64, 4*4*16, 5, 16)
    activations_i: 
      activations of capsules in layer i (L)
      (N*OH*OW, kh*kw*i, 1)
      (64*6*6, 9*8, 1)
    batch_size: 
    spatial_routing_matrix: 
  Returns:
    poses_j: 
      poses of capsules in layer j (L+1)
      (N, OH, OW, o, 4x4) 
      (64, 6, 6, 32, 16)
    activations_j: 
      activations of capsules in layer j (L+1)
      (N, OH, OW, o, 1)
      (64, 6, 6, 32, 1)
  """
  
  #----- Dimensions -----#
  
  # Get dimensions needed to do conversions
  N = batch_size
  votes_shape = votes_ij.get_shape().as_list()
  OH = np.sqrt(int(votes_shape[0]) / N)
  OH = int(OH)
  OW = np.sqrt(int(votes_shape[0]) / N)
  OW = int(OW)
  kh_kw_i = int(votes_shape[1])
  o = int(votes_shape[2])
  n_channels = int(votes_shape[3])
  
  # Calculate kernel size by adding up column of spatial routing matrix
  # Do this before conventing the spatial_routing_matrix to tf
  kk = int(np.sum(spatial_routing_matrix[:,0]))
  
  parent_caps = o
  child_caps = int(kh_kw_i/kk)
  
  rt_mat_shape = spatial_routing_matrix.shape
  child_space_2 = rt_mat_shape[0]
  child_space = int(np.sqrt(child_space_2))
  parent_space_2 = rt_mat_shape[1]
  parent_space = int(np.sqrt(parent_space_2))
   
  
  #----- Reshape Inputs -----#

  # conv: (N*OH*OW, kh*kw*i, o, 4x4) -> (N, OH, OW, kh*kw*i, o, 4x4)
  # FC: (N, child_space*child_space*i, o, 4x4) -> (N, 1, 1, child_space*child_space*i, output_classes, 4x4)
  votes_ij = tf.reshape(votes_ij, [N, OH, OW, kh_kw_i, o, n_channels]) 
  
  # (N*OH*OW, kh*kw*i, 1) -> (N, OH, OW, kh*kw*i, o, n_channels)
  #              (24, 6, 6, 288, 1, 1)
  activations_i = tf.reshape(activations_i, [N, OH, OW, kh_kw_i, 1, 1])
  

  #----- Betas -----#

  """
  # Initialization from Jonathan Hui [1]:
  beta_v_hui = tf.get_variable(
    name='beta_v', 
    shape=[1, 1, 1, o], 
    dtype=tf.float32,
    initializer=tf.contrib.layers.xavier_initializer())
  beta_a_hui = tf.get_variable(
    name='beta_a', 
    shape=[1, 1, 1, o], 
    dtype=tf.float32,
    initializer=tf.contrib.layers.xavier_initializer())
                              
  # AG 21/11/2018: 
  # Tried to find std according to Hinton's comments on OpenReview 
  # https://openreview.net/forum?id=HJWLfGWRb&noteId=r1lQjCAChm
  # Hinton: "We used truncated_normal_initializer and set the std so that at the 
  # start of training half of the capsules in each layer are active and half 
  # inactive (for the Primary Capsule layer where the activation is not computed 
  # through routing we use different std for activation convolution weights & 
  # for pose parameter convolution weights)."
  # 
  # std beta_v seems to control the spread of activations
  # To try and achieve what Hinton said about half active and half not active,
  # I change the std values and check the histogram/distributions in 
  # Tensorboard
  # to try and get a good spread across all values. I couldn't get this working
  # nicely.
  beta_v_hui = slim.model_variable(
    name='beta_v', 
    shape=[1, 1, 1, 1, o, 1], 
    dtype=tf.float32,
    initializer=tf.truncated_normal_initializer(mean=0.0, stddev=10.0))
  """
  beta_a = slim.model_variable(
    name='beta_a', 
    shape=[1, 1, 1, 1, o, 1], 
    dtype=tf.float32, 
    initializer=tf.truncated_normal_initializer(mean=-1000.0, stddev=500.0))
  
  # AG 04/10/2018: using slim.variable to create instead of tf.get_variable so 
  # that they get correctly placed on the CPU instead of GPU in the multi-gpu 
  # version.
  # One beta per output capsule type
  # (1, 1, 1, 1, 32, 1)
  # (N, OH, OH, i, o, n_channels)
  beta_v = slim.model_variable(
    name='beta_v', 
    shape=[1, 1, 1, 1, o, 1], 
    dtype=tf.float32,            
    initializer=tf.contrib.layers.xavier_initializer(),
    regularizer=None)
  """
  beta_a = slim.model_variable(
    name='beta_a', 
    shape=[1, 1, 1, 1, o, 1], 
    dtype=tf.float32, 
    initializer=tf.contrib.layers.xavier_initializer(),
    regularizer=None)
  """

  with tf.variable_scope("em_routing") as scope:
    # Initialise routing assignments
    # rr (1, 6, 6, 9, 8, 16) 
    #  (1, parent_space, parent_space, kk, child_caps, parent_caps)
    rr = utl.init_rr(spatial_routing_matrix, child_caps, parent_caps)
    
    # Need to reshape (1, 6, 6, 9, 8, 16) -> (1, 6, 6, 9*8, 16, 1)
    rr = np.reshape(
      rr, 
      [1, parent_space, parent_space, kk*child_caps, parent_caps, 1])
    
    # Convert rr from np to tf
    rr = tf.constant(rr, dtype=tf.float32)
    
    for it in range(FLAGS.iter_routing):  
      # AG 17/09/2018: modified schedule for inverse_temperature (lambda) based
      # on Hinton's response to questions on OpenReview.net: 
      # https://openreview.net/forum?id=HJWLfGWRb
      # "the formula we used for lambda is:
      # lambda = final_lambda * (1 - tf.pow(0.95, tf.cast(i + 1, tf.float32)))
      # where 'i' is the routing iteration (range is 0-2). Final_lambda is set 
      # to 0.01."
      # final_lambda = 0.01
      final_lambda = FLAGS.final_lambda
      inverse_temperature = (final_lambda * 
                             (1 - tf.pow(0.95, tf.cast(it + 1, tf.float32))))

      # AG 26/06/2018: added var_j
      activations_j, mean_j, stdv_j, var_j = m_step(
        rr, 
        votes_ij, 
        activations_i, 
        beta_v, beta_a, 
        inverse_temperature=inverse_temperature)
      
      # We skip the e_step call in the last iteration because we only need to 
      # return the a_j and the mean from the m_stp in the last iteration to 
      # compute the output capsule activation and pose matrices  
      if it < FLAGS.iter_routing - 1:
        rr = e_step(votes_ij, 
                    activations_j, 
                    mean_j, 
                    stdv_j, 
                    var_j, 
                    spatial_routing_matrix)

    # pose: (N, OH, OW, o, 4 x 4) via squeeze mean_j (24, 6, 6, 32, 16)
    poses_j = tf.squeeze(mean_j, axis=-3, name="poses")

    # activation: (N, OH, OW, o, 1) via squeeze o_activation is 
    # [24, 6, 6, 32, 1]
    activations_j = tf.squeeze(activations_j, axis=-3, name="activations")

  return poses_j, activations_j


def m_step(rr, votes, activations_i, beta_v, beta_a, inverse_temperature):
  """The m-step in EM routing between input capsules (i) and output capsules 
  (j).
  
  Compute the activations of the output capsules (j), and the Gaussians for the
  pose of the output capsules (j).
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  of m-step.
  
  Author:
    Ashley Gritzman 19/10/2018
    
  Args: 
    rr: 
      assignment weights between capsules in layer i and layer j
      (N, OH, OW, kh*kw*i, o, 1)
      (64, 6, 6, 9*8, 16, 1)
    votes_ij: 
      votes from capsules in layer i to capsules in layer j
      For conv layer:
        (N, OH, OW, kh*kw*i, o, 4x4)
        (64, 6, 6, 9*8, 32, 16)
      For FC layer:
        The kernel dimensions are equal to the spatial dimensions of the input 
        layer i, and
        the spatial dimensions of the output layer j are 1x1.
        (N, 1, 1, child_space*child_space*i, output_classes, 4x4)
        (64, 1, 1, 4*4*16, 5, 16)
    activations_i: 
      activations of capsules in layer i (L)
      (N, OH, OW, kh*kw*i, o, n_channels)
      (24, 6, 6, 288, 1, 1)
    beta_v: 
      Trainable parameters in computing cost 
      (1, 1, 1, 1, 32, 1)
    beta_a: 
      Trainable parameters in computing next level activation 
      (1, 1, 1, 1, 32, 1)
    inverse_temperature: lambda, increase over each iteration by the caller
    
  Returns:
    activations_j: 
      activations of capsules in layer j (L+1)
      (N, OH, OW, 1, o, 1)
      (64, 6, 6, 1, 32, 1)
    mean_j: 
      mean of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    stdv_j: 
      standard deviation of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    var_j: 
      variance of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
  """

  with tf.variable_scope("m_step") as scope:
    
    rr_prime = rr * activations_i
    rr_prime = tf.identity(rr_prime, name="rr_prime")

    # rr_prime_sum: sum over all input capsule i
    rr_prime_sum = tf.reduce_sum(rr_prime, 
                                 axis=-3, 
                                 keepdims=True, 
                                 name='rr_prime_sum')
    
    # AG 13/12/2018: normalise amount of information
    # The amount of information given to parent capsules is very different for 
    # the final "class-caps" layer. Since all the spatial capsules give output 
    # to just a few class caps, they receive a lot more information than the 
    # convolutional layers. So in order for lambda and beta_v/beta_a settings to 
    # apply to this layer, we must normalise the amount of information.
    # activ from convcaps1 to convcaps2 (64*5*5, 144, 16, 1) 144/16 = 9 info
    # (N*OH*OW, kh*kw*i, o, 1)
    # activ from convcaps2 to classcaps (64, 1, 1, 400, 5, 1) 400/5 = 80 info
    # (N, 1, 1, IH*IW*i, n_classes, 1)
    child_caps = float(rr_prime.get_shape().as_list()[-3])
    parent_caps = float(rr_prime.get_shape().as_list()[-2])
    ratio_child_to_parent =  child_caps/parent_caps
    layer_norm_factor = 100/ratio_child_to_parent
    # logger.info("ratio_child_to_parent: {}".format(ratio_child_to_parent))
    # rr_prime_sum = rr_prime_sum/ratio_child_to_parent

    # mean_j: (24, 6, 6, 1, 32, 16)
    mean_j_numerator = tf.reduce_sum(rr_prime * votes, 
                                     axis=-3, 
                                     keepdims=True, 
                                     name="mean_j_numerator")
    mean_j = tf.div(mean_j_numerator, 
                    rr_prime_sum + FLAGS.epsilon, 
                    name="mean_j")
    
    #----- AG 26/06/2018 START -----#
    # Use variance instead of standard deviation, because the sqrt seems to 
    # cause NaN gradients during backprop.
    # See original implementation from Suofei below
    var_j_numerator = tf.reduce_sum(rr_prime * tf.square(votes - mean_j), 
                                    axis=-3, 
                                    keepdims=True, 
                                    name="var_j_numerator")
    var_j = tf.div(var_j_numerator, 
                   rr_prime_sum + FLAGS.epsilon, 
                   name="var_j")
    
    # Set the minimum variance (note: variance should always be positive)
    # This should allow me to remove the FLAGS.epsilon safety from log and div 
    # that follow
    #var_j = tf.maximum(var_j, FLAGS.epsilon)
    #var_j = var_j + FLAGS.epsilon
    
    ###################
    #var_j = var_j + 1e-5
    var_j = tf.identity(var_j + 1e-9, name="var_j_epsilon")
    ###################
    
    # Compute the stdv, but it shouldn't actually be used anywhere
    # stdv_j = tf.sqrt(var_j)
    stdv_j = None
    
    ######## layer_norm_factor
    cost_j_h = (beta_v + 0.5*tf.math.log(var_j)) * rr_prime_sum * layer_norm_factor
    cost_j_h = tf.identity(cost_j_h, name="cost_j_h")
    
    # ----- END ----- #
    
    """
    # Original from Suofei (reference [3] at top)
    # stdv_j: (24, 6, 6, 1, 32, 16)
    stdv_j = tf.sqrt(
      tf.reduce_sum(
        rr_prime * tf.square(votes - mean_j), axis=-3, keepdims=True
      ) / rr_prime_sum,
      name="stdv_j"
    )
    # cost_j_h: (24, 6, 6, 1, 32, 16)
    cost_j_h = (beta_v + tf.log(stdv_j + FLAGS.epsilon)) * rr_prime_sum
    """
    
    # cost_j: (24, 6, 6, 1, 32, 1)
    # activations_j_cost = (24, 6, 6, 1, 32, 1)
    # yg: This is done for numeric stability.
    # It is the relative variance between each channel determined which one 
    # should activate.
    cost_j = tf.reduce_sum(cost_j_h, axis=-1, keepdims=True, name="cost_j")
    #cost_j_mean = tf.reduce_mean(cost_j, axis=-2, keepdims=True)
    #cost_j_stdv = tf.sqrt(
    #  tf.reduce_sum(
    #    tf.square(cost_j - cost_j_mean), axis=-2, keepdims=True
    #  ) / cost_j.get_shape().as_list()[-2]
    #)
    
    # AG 17/09/2018: trying to remove normalisation
    # activations_j_cost = beta_a + (cost_j_mean - cost_j) / (cost_j_stdv)
    activations_j_cost = tf.identity(beta_a - cost_j, 
                                     name="activations_j_cost")

    # (24, 6, 6, 1, 32, 1)
    activations_j = tf.sigmoid(inverse_temperature * activations_j_cost,
                               name="sigmoid")
    
    # AG 26/06/2018: added var_j to return
    return activations_j, mean_j, stdv_j, var_j

  
# AG 26/06/2018: added var_j
def e_step(votes_ij, activations_j, mean_j, stdv_j, var_j, spatial_routing_matrix):
  """The e-step in EM routing between input capsules (i) and output capsules (j).
  
  Update the assignment weights using in routung. The output capsules (j) 
  compete for the input capsules (i).
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  of e-step.
  
  Author:
    Ashley Gritzman 19/10/2018
    
  Args: 
    votes_ij: 
      votes from capsules in layer i to capsules in layer j
      For conv layer:
        (N, OH, OW, kh*kw*i, o, 4x4)
        (64, 6, 6, 9*8, 32, 16)
      For FC layer:
        The kernel dimensions are equal to the spatial dimensions of the input 
        layer i, and the spatial dimensions of the output layer j are 1x1.
        (N, 1, 1, child_space*child_space*i, output_classes, 4x4)
        (64, 1, 1, 4*4*16, 5, 16)
    activations_j: 
      activations of capsules in layer j (L+1)
      (N, OH, OW, 1, o, 1)
      (64, 6, 6, 1, 32, 1)
    mean_j: 
      mean of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    stdv_j: 
      standard deviation of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    var_j: 
      variance of each channel in capsules of layer j (L+1)
      (N, OH, OW, 1, o, n_channels)
      (24, 6, 6, 1, 32, 16)
    spatial_routing_matrix: ???
    
  Returns:
    rr: 
      assignment weights between capsules in layer i and layer j
      (N, OH, OW, kh*kw*i, o, 1)
      (64, 6, 6, 9*8, 16, 1)
  """
  
  with tf.variable_scope("e_step") as scope:
    
    # AG 26/06/2018: changed stdv_j to var_j
    o_p_unit0 = - tf.reduce_sum(
      tf.square(votes_ij - mean_j, name="num") / (2 * var_j), 
      axis=-1, 
      keepdims=True, 
      name="o_p_unit0")
    
    o_p_unit2 = - 0.5 * tf.reduce_sum(
      tf.math.log(2*np.pi * var_j), 
      axis=-1, 
      keepdims=True, 
      name="o_p_unit2"
    )

    # (24, 6, 6, 288, 32, 1)
    o_p = o_p_unit0 + o_p_unit2
    zz = tf.math.log(activations_j + FLAGS.epsilon) + o_p
    
    # AG 13/11/2018: New implementation of normalising across parents
    #----- Start -----#
    zz_shape = zz.get_shape().as_list()
    batch_size = zz_shape[0]
    parent_space = zz_shape[1]
    kh_kw_i = zz_shape[3]
    parent_caps = zz_shape[4]
    kk = int(np.sum(spatial_routing_matrix[:,0]))
    child_caps = int(kh_kw_i / kk)
    
    zz = tf.reshape(zz, [batch_size, parent_space, parent_space, kk, 
                         child_caps, parent_caps])
    
    """
    # In un-log space
    with tf.variable_scope("to_sparse_unlog") as scope:
      zz_unlog = tf.exp(zz)
      #zz_sparse_unlog = utl.to_sparse(zz_unlog, spatial_routing_matrix, 
      # sparse_filler=1e-15)
      zz_sparse_unlog = utl.to_sparse(
          zz_unlog, 
          spatial_routing_matrix, 
          sparse_filler=0.0)
      # maybe this value should be even lower 1e-15
      zz_sparse_log = tf.log(zz_sparse_unlog + 1e-15) 
      zz_sparse = zz_sparse_log
    """

    
    # In log space
    with tf.variable_scope("to_sparse_log") as scope:
      # Fill the sparse matrix with the smallest value in zz (at least -100)
      sparse_filler = tf.minimum(tf.reduce_min(zz), -100)
#       sparse_filler = -100
      zz_sparse = utl.to_sparse(
          zz, 
          spatial_routing_matrix, 
          sparse_filler=sparse_filler)
  
    
    with tf.variable_scope("softmax_across_parents") as scope:
      rr_sparse = utl.softmax_across_parents(zz_sparse, spatial_routing_matrix)
    
    with tf.variable_scope("to_dense") as scope:
      rr_dense = utl.to_dense(rr_sparse, spatial_routing_matrix)
      
    rr = tf.reshape(
        rr_dense, 
        [batch_size, parent_space, parent_space, kh_kw_i, parent_caps, 1])
    #----- End -----#

    # AG 02/11/2018
    # In response to a question on OpenReview, Hinton et al. wrote the 
    # following:
    # "The gradient flows through EM algorithm. We do not use stop gradient. A 
    # routing of 3 is like a 3 layer network where the weights of layers are 
    # shared."
    # https://openreview.net/forum?id=HJWLfGWRb&noteId=S1eo2P1I3Q
    
    return rr

In [6]:
def conv_caps(activation_in, 
              pose_in, 
              kernel, 
              stride, 
              ncaps_out, 
              name='conv_caps', 
              weights_regularizer=None):
  """Convolutional capsule layer.
  
  "The routing procedure is used between each adjacent pair of capsule layers. 
  For convolutional capsules, each capsule in layer L + 1 sends feedback only to 
  capsules within its receptive field in layer L. Therefore each convolutional 
  instance of a capsule in layer L receives at most kernel size X kernel size 
  feedback from each capsule type in layer L + 1. The instances closer to the 
  border of the image receive fewer feedbacks with corner ones receiving only 
  one feedback per capsule type in layer L + 1."
  
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  convolutional capsule layer.
  
  Author:
    Ashley Gritzman 27/11/2018
    
  Args: 
    activation_in:
      (batch_size, child_space, child_space, child_caps, 1)
      (64, 7, 7, 8, 1) 
    pose_in:
      (batch_size, child_space, child_space, child_caps, 16)
      (64, 7, 7, 8, 16) 
    kernel: 
    stride: 
    ncaps_out: depth dimension of parent capsules
    
  Returns:
    activation_out: 
      (batch_size, parent_space, parent_space, parent_caps, 1)
      (64, 5, 5, 32, 1)
    pose_out:
      (batch_size, parent_space, parent_space, parent_caps, 16)
      (64, 5, 5, 32, 16)
  """
  
  with tf.variable_scope(name) as scope:
    
    # Get shapes
    shape = pose_in.get_shape().as_list()
    batch_size = shape[0]
    child_space = shape[1]
    child_space_2 = int(child_space**2)
    child_caps = shape[3]
    parent_space = int(np.floor((child_space-kernel)/stride + 1))
    parent_space_2 = int(parent_space**2)
    parent_caps = ncaps_out
    kernel_2 = int(kernel**2)
    
    with tf.variable_scope('votes') as scope:
      # Tile poses and activations
      # (64, 7, 7, 8, 16)  -> (64, 5, 5, 9, 8, 16)
      pose_tiled, spatial_routing_matrix = utl.kernel_tile(
          pose_in, 
          kernel=kernel, 
          stride=stride)
      activation_tiled, _ = utl.kernel_tile(
          activation_in, 
          kernel=kernel, 
          stride=stride)

      # Check dimensions of spatial_routing_matrix
      assert spatial_routing_matrix.shape == (child_space_2, parent_space_2)

      # Unroll along batch_size and parent_space_2
      # (64, 5, 5, 9, 8, 16) -> (64*5*5, 9*8, 16)
      pose_unroll = tf.reshape(
          pose_tiled, 
          shape=[batch_size * parent_space_2, kernel_2 * child_caps, 16])
      activation_unroll = tf.reshape(
          activation_tiled, 
          shape=[batch_size * parent_space_2, kernel_2 * child_caps, 1])
      
      # (64*5*5, 9*8, 16) -> (64*5*5, 9*8, 32, 16)
      votes = utl.compute_votes(
          pose_unroll, 
          parent_caps, 
          weights_regularizer, 
          tag=True)
      logger.info(name + ' votes shape: {}'.format(votes.get_shape()))

    with tf.variable_scope('routing') as scope:
      # votes (64*5*5, 9*8, 32, 16)
      # activations (64*5*5, 9*8, 1)
      # pose_out: (N, OH, OW, o, 4x4)
      # activation_out: (N, OH, OW, o, 1)
      pose_out, activation_out = em.em_routing(votes, 
                           activation_unroll, 
                           batch_size, 
                           spatial_routing_matrix)
  
    logger.info(name + ' pose_out shape: {}'.format(pose_out.get_shape()))
    logger.info(name + ' activation_out shape: {}'
                .format(activation_out.get_shape()))

    tf.summary.histogram(name + "activation_out", activation_out)
  
  return activation_out, pose_out


def fc_caps(activation_in, 
            pose_in, 
            ncaps_out, 
            name='class_caps', 
            weights_regularizer=None):
  """Fully connected capsule layer.
  
  "The last layer of convolutional capsules is connected to the final capsule 
  layer which has one capsule per output class." We call this layer 'fully 
  connected' because it fits these characteristics, although Hinton et al. do 
  not use this teminology in the paper.
  
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description.
  
  Author:
    Ashley Gritzman 27/11/2018
    
  Args: 
    activation_in:
      (batch_size, child_space, child_space, child_caps, 1)
      (64, 7, 7, 8, 1) 
    pose_in:
      (batch_size, child_space, child_space, child_caps, 16)
      (64, 7, 7, 8, 16) 
    ncaps_out: number of class capsules
    name: 
    weights_regularizer:
    
  Returns:
    activation_out: 
      score for each output class
      (batch_size, ncaps_out)
      (64, 5)
    pose_out:
      pose for each output class capsule
      (batch_size, ncaps_out, 16)
      (64, 5, 16)
  """
  
  with tf.variable_scope(name) as scope:
    
    # Get shapes
    shape = pose_in.get_shape().as_list()
    batch_size = shape[0]
    child_space = shape[1]
    child_caps = shape[3]

    with tf.variable_scope('v') as scope:
      # In the class_caps layer, we apply same multiplication to every spatial 
      # location, so we unroll along the batch and spatial dimensions
      # (64, 5, 5, 32, 16) -> (64*5*5, 32, 16)
      pose = tf.reshape(
          pose_in, 
          shape=[batch_size * child_space * child_space, child_caps, 16])
      activation = tf.reshape(
          activation_in, 
          shape=[batch_size * child_space * child_space, child_caps, 1], 
          name="activation")

      # (64*5*5, 32, 16) -> (65*5*5, 32, 5, 16)
      votes = utl.compute_votes(pose, ncaps_out, weights_regularizer)

      # (65*5*5, 32, 5, 16)
      assert (
        votes.get_shape() == 
        [batch_size * child_space * child_space, child_caps, ncaps_out, 16])
      logger.info('class_caps votes original shape: {}'
                  .format(votes.get_shape()))

    with tf.variable_scope('coord_add') as scope:
      # (64*5*5, 32, 5, 16)
      votes = tf.reshape(
          votes, 
          [batch_size, child_space, child_space, child_caps, ncaps_out, 
           votes.shape[-1]])
      votes = coord_addition(votes)

    with tf.variable_scope('routing') as scope:
      # Flatten the votes:
      # Combine the 4 x 4 spacial dimensions to appear as one spacial dimension       # with many capsules.
      # [64*5*5, 16, 5, 16] -> [64, 5*5*16, 5, 16]
      votes_flat = tf.reshape(
          votes, 
          shape=[batch_size, child_space * child_space * child_caps, 
                 ncaps_out, votes.shape[-1]])
      activation_flat = tf.reshape(
          activation, 
          shape=[batch_size, child_space * child_space * child_caps, 1])
      
      spatial_routing_matrix = utl.create_routing_map(child_space=1, k=1, s=1)

      logger.info('class_caps votes in to routing shape: {}'
            .format(votes_flat.get_shape()))
      
      pose_out, activation_out = em.em_routing(votes_flat, 
                           activation_flat, 
                           batch_size, 
                           spatial_routing_matrix)

    activation_out = tf.squeeze(activation_out, name="activation_out")
    pose_out = tf.squeeze(pose_out, name="pose_out")

    logger.info('class_caps activation shape: {}'
                .format(activation_out.get_shape()))
    logger.info('class_caps pose shape: {}'.format(pose_out.get_shape()))

    tf.summary.histogram("activation_out", activation_out)
      
  return activation_out, pose_out

  
def coord_addition(votes):
  """Coordinate addition for connecting the last convolutional capsule layer to   the final layer.
  
  "When connecting the last convolutional capsule layer to the final layer we do 
  not want to throw away information about the location of the convolutional 
  capsules but we also want to make use of the fact that all capsules of the 
  same type are extracting the same entity at different positions. We therefore   share the transformation matrices between different positions of the same 
  capsule type and add the scaled coordinate (row, column) of the center of the   receptive field of each capsule to the first two elements of the right-hand 
  column of its vote matrix. We refer to this technique as Coordinate Addition.   This should encourage the shared final transformations to produce values for 
  those two elements that represent the fine position of the entity relative to   the center of the capsule’s receptive field."
  
  In Suofei's implementation, they add x and y coordinates as two new dimensions   to the pose matrix i.e. from 16 to 18 dimensions. The paper seems to say that   the values are added to existing dimensions.
  
  See Hinton et al. "Matrix Capsules with EM Routing" for detailed description 
  coordinate addition.  
  
  Author:
    Ashley Gritzman 27/11/2018
    
  Credit:
    Based on Jonathan Hui's implementation:
    https://jhui.github.io/2017/11/14/Matrix-Capsules-with-EM-routing-
    Capsule-Network/
    
  Args: 
    votes:
      (batch_size, child_space, child_space, child_caps, n_output_capsules, 16)
      (64, 5, 5, 32, 5, 16) 
      
  Returns:
    votes: 
      same size as input, with coordinate encoding added to first two elements 
      of right hand column of vote matrix
      (batch_size, parent_space, parent_space, parent_caps, 1)
      (64, 5, 5, 32, 16)
  """
  
  # get spacial dimension of votes
  height = votes.get_shape().as_list()[1]
  width = votes.get_shape().as_list()[2]
  dims = votes.get_shape().as_list()[-1]
  
  # Generate offset coordinates
  # The difference here is that the coordinate won't be exactly in the middle of 
  # the receptive field, but will be evenly spread out
  w_offset_vals = (np.arange(width) + 0.50)/float(width)
  h_offset_vals = (np.arange(height) + 0.50)/float(height)
  
  w_offset = np.zeros([width, dims]) # (5, 16)
  w_offset[:,3] = w_offset_vals
  # (1, 1, 5, 1, 1, 16)
  w_offset = np.reshape(w_offset, [1, 1, width, 1, 1, dims]) 
  
  h_offset = np.zeros([height, dims])
  h_offset[:,7] = h_offset_vals
  # (1, 5, 1, 1, 1, 16)
  h_offset = np.reshape(h_offset, [1, height, 1, 1, 1, dims]) 
  
  # Combine w and h offsets using broadcasting
  # w is (1, 1, 5, 1, 1, 16)
  # h is (1, 5, 1, 1, 1, 16)
  # together (1, 5, 5, 1, 1, 16)
  offset = w_offset + h_offset
  
  # Convent from numpy to tensor
  offset = tf.constant(offset, dtype=tf.float32)
    
  votes = tf.add(votes, offset, name="votes_with_coord_add")
  
  return votes

In [7]:
#------------------------------------------------------------------------------
# LOSS FUNCTIONS
#------------------------------------------------------------------------------
def spread_loss(scores, y):
  """Spread loss.
  
  "In order to make the training less sensitive to the initialization and 
  hyper-parameters of the model, we use “spread loss” to directly maximize the 
  gap between the activation of the target class (a_t) and the activation of the 
  other classes. If the activation of a wrong class, a_i, is closer than the 
  margin, m, to at then it is penalized by the squared distance to the margin."
  
  See Hinton et al. "Matrix Capsules with EM Routing" equation (3).
  
  Author:
    Ashley Gritzman 19/10/2018  
  Credit:
    Adapted from Suofei Zhang's implementation on GitHub, "Matrix-Capsules-
    EM-Tensorflow"
    https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow  
  Args: 
    scores: 
      scores for each class 
      (batch_size, num_class)
    y: 
      index of true class 
      (batch_size, 1)  
  Returns:
    loss: 
      mean loss for entire batch
      (scalar)
  """
  
  with tf.variable_scope('spread_loss') as scope:
    batch_size = int(scores.get_shape()[0])

    # AG 17/09/2018: modified margin schedule based on response of authors to 
    # questions on OpenReview.net: 
    # https://openreview.net/forum?id=HJWLfGWRb
    # "The margin that we set is: 
    # margin = 0.2 + .79 * tf.sigmoid(tf.minimum(10.0, step / 50000.0 - 4))
    # where step is the training step. We trained with batch size of 64."
    global_step = tf.to_float(tf.train.get_global_step())
    m_min = 0.2
    m_delta = 0.79
    m = (m_min 
         + m_delta * tf.sigmoid(tf.minimum(10.0, global_step / 50000.0 - 4)))

    num_class = int(scores.get_shape()[-1])

    y = tf.one_hot(y, num_class, dtype=tf.float32)
    
    # Get the score of the target class
    # (64, 1, 5)
    scores = tf.reshape(scores, shape=[batch_size, 1, num_class])
    # (64, 5, 1)
    y = tf.expand_dims(y, axis=2)
    # (64, 1, 5)*(64, 5, 1) = (64, 1, 1)
    at = tf.matmul(scores, y)
    
    # Compute spread loss, paper eq (3)
    loss = tf.square(tf.maximum(0., m - (at - scores)))
    
    # Sum losses for all classes
    # (64, 1, 5)*(64, 5, 1) = (64, 1, 1)
    # e.g loss*[1 0 1 1 1]
    loss = tf.matmul(loss, 1. - y)
    
    # Compute mean
    loss = tf.reduce_mean(loss)

  return loss


def cross_ent_loss(logits, y):
  """Cross entropy loss.
  
  Author:
    Ashley Gritzman 06/05/2019  
  Args: 
    logits: 
      logits for each class 
      (batch_size, num_class)
    y: 
      index of true class 
      (batch_size, 1)  
  Returns:
    loss: 
      mean loss for entire batch
      (scalar)
  """
  loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)
  loss = tf.reduce_mean(loss)

  return loss



def total_loss(scores, y):
  """total_loss = spread_loss + regularization_loss.
  
  If the flag to regularize is set, the the total loss is the sum of the spread   loss and the regularization loss.
  
  Author:
    Ashley Gritzman 19/10/2018  
  Credit:
    Adapted from Suofei Zhang's implementation on GitHub, "Matrix-Capsules-
    EM-Tensorflow"
    https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow  
  Args: 
    scores: 
      scores for each class 
      (batch_size, num_class)
    y: 
      index of true class 
      (batch_size, 1)  
  Returns:
    total_loss: 
      mean total loss for entire batch
      (scalar)
  """
  
  with tf.variable_scope('total_loss') as scope:
    # spread loss
    sprd_loss = spread_loss(scores, y)

    if FLAGS.weight_reg:
      # Regularization
      regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
      reg_loss = tf.add_n(regularization)
      total_loss = sprd_loss + reg_loss
      tf.summary.scalar('spread_loss', sprd_loss)
      tf.summary.scalar('regularization_loss', reg_loss)
    else:
      # No regularization
      total_loss = sprd_loss
      tf.summary.scalar('spread_loss', sprd_loss)

  return total_loss 


In [8]:
def build_arch_smallnorb(input, is_train: bool, num_classes: int):
  
  logger.info('input shape: {}'.format(input.get_shape()))
  batch_size = int(input.get_shape()[0])
  spatial_size = int(input.get_shape()[1])

  # xavier initialization is necessary here to provide higher stability
  # initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01)
  # instead of initializing bias with constant 0, a truncated normal 
  # initializer is exploited here for higher stability
  bias_initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01) 

  # AG 13/11/2018
  # In response to a question on OpenReview, Hinton et al. wrote the 
  # following:
  # "We use a weight decay loss with a small factor of .0000002 rather than 
  # the reconstruction loss."
  # https://openreview.net/forum?id=HJWLfGWRb&noteId=rJeQnSsE3X
  weights_regularizer = tf.contrib.layers.l2_regularizer(0.0000002)

  # weights_initializer=initializer,
  with slim.arg_scope([slim.conv2d], 
    trainable = is_train, 
    biases_initializer = bias_initializer,
    weights_regularizer = weights_regularizer):
    
    #----- Batch Norm -----#
    output = slim.batch_norm(
        input, 
        center=False, 
        is_training=is_train, 
        trainable=is_train)
    
    #----- Convolutional Layer 1 -----#
    with tf.variable_scope('relu_conv1') as scope:
      output = slim.conv2d(output, 
      num_outputs=FLAGS.A, 
      kernel_size=[5, 5], 
      stride=2, 
      padding='SAME', 
      scope=scope, 
      activation_fn=tf.nn.relu)
      
      spatial_size = int(output.get_shape()[1])
      assert output.get_shape() == [batch_size, spatial_size, spatial_size, 
                                    FLAGS.A]
      logger.info('relu_conv1 output shape: {}'.format(output.get_shape()))
    
    #----- Primary Capsules -----#
    with tf.variable_scope('primary_caps') as scope:
      pose = slim.conv2d(output, 
      num_outputs=FLAGS.B * 16, 
      kernel_size=[1, 1], 
      stride=1, 
      padding='VALID', 
      scope='pose', 
      activation_fn=None)
      activation = slim.conv2d(
          output, 
          num_outputs=FLAGS.B, 
          kernel_size=[1, 1], 
          stride=1, 
          padding='VALID', 
          scope='activation', 
          activation_fn=tf.nn.sigmoid)

      spatial_size = int(pose.get_shape()[1])
      pose = tf.reshape(pose, shape=[batch_size, spatial_size, spatial_size, 
                                     FLAGS.B, 16], name='pose')
      activation = tf.reshape(
          activation, 
          shape=[batch_size, spatial_size, spatial_size, FLAGS.B, 1], 
          name="activation")
      
      assert pose.get_shape() == [batch_size, spatial_size, spatial_size, 
                                  FLAGS.B, 16]
      assert activation.get_shape() == [batch_size, spatial_size, spatial_size,
                                        FLAGS.B, 1]
      logger.info('primary_caps pose shape: {}'.format(pose.get_shape()))
      logger.info('primary_caps activation shape {}'
                  .format(activation.get_shape()))
      
      tf.summary.histogram("activation", activation)
    
    #----- Conv Caps 1 -----#
    # activation_in: (64, 7, 7, 8, 1) 
    # pose_in: (64, 7, 7, 16, 16) 
    # activation_out: (64, 5, 5, 32, 1)
    # pose_out: (64, 5, 5, 32, 16)
    activation, pose = lyr.conv_caps(
        activation_in = activation, 
        pose_in = pose, 
        kernel = 3, 
        stride = 2, 
        ncaps_out = FLAGS.C, 
        name = 'lyr.conv_caps1', 
        weights_regularizer = weights_regularizer)
    
    #----- Conv Caps 2 -----#
    # activation_in: (64, 7, 7, 8, 1) 
    # pose_in: (64, 7, 7, 16, 1) 
    # activation_out: (64, 5, 5, 32, 1)
    # pose_out: (64, 5, 5, 32, 16)
    activation, pose = lyr.conv_caps(
        activation_in = activation, 
        pose_in = pose, 
        kernel = 3, 
        stride = 1, 
        ncaps_out = FLAGS.D, 
        name = 'lyr.conv_caps2', 
        weights_regularizer = weights_regularizer)
    
    #----- Class Caps -----#
    # activation_in: (64, 5, 5, 32, 1)
    # pose_in: (64, 5, 5, 32, 16)
    # activation_out: (64, 5)
    # pose_out: (64, 5, 16) 
    activation_out, pose_out = lyr.fc_caps(
        activation_in = activation,
        pose_in = pose,
        ncaps_out = num_classes,
        name = 'class_caps',
        weights_regularizer = weights_regularizer)
    
  return {'scores': activation_out, 'pose_out': pose_out}

In [9]:
def _parser(serialized_example):
  """Parse smallNORB example from tfrecord.
  
  Author:
    Ashley Gritzman 15/11/2018
  Args: 
    serialized_example: serialized example from tfrecord  
  Returns:
    img: image
    lab: label
    cat: 
      category
      the instance in the category (0 to 9)
    elv: 
      elevation
      the elevation (0 to 8, which mean cameras are 30, 
      35,40,45,50,55,60,65,70 degrees from the horizontal respectively)
    azi: 
      azimuth
      the azimuth (0,2,4,...,34, multiply by 10 to get the azimuth in 
      degrees)
    lit: 
      lighting
      the lighting condition (0 to 5)
  """

  features = tf.parse_single_example(
    serialized_example, 
    features={
      'img_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'category': tf.FixedLenFeature([], tf.int64), 
      'elevation': tf.FixedLenFeature([], tf.int64), 
      'azimuth': tf.FixedLenFeature([], tf.int64), 
      'lighting': tf.FixedLenFeature([], tf.int64),
     })

  img = tf.decode_raw(features['img_raw'], tf.float64)
  img = tf.reshape(img, [96, 96, 1])
  img = tf.cast(img, tf.float32)  # * (1. / 255) # left unnormalized

  lab = tf.cast(features['label'], tf.int32)
  cat = tf.cast(features['category'], tf.int32)
  elv = tf.cast(features['elevation'], tf.int32)
  azi = tf.cast(features['azimuth'], tf.int32)
  lit = tf.cast(features['lighting'], tf.int32)

  return img, lab, cat, elv, azi, lit


def _train_preprocess(img, lab, cat, elv, azi, lit):
  """Preprocessing for training.
  
  Preprocessing from Hinton et al. (2018) "Matrix capsules with EM routing."
  Hinton2018: "We downsample smallNORB to 48 × 48 pixels and normalize each 
  image to have zero mean and unit variance. During training, we randomly crop 
  32 × 32 patches and add random brightness and contrast to the cropped images.
  During test, we crop a 32 × 32 patch from the center of the image and 
  achieve..."
  
  Author:
    Ashley Gritzman 15/11/2018
  Args: 
    img: this fn only works on the image
    lab, cat, elv, azi, lit: allow these to pass through  
  Returns:
    img: image processed
    lab, cat, elv, azi, lit: allow these to pass through   
  """
  
  img = img / 255.
  img = tf.image.resize_images(img, [48, 48])
  img = tf.image.per_image_standardization(img)
  img = tf.random_crop(img, [32, 32, 1])
  img = tf.image.random_brightness(img, max_delta = 2.0)
  #original 0.5, 1.5
  img = tf.image.random_contrast(img, lower=0.5, upper=1.5) 
  
  # Original
  # image = tf.image.random_brightness(image, max_delta=32. / 255.)
  # image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  # image = tf.image.resize_images(image, [48, 48])
  # image = tf.random_crop(image, [32, 32, 1])

  return img, lab, cat, elv, azi, lit


def _val_preprocess(img, lab, cat, elv, azi, lit):
  """Preprocessing for validation/testing.
  
  Preprocessing from Hinton et al. (2018) "Matrix capsules with EM routing." 
  Hinton2018: "We downsample smallNORB to 48 × 48 pixels and normalize each 
  image to have zero mean and unit variance. During training, we randomly crop 
  32 × 32 patches and add random brightness and contrast to the cropped 
  images. During test, we crop a 32 × 32 patch from the center of the image 
  and achieve..."
  
  Author:
    Ashley Gritzman 15/11/2018
  Args: 
    img: this fn only works on the image
    lab, cat, elv, azi, lit: allow these to pass through  
  Returns:
    img: image processed
    lab, cat, elv, azi, lit: allow these to pass through   
  """
  
  img = img / 255.
  img = tf.image.resize_images(img, [48, 48])
  img = tf.image.per_image_standardization(img)
  img = tf.slice(img, [8, 8, 0], [32, 32, 1])
  
  # Original
  # image = tf.image.resize_images(image, [48, 48])
  # image = tf.slice(image, [8, 8, 0], [32, 32, 1])

  return img, lab, cat, elv, azi, lit
  

def input_fn(path, is_train: bool):
  """Input pipeline for smallNORB using tf.data.
  
  Author:
    Ashley Gritzman 15/11/2018
  Args: 
    is_train:  
  Returns:
    dataset: image tf.data.Dataset 
  """

  import re
  if is_train:
    CHUNK_RE = re.compile(r"train.*\.tfrecords")
  else:
    CHUNK_RE = re.compile(r"test.*\.tfrecords")

  chunk_files = [os.path.join(path, fname)
           for fname in os.listdir(path)
           if CHUNK_RE.match(fname)]
  
  print("path:", path)
  # 1. create the dataset
  dataset = tf.data.TFRecordDataset(chunk_files)
  
  # 2. map with the actual work (preprocessing, augmentation…) using multiple 
  # parallel calls
  dataset = dataset.map(_parser, num_parallel_calls=4)
  if is_train:
    dataset = dataset.map(_train_preprocess, 
                          num_parallel_calls=FLAGS.num_threads)
  else:
    dataset = dataset.map(_val_preprocess, 
                          num_parallel_calls=FLAGS.num_threads)
  
  # 3. shuffle (with a big enough buffer size)
  # In response to a question on OpenReview, Hinton et al. wrote the 
  # following:
  # https://openreview.net/forum?id=HJWLfGWRb&noteId=rJgxonoNnm
  # "We did not have any special ordering of training batches and we random 
  # shuffle. In terms of TF batch:
  # capacity=2000 + 3 * batch_size, ensures a minimum amount of shuffling of 
  # examples. min_after_dequeue=2000."
  capacity = 2000 + 3 * FLAGS.batch_size
  dataset = dataset.shuffle(buffer_size = capacity)
    
  # 4. batch
  dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
  
  # 5. repeat
  dataset = dataset.repeat(count=FLAGS.epoch)
  
  # 6. prefetch
  dataset = dataset.prefetch(1)
  
  return dataset


def create_inputs_norb(path, is_train: bool):
  """Get a batch from the input pipeline.
  
  Author:
    Ashley Gritzman 15/11/2018
  Args: 
    is_train:  
  Returns:
    img, lab, cat, elv, azi, lit: 
  """
  
  # Create batched dataset
  dataset = input_fn(path, is_train)
  
  # Create one-shot iterator
  iterator = dataset.make_one_shot_iterator()
  
  img, lab, cat, elv, azi, lit = iterator.get_next()
  
  output_dict = {'image': img,
           'label': lab,
           'category': cat,
           'elevation': elv,
           'azimuth': azi,
           'lighting': lit}
  
  return output_dict


def plot_smallnorb(is_train=True, samples_per_class=5):
  """Plot examples from the smallNORB dataset.
  
  Execute this command in a Jupyter Notebook.
  
  Author:
    Ashley Gritzman 18/04/2019
  Args: 
    is_train: True for the training dataset, False for the test dataset
    samples_per_class: number of samples images per class
  Returns:
    None
  """
  
  # To plot pretty figures
  import matplotlib.pyplot as plt
  plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
  plt.rcParams['image.interpolation'] = 'nearest'
  plt.rcParams['image.cmap'] = 'gray'
  
  from config import get_dataset_path
  path = get_dataset_path("smallNORB")
  
  CLASSES = ['animal', 'human', 'airplane', 'truck', 'car']

  # Get batch from data queue. Batch size is FLAGS.batch_size, which is then 
  # divided across multiple GPUs
  input_dict = create_inputs_norb(path, is_train=is_train)
  with tf.Session() as sess:
    input_dict = sess.run(input_dict)
    
  img_bch = input_dict['image']
  lab_bch = input_dict['label']
  cat_bch = input_dict['category']
  elv_bch = input_dict['elevation']
  azi_bch = input_dict['azimuth']
  lit_bch = input_dict['lighting']
  
  num_classes = len(CLASSES)

  fig = plt.figure(figsize=(num_classes * 2, samples_per_class * 2))
  fig.suptitle("category, elevation, azimuth, lighting")  
  for y, cls in enumerate(CLASSES):
    idxs = np.flatnonzero(lab_bch == y)
    idxs = np.random.choice(idxs, samples_per_class, replace=False)
    for i, idx in enumerate(idxs):
      plt_idx = i * num_classes + y + 1
      plt.subplot(samples_per_class, num_classes, plt_idx)
      #plt.imshow(img_bch[idx].astype('uint8').squeeze())
      plt.imshow(np.squeeze(img_bch[idx]))
      plt.xticks([], [])
      plt.yticks([], [])
      plt.xlabel("{}, {}, {},{}".format(cat_bch[idx], elv_bch[idx], 
                        azi_bch[idx], lit_bch[idx]))
      # plt.axis('off')

      if i == 0:
        plt.title(cls)
  plt.show()

In [10]:
def main(args):
  """Run training and validation.
  
  1. Build graphs
      1.1 Training graph to run on multiple GPUs
      1.2 Validation graph to run on multiple GPUs
  2. Configure sessions
      2.1 Train
      2.2 Validate
  3. Main loop
      3.1 Train
      3.2 Write summary
      3.3 Save model
      3.4 Validate model
      
  Author:
    Ashley Gritzman
  """
  
  # Set reproduciable random seed
  #tf.set_random_seed(1234)
    
  # Directories
  train_dir, train_summary_dir = setup_train_directories()
  
  # Logger
  setup_logger(logger_dir=train_dir, name="logger_train.txt")
  
  # Hyperparameters
  load_or_save_hyperparams(train_dir)
  
  # Get dataset hyperparameters
  logger.info('Using dataset: {}'.format(FLAGS.dataset))
  dataset_size_train  = get_dataset_size_train(FLAGS.dataset)
  dataset_size_val  = get_dataset_size_validate(FLAGS.dataset)
  build_arch      = get_dataset_architecture(FLAGS.dataset)
  num_classes     = get_num_classes(FLAGS.dataset)
  create_inputs_train = get_create_inputs(FLAGS.dataset, mode="train")
  create_inputs_val   = get_create_inputs(FLAGS.dataset, mode="validate")

  
 #*****************************************************************************
 # 1. BUILD GRAPHS
 #*****************************************************************************

  #----------------------------------------------------------------------------
  # GRAPH - TRAIN
  #----------------------------------------------------------------------------
  logger.info('BUILD TRAIN GRAPH')
  g_train = tf.Graph()
  with g_train.as_default(), tf.device('/cpu:0'):
    
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    # Get batches per epoch
    num_batches_per_epoch = int(dataset_size_train / FLAGS.batch_size)

    # In response to a question on OpenReview, Hinton et al. wrote the 
    # following:
    # "We use an exponential decay with learning rate: 3e-3, decay_steps: 20000,     # decay rate: 0.96."
    # https://openreview.net/forum?id=HJWLfGWRb&noteId=ryxTPFDe2X
    lrn_rate = tf.train.exponential_decay(learning_rate = FLAGS.lrn_rate, 
                        global_step = global_step, 
                        decay_steps = 20000, 
                        decay_rate = 0.96)
    tf.summary.scalar('learning_rate', lrn_rate)
    opt = tf.train.AdamOptimizer(learning_rate=lrn_rate)

    # Get batch from data queue. Batch size is FLAGS.batch_size, which is then 
    # divided across multiple GPUs
    input_dict = create_inputs_train()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 03/10/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-Tensorflow/blob/master/
    # dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)

    
    #--------------------------------------------------------------------------
    # MULTI GPU - TRAIN
    #--------------------------------------------------------------------------
    # Calculate the gradients for each model tower
    tower_grads = []
    tower_losses = []
    tower_logits = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          logger.info('TOWER %d' % i)
          #with slim.arg_scope([slim.model_variable, slim.variable],
          # device='/cpu:0'):
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits = tower_fn(
                build_arch, 
                splits_x[i], 
                splits_labels[i], 
                scope, 
                num_classes, 
                reuse_variables=reuse_variables,
                is_train=True)
          
          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Compute gradients for one GPU
          grads = opt.compute_gradients(loss)
          
          # Keep track of the gradients across all towers.
          tower_grads.append(grads)
          
          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          tower_losses.append(loss)
          
          # Loss for each tower
          tf.summary.scalar("loss", loss)
    
    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grad = average_gradients(tower_grads)
    
    # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-
    # gradients-in-tensorflow-when-updating
    grad_check = ([tf.check_numerics(g, message='Gradient NaN Found!') 
                      for g, _ in grad if g is not None] 
                  + [tf.check_numerics(loss, message='Loss NaN Found')])
    
    # Apply the gradients to adjust the shared variables
    with tf.control_dependencies(grad_check):
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = opt.apply_gradients(grad, global_step=global_step)
    
    # Calculate mean loss     
    loss = tf.reduce_mean(tower_losses)
    
    # Calculate accuracy
    logits = tf.concat(tower_logits, axis=0)
    acc = met.accuracy(logits, batch_labels)
    
    # Prepare predictions and one-hot labels
    probs = tf.nn.softmax(logits=logits)
    labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    trn_metrics = {'loss' : loss,
             'labels' : batch_labels, 
             'labels_oh' : labels_oh,
             'logits' : logits,
             'probs' : probs,
             'acc' : acc,
             }
    
    # Reset and read operations for streaming metrics go here
    trn_reset = {}
    trn_read = {}
    
    # Logging
    tf.summary.scalar('trn_loss', loss)
    tf.summary.scalar('trn_acc', acc)

    # Set Saver
    # AG 26/09/2018: Save all variables including Adam so that we can continue 
    # training from where we left off
    # max_to_keep=None should keep all checkpoints
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
    
    # Display number of parameters
    train_params = np.sum([np.prod(v.get_shape().as_list())
              for v in tf.trainable_variables()]).astype(np.int32)
    logger.info('Trainable Parameters: {}'.format(train_params))
        
    # Set summary op
    trn_summary = tf.summary.merge_all()
    
  
  #----------------------------------------------------------------------------
  # GRAPH - VALIDATION
  #----------------------------------------------------------------------------
  logger.info('BUILD VALIDATION GRAPH')
  g_val = tf.Graph()
  with g_val.as_default():
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    num_batches_val = int(dataset_size_val / FLAGS.batch_size * FLAGS.val_prop)
    
    # Get data
    input_dict = create_inputs_val()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 10/12/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-
    # Tensorflow/blob/master/dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)
    
    
    #--------------------------------------------------------------------------
    # MULTI GPU - VALIDATE
    #--------------------------------------------------------------------------
    # Calculate the logits for each model tower
    tower_logits = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits = tower_fn(
                build_arch, 
                splits_x[i], 
                splits_labels[i], 
                scope, 
                num_classes, 
                reuse_variables=reuse_variables, 
                is_train=False)

          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          
          # Loss for each tower
          tf.summary.histogram("val_logits", logits)
    
    # Combine logits from all towers
    logits = tf.concat(tower_logits, axis=0)
    
    # Calculate metrics
    val_loss = mod.spread_loss(logits, batch_labels)
    val_acc = met.accuracy(logits, batch_labels)
    
    # Prepare predictions and one-hot labels
    val_probs = tf.nn.softmax(logits=logits)
    val_labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    val_metrics = {'loss' : val_loss,
                   'labels' : batch_labels, 
                   'labels_oh' : val_labels_oh,
                   'logits' : logits,
                   'probs' : val_probs,
                   'acc' : val_acc,
                   }
    
    # Reset and read operations for streaming metrics go here
    val_reset = {}
    val_read = {}
    
    tf.summary.scalar("val_loss", val_loss)
    tf.summary.scalar("val_acc", val_acc)
      
    # Saver
    saver = tf.train.Saver(max_to_keep=None)
    
    # Set summary op
    val_summary = tf.summary.merge_all()
     
      
  #****************************************************************************
  # 2. SESSIONS
  #****************************************************************************
          
  #----- SESSION TRAIN -----#
  # Session settings
  sess_train = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 
                                                log_device_placement=False), 
                          graph=g_train)

  # Debugger
  # AG 05/06/2018: Debugging using either command line or TensorBoard
  if FLAGS.debugger is not None:
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    sess_train = tf_debug.TensorBoardDebugWrapperSession(sess_train, 
                                                         FLAGS.debugger)
    
  with g_train.as_default():
    sess_train.run([tf.global_variables_initializer(),
                    tf.local_variables_initializer()])
    
    # Restore previous checkpoint
    # AG 26/09/2018: where should this go???
    if FLAGS.load_dir is not None:
      load_dir_checkpoint = os.path.join(FLAGS.load_dir, "train", "checkpoint")
      prev_step = load_training(saver, sess_train, load_dir_checkpoint)
    else:
      prev_step = 0

  # Create summary writer, and write the train graph
  summary_writer = tf.summary.FileWriter(train_summary_dir, 
                                         graph=sess_train.graph)

  
  #----- SESSION VALIDATION -----#
  sess_val = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False),
                        graph=g_val)
  with g_val.as_default():
    sess_val.run([tf.local_variables_initializer(), 
                  tf.global_variables_initializer()])


  #****************************************************************************
  # 3. MAIN LOOP
  #****************************************************************************
  SUMMARY_FREQ = 100
  SAVE_MODEL_FREQ = num_batches_per_epoch # 500
  VAL_FREQ = num_batches_per_epoch # 500
  PROFILE_FREQ = 5
  
  for step in range(prev_step, FLAGS.epoch * num_batches_per_epoch + 1): 
  #for step in range(0,3):
    # AG 23/05/2018: limit number of iterations for testing
    # for step in range(100):
    epoch_decimal = step/num_batches_per_epoch
    epoch = int(np.floor(epoch_decimal))
    

    # TF queue would pop batch until no file
    try: 
      # TRAIN
      with g_train.as_default():
    
          # With profiling
          if (FLAGS.profile is True) and ((step % PROFILE_FREQ) == 0): 
            logger.info("Train with Profiling")
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
          # Without profiling
          else:
            run_options = None
            run_metadata = None
          
          # Reset streaming metrics
          if step % (num_batches_per_epoch/4) == 1:
            logger.info("Reset streaming metrics")
            sess_train.run([trn_reset])
          
          # MAIN RUN
          tic = time.time()
          train_op_v, trn_metrics_v, trn_summary_v = sess_train.run(
              [train_op, trn_metrics, trn_summary], 
              options=run_options, 
              run_metadata=run_metadata)
          toc = time.time()
          
          # Read streaming metrics
          trn_read_v = sess_train.run(trn_read)
          
          # Write summary for profiling
          if run_options is not None: 
            summary_writer.add_run_metadata(
                run_metadata, 'step{:d}'.format(step))
          
          # Logging
          logger.info('TRN'
                + ' e-{:d}'.format(epoch)
                + ' stp-{:d}'.format(step) 
                + ' {:.2f}s'.format(toc - tic) 
                + ' loss: {:.4f}'.format(trn_metrics_v['loss'])
                + ' acc: {:.2f}%'.format(trn_metrics_v['acc']*100)
                 )

    except KeyboardInterrupt:
      sess_train.close()
      sess_val.close()
      sys.exit()
      
    except tf.errors.InvalidArgumentError as e:
      logger.warning('%d iteration contains NaN gradients. Discard.' % step)
      logger.error(str(e))
      continue
      
    else:
      # WRITE SUMMARY
      if (step % SUMMARY_FREQ) == 0:
        logger.info("Write Train Summary")
        with g_train.as_default():
          # Summaries from graph
          summary_writer.add_summary(trn_summary_v, step)
          
      # SAVE MODEL
      if (step % SAVE_MODEL_FREQ) == 100:  
        logger.info("Save Model")
        with g_train.as_default():
          train_checkpoint_dir = train_dir + '/checkpoint'
          if not os.path.exists(train_checkpoint_dir):
            os.makedirs(train_checkpoint_dir)

          # Save ckpt from train session
          ckpt_path = os.path.join(train_checkpoint_dir, 'model.ckpt')
          saver.save(sess_train, ckpt_path, global_step=step)
      
      # VALIDATE MODEL
      if (step % VAL_FREQ) == 100:    
        #----- Validation -----#
        with g_val.as_default():
          logger.info("Start Validation")
          
          # Restore ckpt to val session
          latest_ckpt = tf.train.latest_checkpoint(train_checkpoint_dir)
          saver.restore(sess_val, latest_ckpt)
          
          # Reset accumulators
          accuracy_sum = 0
          loss_sum = 0
          sess_val.run(val_reset)
          
          for i in range(num_batches_val):
            val_metrics_v, val_summary_str_v = sess_val.run(
                [val_metrics, val_summary])
             
            # Update
            accuracy_sum += val_metrics_v['acc']
            loss_sum += val_metrics_v['loss']
            
            # Read
            val_read_v = sess_val.run(val_read)
            
            # Get checkpoint number
            ckpt_num = re.split('-', latest_ckpt)[-1]

            # Logging
            logger.info('VAL ckpt-{}'.format(ckpt_num) 
                        + ' bch-{:d}'.format(i) 
                        + ' cum_acc: {:.2f}%'.format(accuracy_sum/(i+1)*100) 
                        + ' cum_loss: {:.4f}'.format(loss_sum/(i+1))
                       )
          
          # Average across batches
          ave_acc = accuracy_sum / num_batches_val
          ave_loss = loss_sum / num_batches_val
           
          logger.info('VAL ckpt-{}'.format(ckpt_num) 
                      + ' avg_acc: {:.2f}%'.format(ave_acc*100) 
                      + ' avg_loss: {:.4f}'.format(ave_loss)
                     )
          
          logger.info("Write Val Summary")
          summary_val = tf.Summary()
          summary_val.value.add(tag="val_acc", simple_value=ave_acc)
          summary_val.value.add(tag="val_loss", simple_value=ave_loss)
          summary_writer.add_summary(summary_val, step)
          
  # Close (main loop)
  sess_train.close()
  sess_val.close()
  sys.exit()

  
def tower_fn(build_arch, 
             x, 
             y, 
             scope, 
             num_classes, 
             is_train=True, 
             reuse_variables=None):
  """Model tower to be run on each GPU.
  
  Author:
    Ashley Gritzman 27/11/2018
    
  Args: 
    build_arch:
    x: split of batch_x allocated to particular GPU
    y: split of batch_y allocated to particular GPU
    scope:
    num_classes:
    is_train:
    reuse_variables: False for the first GPU, and True for subsequent GPUs
  Returns:
    loss: mean loss across samples for one tower (scalar)
    scores: 
      If the architecture is a capsule network, then the scores are the output 
      activations of the class caps.
      If the architecture is the CNN baseline, then the scores are the logits of 
      the final layer.
      (samples_per_tower, n_classes)
      (64/4=16, 5)
  """
  
  with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
    output = build_arch(x, is_train, num_classes=num_classes)
    scores = output['scores']
    
  loss = mod.total_loss(scores, y)

  return loss, scores


def average_gradients(tower_grads):
  """Compute average gradients across all towers.
  
  Calculate the average gradient for each shared variable across all towers.
  Note that this function provides a synchronization point across all towers.
  
  Credit:
    https://github.com/naturomics/CapsNet-
    Tensorflow/blob/master/dist_version/distributed_train.py
  Args:
    tower_grads: 
      List of lists of (gradient, variable) tuples. The outer list is over 
      individual gradients. The inner list is over the gradient calculation for       each tower.
  Returns:
    average_grads:
      List of pairs of (gradient, variable) where the gradient has been 
      averaged across all towers.
  """
  
  average_grads = []
  for grad_and_vars in zip(*tower_grads):
  # Note that each grad_and_vars looks like the following:
  #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads = []
    for g, _ in grad_and_vars:
      # Add 0 dimension to the gradients to represent the tower.
      expanded_g = tf.expand_dims(g, 0)

      # Append on a 'tower' dimension which we will average over below.
      grads.append(expanded_g)

    # Average over the 'tower' dimension.
    grad = tf.concat(axis=0, values=grads)
    grad = tf.reduce_mean(grad, 0)

    # Keep in mind that the Variables are redundant because they are shared
    # across towers. So .. we will just return the first tower's pointer to
    # the Variable.
    v = grad_and_vars[0][1]
    grad_and_var = (grad, v)
    average_grads.append(grad_and_var)
    
  return average_grads
          

def extract_step(path):
  """Returns the step from the file format name of Tensorflow checkpoints.
  
  Credit:
    Sara Sabour
    https://github.com/Sarasra/models/blob/master/research/capsules/
    experiment.py
  Args:
    path: The checkpoint path returned by tf.train.get_checkpoint_state.
    The format is: {ckpnt_name}-{step}
  Returns:
    The last training step number of the checkpoint.
  """
  file_name = os.path.basename(path)
  return int(file_name.split('-')[-1])


def load_training(saver, session, load_dir):
  """Loads a saved model into current session or initializes the directory.
  
  If there is no functioning saved model or FLAGS.restart is set, cleans the
  load_dir directory. Otherwise, loads the latest saved checkpoint in load_dir
  to session.
  
  Author:
    Ashley Gritzman 26/09/2018
  Credit:
    Adapted from Sara Sabour
    https://github.com/Sarasra/models/blob/master/research/capsules/
    experiment.py
  Args:
    saver: An instance of tf.train.saver to load the model in to the session.
    session: An instance of tf.Session with the built-in model graph.
    load_dir: The directory which is used to load the latest checkpoint.
    
  Returns:
    The latest saved step.
  """
  
  if gfile.Exists(load_dir): 
    ckpt = tf.train.get_checkpoint_state(load_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(session, ckpt.model_checkpoint_path)
      prev_step = extract_step(ckpt.model_checkpoint_path)
      logger.info("Restored checkpoint")
    else:
      raise IOError("""AG: load_ckpt directory exists but cannot find a valid 
                    checkpoint to resore, consider using the reset flag""")
  else:
    raise IOError("AG: load_ckpt directory does not exist")
    
  return prev_step


def find_checkpoint(load_dir, seen_step):
  """Finds the global step for the latest written checkpoint to the load_dir.
  
  Credit:
    Sara Sabour
    https://github.com/Sarasra/models/blob/master/research/capsules/
    experiment.py
  Args:
    load_dir: The directory address to look for the training checkpoints.
    seen_step: Latest step which evaluation has been done on it.
  Returns:
    The latest new step in the load_dir and the file path of the latest model
    in load_dir. If no new file is found returns -1 and None.
  """
  ckpt = tf.train.get_checkpoint_state(load_dir)
  if ckpt and ckpt.model_checkpoint_path:
    global_step = extract_step(ckpt.model_checkpoint_path)
    if int(global_step) != seen_step:
      return int(global_step), ckpt.model_checkpoint_path
  return -1, None
          

# if __name__ == "__main__":
#   absl.app.run(main)

In [11]:
FLAGS(sys.argv)

# Directories
train_dir, train_summary_dir = setup_train_directories()

# Logger
setup_logger(logger_dir=train_dir, name="logger_train.txt")

# Hyperparameters
load_or_save_hyperparams(train_dir)

# Get dataset hyperparameters
logger.info('Using dataset: {}'.format(FLAGS.dataset))
dataset_size_train  = get_dataset_size_train(FLAGS.dataset)
dataset_size_val  = get_dataset_size_validate(FLAGS.dataset)
build_arch      = get_dataset_architecture(FLAGS.dataset)
num_classes     = get_num_classes(FLAGS.dataset)
create_inputs_train = get_create_inputs(FLAGS.dataset, mode="train")
create_inputs_val   = get_create_inputs(FLAGS.dataset, mode="validate")



2020-04-07 10:32:02 INFO: Parameters saved to file: ../logs/smallNORB/20200407_/train/params/params.json
2020-04-07 10:32:02 INFO: Using dataset: smallNORB


In [24]:
# Get batch from data queue. Batch size is FLAGS.batch_size, which is then 
# divided across multiple GPUs
# input_dict = create_inputs_train()
# batch_x = input_dict['image']
# batch_labels = input_dict['label']
import tensorflow_datasets as tfds
dataset = tfds.load('smallnorb', split="train")


2020-04-07 10:40:31 INFO: Overwrite dataset info from restored data version.
2020-04-07 10:40:31 INFO: Reusing dataset smallnorb (/home/dehghani/tensorflow_datasets/smallnorb/2.0.0)
2020-04-07 10:40:31 INFO: Constructing tf.data.Dataset for split train, from /home/dehghani/tensorflow_datasets/smallnorb/2.0.0


In [32]:
output = build_arch(a['image'], False, num_classes=num_classes)
scores = output['scores']
loss = mod.total_loss(scores, y)

2020-04-07 10:42:46 INFO: input shape: (96, 96, 1)


AttributeError: module 'tensorflow' has no attribute 'truncated_normal_initializer'

In [43]:
input = a['image']
is_train = False
logger.info('input shape: {}'.format(input.get_shape()))
batch_size = int(input.get_shape()[0])
spatial_size = int(input.get_shape()[1])

# xavier initialization is necessary here to provide higher stability
# initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01)
# instead of initializing bias with constant 0, a truncated normal 
# initializer is exploited here for higher stability
bias_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.01) 

# AG 13/11/2018
# In response to a question on OpenReview, Hinton et al. wrote the 
# following:
# "We use a weight decay loss with a small factor of .0000002 rather than 
# the reconstruction loss."
# https://openreview.net/forum?id=HJWLfGWRb&noteId=rJeQnSsE3X
weights_regularizer = tf.keras.regularizers.l2(0.0000002)

# weights_initializer=initializer,

#----- Batch Norm -----#
# output = tf.nn.batch_normalization(
#     input
# )

batch_norm = tf.keras.layers.BatchNormalization()

#----- Convolutional Layer 1 -----#
# output = slim.conv2d(output, 
# num_outputs=FLAGS.A, 
# kernel_size=[5, 5], 
# stride=2, 
# padding='SAME', 
# scope=scope, 
# activation_fn=tf.nn.relu)

# spatial_size = int(output.get_shape()[1])
# assert output.get_shape() == [batch_size, spatial_size, spatial_size, 
#                             FLAGS.A]
# logger.info('relu_conv1 output shape: {}'.format(output.get_shape()))

conv1 = tf.keras.layers.Conv2d(filters=FLAGS.A, kernel_size=[5,5], strides=(2, 2), 
                               padding='same', 
                               activation=tf.nn.relu, use_bias=True,
                               kernel_initializer='glorot_uniform', bias_initializer='zeros',
                               kernel_regularizer=weights_regularizer, 
                               bias_regularizer=weights_regularizer, activity_regularizer=None,
                               kernel_constraint=None, bias_constraint=None)

#----- Primary Capsules -----#
# with tf.variable_scope('primary_caps') as scope:
#   pose = slim.conv2d(output, 
#   num_outputs=FLAGS.B * 16, 
#   kernel_size=[1, 1], 
#   stride=1, 
#   padding='VALID', 
#   scope='pose', 
#   activation_fn=None)
#   activation = slim.conv2d(
#       output, 
#       num_outputs=FLAGS.B, 
#       kernel_size=[1, 1], 
#       stride=1, 
#       padding='VALID', 
#       scope='activation', 
#       activation_fn=tf.nn.sigmoid)

#   spatial_size = int(pose.get_shape()[1])
#   pose = tf.reshape(pose, shape=[batch_size, spatial_size, spatial_size, 
#                                  FLAGS.B, 16], name='pose')
#   activation = tf.reshape(
#       activation, 
#       shape=[batch_size, spatial_size, spatial_size, FLAGS.B, 1], 
#       name="activation")

#   assert pose.get_shape() == [batch_size, spatial_size, spatial_size, 
#                               FLAGS.B, 16]
#   assert activation.get_shape() == [batch_size, spatial_size, spatial_size,
#                                     FLAGS.B, 1]
#   logger.info('primary_caps pose shape: {}'.format(pose.get_shape()))
#   logger.info('primary_caps activation shape {}'
#               .format(activation.get_shape()))

#   tf.summary.histogram("activation", activation)

pcaps_pos_conv = tf.keras.layers.Conv2d(filters=FLAGS.B * 16, kernel_size=[1,1], strides=(1, 1), 
                               padding='valid', 
                               activation=None, use_bias=True,
                               kernel_initializer='glorot_uniform', 
                               bias_initializer='zeros',
                               kernel_regularizer=weights_regularizer, 
                               bias_regularizer=weights_regularizer, 
                               activity_regularizer=None,
                               kernel_constraint=None, bias_constraint=None)
p_caps_activation_conv = tf.keras.layers.Conv2d(filters=FLAGS.B, kernel_size=[1,1], strides=(1, 1), 
                               padding='valid', 
                               activation=tf.nn.sigmoid, use_bias=True,
                               kernel_initializer='glorot_uniform', 
                               bias_initializer='zeros',
                               kernel_regularizer=weights_regularizer, 
                               bias_regularizer=weights_regularizer, 
                               activity_regularizer=None,
                               kernel_constraint=None, bias_constraint=None)
#----- Conv Caps 1 -----#
# activation_in: (64, 7, 7, 8, 1) 
# pose_in: (64, 7, 7, 16, 16) 
# activation_out: (64, 5, 5, 32, 1)
# pose_out: (64, 5, 5, 32, 16)
activation, pose = ConvCaps(
    kernel = 3, 
    stride = 2, 
    ncaps_out = FLAGS.C, 
    name = 'conv_caps1', 
    weights_regularizer = weights_regularizer)

#----- Conv Caps 2 -----#
# activation_in: (64, 7, 7, 8, 1) 
# pose_in: (64, 7, 7, 16, 1) 
# activation_out: (64, 5, 5, 32, 1)
# pose_out: (64, 5, 5, 32, 16)
activation, pose = ConvCaps(
    kernel=3, 
    stride=1, 
    ncaps_out=FLAGS.D, 
    name='conv_caps2', 
    weights_regularizer=weights_regularizer)

#----- Class Caps -----#
# activation_in: (64, 5, 5, 32, 1)
# pose_in: (64, 5, 5, 32, 16)
# activation_out: (64, 5)
# pose_out: (64, 5, 16) 
activation_out, pose_out = FcCaps(
    ncaps_out=num_classes,
    name='class_caps',
    weights_regularizer=weights_regularizer)

2020-04-07 11:12:43 INFO: input shape: (96, 96, 1)


AttributeError: module 'tensorflow' has no attribute 'variable_scope'