In [1]:
import os
import tensorflow as tf
from util import constants
from util.config_util import get_model_params, get_task_params, get_train_params
from tf2_models.trainer import Trainer
from absl import app
from absl import flags
import numpy as np
from util.models import MODELS
from util.tasks import TASKS
from notebook_utils import *
import tensorflow_datasets as tfds
from tfds_data.aff_nist import AffNist
%matplotlib inline
import pandas as pd
import seaborn as sns; sns.set()

from tqdm import tqdm
from distill.repsim_util import get_reps

[nltk_data] Downloading package punkt to /home/dehghani/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
    # grab input dimensions
    B = tf.shape(input_fmap)[0]
    H = tf.shape(input_fmap)[1]
    W = tf.shape(input_fmap)[2]

    # reshape theta to (B, 2, 3)
    theta = tf.reshape(theta, [B, 2, 3])

    # generate grids of same size or upsample/downsample if specified
    if out_dims:
        out_H = out_dims[0]
        out_W = out_dims[1]
        batch_grids = affine_grid_generator(out_H, out_W, theta)
    else:
        batch_grids = affine_grid_generator(H, W, theta)

    x_s = batch_grids[:, 0, :, :]
    y_s = batch_grids[:, 1, :, :]

    # sample input with grid to get output
    out_fmap = bilinear_sampler(input_fmap, x_s, y_s)

    return out_fmap


def get_pixel_value(img, x, y):
    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

    batch_idx = tf.range(0, batch_size)
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
    b = tf.tile(batch_idx, (1, height, width))

    indices = tf.stack([b, y, x], 3)

    return tf.gather_nd(img, indices)


def affine_grid_generator(height, width, theta):
    num_batch = tf.shape(theta)[0]

    # create normalized 2D grid
    x = tf.linspace(-1.0, 1.0, width)
    y = tf.linspace(-1.0, 1.0, height)
    x_t, y_t = tf.meshgrid(x, y)

    # flatten
    x_t_flat = tf.reshape(x_t, [-1])
    y_t_flat = tf.reshape(y_t, [-1])

    # reshape to [x_t, y_t , 1] - (homogeneous form)
    ones = tf.ones_like(x_t_flat)
    sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])

    # repeat grid num_batch times
    sampling_grid = tf.expand_dims(sampling_grid, axis=0)
    sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))

    # cast to float32 (required for matmul)
    theta = tf.cast(theta, 'float32')
    sampling_grid = tf.cast(sampling_grid, 'float32')

    # transform the sampling grid - batch multiply
    batch_grids = tf.matmul(theta, sampling_grid)
    # batch grid has shape (num_batch, 2, H*W)

    # reshape to (num_batch, H, W, 2)
    batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])

    return batch_grids


def bilinear_sampler(img, x, y):
    """
    Performs bilinear sampling of the input images according to the
    normalized coordinates provided by the sampling grid. Note that
    the sampling is done identically for each channel of the input.
    To test if the function works properly, output image should be
    identical to input image when theta is initialized to identity
    transform.
    Input
    -----
    - img: batch of images in (B, H, W, C) layout.
    - grid: x, y which is the output of affine_grid_generator.
    Returns
    -------
    - out: interpolated images according to grids. Same size as grid.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    max_y = tf.cast(H - 1, 'int32')
    max_x = tf.cast(W - 1, 'int32')
    zero = tf.zeros([], dtype='int32')

    # rescale x and y to [0, W-1/H-1]
    x = tf.cast(x, 'float32')
    y = tf.cast(y, 'float32')
    x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
    y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))

    # grab 4 nearest corner points for each (x_i, y_i)
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    # clip to range [0, H-1/W-1] to not violate img boundaries
    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)

    # get pixel value at corner coords
    Ia = get_pixel_value(img, x0, y0)
    Ib = get_pixel_value(img, x0, y1)
    Ic = get_pixel_value(img, x1, y0)
    Id = get_pixel_value(img, x1, y1)

    # recast as float for delta calculation
    x0 = tf.cast(x0, 'float32')
    x1 = tf.cast(x1, 'float32')
    y0 = tf.cast(y0, 'float32')
    y1 = tf.cast(y1, 'float32')

    # calculate deltas
    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    # add dimension for addition
    wa = tf.expand_dims(wa, axis=3)
    wb = tf.expand_dims(wb, axis=3)
    wc = tf.expand_dims(wc, axis=3)
    wd = tf.expand_dims(wd, axis=3)

    # compute output
    out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])

    return out

In [30]:

class STN(tf.keras.layers.Layer):
    def __init__(self, *inputs, **kwargs):
        super(STN, self).__init__(*inputs, **kwargs)
        
        self.regularizer = tf.keras.regularizers.l1_l2(l1=0.00,
                                                       l2=0.000000002)
        self.create_layer()
    
    def create_layer(self):
        # params
        n_fc = 6
        
        def init_bias(shape, dtype=None):
            # identity transform
            initial = np.array([[1., 0, 0], [0, 1., 0]])
            initial = initial.astype('float32').flatten()
            return initial
        self.flat = tf.keras.layers.Flatten()
        self.localisation_net = tf.keras.layers.Dense(
        n_fc, activation=None, use_bias=True, kernel_initializer='zeros',
        bias_initializer=init_bias)

    def call(self, inputs, training=None, **kwargs):
        
        x = self.flat(inputs)
        h_fc1 = self.localisation_net(x)
        # spatial transformer layer
        h_trans = spatial_transformer_network(inputs, h_fc1)
        
        return h_trans
    

In [31]:
class ResnetBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size, activation='relu',*inputs, **kwargs):
    super(ResnetBlock, self).__init__(*inputs, **kwargs)
    self.filters = filters
    self.kernel_size = kernel_size
    self.activation = activation
    self.regularizer = tf.keras.regularizers.l1_l2(l1=0.00,
                                                   l2=0.000000002)

    self.create_layer()



  def create_layer(self):
    self.conv1 = tf.keras.layers.Conv2D(self.filters, self.kernel_size,
                                        activation=self.activation,
                                        padding='same',
                                        kernel_regularizer=self.regularizer)
    self.batch_norm1 = tf.keras.layers.BatchNormalization()
    self.conv2 = tf.keras.layers.Conv2D(self.filters, self.kernel_size,
                                 activation=None,
                                 padding='same',
                                 kernel_regularizer=self.regularizer)
    self.batch_norm2 = tf.keras.layers.BatchNormalization()
    self.add = tf.keras.layers.Add()
    self.activation = tf.keras.layers.Activation('relu')

  def call(self, inputs, training=None, **kwargs):
    outputs = self.conv1(inputs, training=training, **kwargs)
    outputs = self.batch_norm1(outputs,training=training, **kwargs)
    outputs = self.conv2(outputs, training=training, **kwargs)
    outputs = self.batch_norm2(outputs,training=training, **kwargs)
    outputs = self.add([outputs, inputs],training=training, **kwargs)
    outputs = self.activation(outputs, training=training, **kwargs)

    return outputs




In [32]:
class Resnet(tf.keras.Model):
  def __init__(self, hparams, scope='resnet', *inputs, **kwargs):
    if 'cl_token' in kwargs:
      del kwargs['cl_token']
    super(Resnet, self).__init__(name=scope, *inputs, **kwargs)
    self.scope = scope
    self.hparams = hparams
    self.model_name = '_'.join([self.scope,
                                'h-' + str(self.hparams.hidden_dim),
                                'rd-' + str(self.hparams.num_res_net_blocks),
                                'hdrop-' + str(self.hparams.hidden_dropout_rate),
                                'indrop-' + str(self.hparams.input_dropout_rate)])

    self.regularizer = tf.keras.regularizers.l1_l2(l1=0.00,
                                                   l2=0.000000002)
    self.create_layers()
    self.rep_index = 1
    self.rep_layer = -1


  def create_layers(self):
    self.stn1 = STN()
    self.activation = tf.keras.layers.Activation('relu')

    self.conv1 = tf.keras.layers.Conv2D(self.hparams.filters[0], self.hparams.kernel_size[0],
                                  activation=None,
                                  kernel_regularizer=self.regularizer)
    self.batch_norm2 = tf.keras.layers.BatchNormalization()
    self.conv2 = tf.keras.layers.Conv2D(self.hparams.filters[1], self.hparams.kernel_size[1],
                                  activation=None,
                                  kernel_regularizer=self.regularizer)
    self.batch_norm3 = tf.keras.layers.BatchNormalization()
    self.pool2 = tf.keras.layers.MaxPooling2D(self.hparams.pool_size)

    self.resblocks = []
    for i in range(self.hparams.num_res_net_blocks):
      self.resblocks.append(ResnetBlock(self.hparams.filters[2], self.hparams.kernel_size[2]))

    self.conv4 = tf.keras.layers.Conv2D(self.hparams.filters[3], self.hparams.kernel_size[3],
                                        activation=None)
    self.batch_norm4 = tf.keras.layers.BatchNormalization()
    self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
    self.dense = tf.keras.layers.Dense(self.hparams.hidden_dim, activation='relu')
    self.dropout = tf.keras.layers.Dropout(self.hparams.hidden_dropout_rate)
    self.project = tf.keras.layers.Dense(self.hparams.output_dim, activation=None)

  def call(self, inputs, padding_symbol=None, training=None, **kwargs):
    x = self.stn1(inputs, training=training, **kwargs)
    x = self.conv1(x, training=training, **kwargs)
    x = self.batch_norm2(x, training=training, **kwargs)
    x = self.activation(x)
    x = self.dropout(x, training=training, **kwargs)

    x = self.conv2(x, training=training, **kwargs)
    x = self.batch_norm3(x, training=training, **kwargs)
    x = self.activation(x)
    x = self.dropout(x, training=training, **kwargs)

    x = self.pool2(x, training=training, **kwargs)
    for i in range(self.hparams.num_res_net_blocks):
      x = self.resblocks[i](x, training=training, **kwargs)
      x = self.dropout(x, training=training, **kwargs)

    x = self.conv4(x, training=training, **kwargs)
    x = self.batch_norm4(x, training=training, **kwargs)
    x = self.activation(x)
    x = self.dropout(x, training=training, **kwargs)

    x = self.avgpool(x, training=training, **kwargs)
    x = self.dense(x, training=training, **kwargs)
    x = self.dropout(x, training=training, **kwargs)
    outputs = self.project(x, training=training, **kwargs)

    return outputs



In [33]:
chkpt_dir='../tf_ckpts'
task_name='affnist'
task = TASKS[task_name](get_task_params(), data_dir='../data')

In [34]:
config={'exp_name':'test',
    'model_config':'rsnt_mnist1',
    'task_name':'affnist',
    'model_name':'resnet',
    'chkpt_dir':'../tf_ckpts',
    'learning_rate': 0.001
    }

task = TASKS[config['task_name']](get_task_params(batch_size=16), data_dir='../data')

hparams = get_model_params(task, config['model_name'], config['model_config'])
print(hparams)


model config: rsnt_mnist1
{'hidden_dim': 512, 'pool_size': 3, 'filters': [32, 32, 32, 32], 'kernel_size': [(3, 3), (3, 3), (3, 3), (3, 3)], 'hidden_dropout_rate': 0.2, 'input_dropout_rate': 0.0, 'num_res_net_blocks': 2}
<util.model_configs.ResnetConfig object at 0x7f2b30c72190>


In [35]:
model = Resnet(hparams)

In [36]:
for x,y in task.train_dataset:
    print(x.shape, y.shape)
    break
out = model(inputs=x, training=True)
print(out.shape)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=task.get_loss_fn(),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

model.summary()

(16, 40, 40, 1) (16,)
(16, 10)
Model: "resnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
stn_4 (STN)                  multiple                  9606      
_________________________________________________________________
activation_9 (Activation)    multiple                  0         
_________________________________________________________________
conv2d_21 (Conv2D)           multiple                  320       
_________________________________________________________________
batch_normalization_21 (Batc multiple                  128       
_________________________________________________________________
conv2d_22 (Conv2D)           multiple                  9248      
_________________________________________________________________
batch_normalization_22 (Batc multiple                  128       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 

In [37]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=task.get_loss_fn(),
              metrics=task.metrics())

history = model.fit(task.train_dataset,
                  epochs=20,
                  steps_per_epoch=task.n_train_batches,
                  validation_steps=task.n_valid_batches,
                  validation_data=task.valid_dataset,
                  verbose=2
                  )

Train for 3125 steps, validate for 20000 steps
Epoch 1/20
3125/3125 - 137s - loss: 1.5411 - classification_loss: 1.5411 - sparse_categorical_accuracy: 0.4521 - val_loss: 2.2272 - val_classification_loss: 2.2272 - val_sparse_categorical_accuracy: 0.2747
Epoch 2/20
3125/3125 - 133s - loss: 1.6724 - classification_loss: 1.6724 - sparse_categorical_accuracy: 0.4110 - val_loss: 3.8626 - val_classification_loss: 3.8626 - val_sparse_categorical_accuracy: 0.1329
Epoch 3/20
3125/3125 - 134s - loss: 1.7754 - classification_loss: 1.7754 - sparse_categorical_accuracy: 0.3712 - val_loss: 3.3037 - val_classification_loss: 3.3037 - val_sparse_categorical_accuracy: 0.2101
Epoch 4/20
3125/3125 - 133s - loss: 1.8801 - classification_loss: 1.8801 - sparse_categorical_accuracy: 0.3296 - val_loss: 2.7884 - val_classification_loss: 2.7884 - val_sparse_categorical_accuracy: 0.1742
Epoch 5/20
3125/3125 - 135s - loss: 1.9907 - classification_loss: 1.9907 - sparse_categorical_accuracy: 0.2747 - val_loss: 2.8237

KeyboardInterrupt: 