In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import Model, Sequential
from keras.layers import Dense, Input, Conv2D, Dropout, Layer, ReLU, MaxPool2D, Flatten
from keras.metrics import CategoricalAccuracy, CategoricalCrossentropy
from itertools import product
import shutil

In [2]:
dtype = tf.float32
intDtype = tf.int32
from keras.preprocessing.image import ImageDataGenerator
trainDataGen = ImageDataGenerator(rescale=1.0/255.0, dtype= dtype, validation_split= 2.00075/10) # For some reason 0.2 does not works for 20% split.
batchSize = 100
inputSize = 49
gridSize = 49

In [7]:
dirName = '/content/affineMNIST'

In [8]:
shutil.unpack_archive('/content/affineMNIST.zip',extract_dir= dirName)

In [9]:
trainGenerator = trainDataGen.flow_from_directory(dirName, target_size= (inputSize,inputSize), color_mode= 'grayscale', batch_size = batchSize, shuffle = True, seed = 42, subset='training')
testGenerator = trainDataGen.flow_from_directory(dirName, target_size= (inputSize, inputSize), color_mode = 'grayscale', batch_size = batchSize, shuffle = True, seed = 42, subset = 'validation')

Found 48000 images belonging to 10 classes.
Found 12000 images belonging to 10 classes.


In [None]:
LocNet = Sequential([
   Input(shape=(inputSize,inputSize,1)),
   Conv2D(7, (3, 3), strides=(2,2), padding="same", activation= 'relu', bias_initializer = 'glorot_uniform'),
   Conv2D(7, (5, 5), strides=(2,2), padding='same', activation= 'relu', bias_initializer = 'glorot_uniform'),
   MaxPool2D(),
   Flatten(),
   Dense(30, activation= 'relu'),
   Dense(6)
], name = 'LocalisationNetwork')
LocNet.summary()

In [None]:
CNN = Sequential([
    Input(shape=(gridSize,gridSize,1)),
    Conv2D(7, (7, 7), strides=(2,2), padding="same", activation= 'relu',bias_initializer = 'glorot_uniform'),
    MaxPool2D(),
    Flatten(),
    Dense(10,activation='softmax')
], name = 'CNN')
CNN.summary()

In [12]:
networkLoss = keras.losses.categorical_crossentropy

In [13]:
adamLearningRate = 0.0005

cnnOptimiser = keras.optimizers.Adam(learning_rate=adamLearningRate)
stnOptimiser = keras.optimizers.Adam(learning_rate=adamLearningRate)

In [14]:
def sampler(input, batchTheta):
  delta = 1e-9
  P1, P2, P3, P4, P5, P6 = tf.unstack(batchTheta, axis = 1)
  zeros, ones = tf.zeros(batchSize, dtype = dtype), tf.ones(batchSize, dtype = dtype)

  thetas = tf.stack([[P1 + ones, P2, P3], [P4, P5 + ones, P6], [zeros, zeros, ones]])
  thetas = tf.transpose(thetas, perm = [2, 0, 1])
  # (None, 3,3)
  indexRange = (gridSize - 1)/2
  vGrid = tf.meshgrid(tf.linspace(-indexRange, indexRange, gridSize), tf.linspace(-indexRange, indexRange, gridSize))
  
  vGrid = tf.cast(vGrid, dtype = dtype)
  xvGrid, yvGrid = tf.reshape(vGrid[0], (1, -1)), tf.reshape(vGrid[1], (1, -1))
  vGrid = tf.stack([xvGrid, yvGrid, tf.ones_like(yvGrid, dtype = dtype)], axis = 1)
  vGrid = tf.tile(vGrid, [batchSize, 1, 1])
  # (b*3*gridSize**2) = (b*3*3) * (b*3*gridSize**2)
  Gs = tf.matmul(thetas, vGrid)
  Xs, Ys, scale = tf.unstack(Gs, axis= 1)
  Xs, Ys = tf.reshape(Xs/(scale + delta), [batchSize, gridSize, gridSize]), tf.reshape(Ys/(scale + delta), [batchSize, gridSize, gridSize])

  XFloor, YFloor = tf.floor(Xs), tf.floor(Ys)
  XCeil, YCeil = tf.math.ceil(Xs), tf.math.ceil(Ys)
  XfloorInt, YfloorInt = tf.cast(XFloor, intDtype), tf.cast(YFloor, intDtype)
  XceilInt, YceilInt = tf.cast(XCeil, intDtype), tf.cast(YCeil, intDtype)
  imageIndex = np.tile(np.arange(batchSize).reshape([batchSize,1,1]),[1,inputSize,inputSize])
  imageCopy = tf.reshape(input,[-1,int(input.shape[-1])])
  imagePadded = tf.concat([imageCopy,tf.zeros([1,int(input.shape[-1])])],axis=0)

  firstQuarterIndex = (imageIndex*inputSize+YfloorInt)*inputSize+XfloorInt
  secondQuarterIndex = (imageIndex*inputSize+YfloorInt)*inputSize+XceilInt
  thirdQuarterIndex = (imageIndex*inputSize+YceilInt)*inputSize+XfloorInt
  fourthQuarterIndex = (imageIndex*inputSize+YceilInt)*inputSize+XceilInt
  paddingIndex = tf.fill([batchSize,inputSize,inputSize],batchSize*inputSize**2)

  xFloorRange = tf.logical_and(XfloorInt >= 0, XfloorInt < inputSize)
  yFloorRange = tf.logical_and(YfloorInt >= 0, YfloorInt < inputSize)
  xCeilRange = tf.logical_and(XceilInt >= 0, XceilInt < inputSize)
  yCeilRange = tf.logical_and(YceilInt >= 0, YceilInt < inputSize)

  firstQuarter = tf.logical_and(xFloorRange, yFloorRange)
  secondQuarter = tf.logical_and(xFloorRange, yCeilRange)
  thirdQuarter = tf.logical_and(xCeilRange, yFloorRange)
  fourthQuarter = tf.logical_and(xCeilRange, yCeilRange)
   
  firstQuarterIndex = tf.where(firstQuarter,firstQuarterIndex,paddingIndex)
  secondQuarterIndex = tf.where(secondQuarter,secondQuarterIndex,paddingIndex)
  thirdQuarterIndex = tf.where(thirdQuarter,thirdQuarterIndex,paddingIndex)
  fourthQuarterIndex = tf.where(fourthQuarter,fourthQuarterIndex,paddingIndex)

  Xratio = tf.reshape(Xs-XFloor,[batchSize,inputSize,inputSize,1])
  Yratio = tf.reshape(Ys-YFloor,[batchSize,inputSize,inputSize,1])
  
  firstImage = tf.cast(tf.gather(imagePadded,firstQuarterIndex), dtype)*(1-Xratio)*(1-Yratio)
  secondImage = tf.cast(tf.gather(imagePadded,secondQuarterIndex), dtype)*(Xratio)*(1-Yratio)
  thirdImage = tf.cast(tf.gather(imagePadded,thirdQuarterIndex), dtype)*(1-Xratio)*(Yratio)
  fourthImage = tf.cast(tf.gather(imagePadded,fourthQuarterIndex), dtype)*(Xratio)*(Yratio)
  imageUnWarped = firstImage+secondImage+thirdImage+fourthImage

  return imageUnWarped


In [15]:
def trainStep(U, label):
  varU = tf.Variable(U, trainable= True)

  with tf.GradientTape(persistent=True) as samplingTape:
    params = LocNet(varU)
    V = sampler(varU, params)

    with tf.GradientTape(persistent=True) as classificationTape:
      predictedLabel = CNN(V)
      loss = networkLoss(label, predictedLabel)
      loss = tf.reduce_mean(loss)

  cnnGradients = classificationTape.gradient(loss, CNN.trainable_weights)
  stnGradients = samplingTape.gradient(loss, LocNet.trainable_weights)

  cnnOptimiser.apply_gradients(zip(cnnGradients, CNN.trainable_weights))
  stnOptimiser.apply_gradients(zip(stnGradients, LocNet.trainable_weights))

  del samplingTape, classificationTape

  return loss, predictedLabel, V

In [16]:
def testStep(U, label):
  varU = tf.Variable(U, trainable= False)
  params = LocNet(varU)
  V = sampler(varU, params)

  predictedLabel = CNN(V)
  valLoss = networkLoss(label, predictedLabel)
  valLoss = tf.reduce_mean(valLoss)

  return valLoss, predictedLabel

In [None]:
epochAccuracy = CategoricalAccuracy()
epochLoss = CategoricalCrossentropy()
epochCount = 1
highAccuracy = False

with tf.device('/gpu:0'):
  while epochCount > 0:
    print(f'Epoch - {epochCount}')
    for step in range(len(trainGenerator)):
      genU, genLabel = trainGenerator[step]
      loss, predictedLabel, V = trainStep(genU, genLabel)

      predictedIndices = tf.argmax(predictedLabel, axis = 1)
      actualIndices = tf.argmax(genLabel, axis = 1)
      epochAccuracy.update_state(actualIndices, predictedIndices)
      epochLoss.update_state(actualIndices, predictedIndices)

      if step % 100 == 1:
        print(f'Step : {step}, loss : {loss}')

    acc = epochAccuracy.result().numpy()    
    print(f'\nEpoch Accuracy : {acc*100} %')
    epoLoss = epochLoss.result().numpy()
    print(f'Epoch Loss : {epoLoss/(len(trainGenerator)*batchSize)}\n')
    epochCount += 1
    epochAccuracy.reset_states()
    epochLoss.reset_states()

    if acc > 0.90:
      if highAccuracy:
        break
      else:
        highAccuracy = True
    else:
      highAccuracy = False 

In [None]:
epochTestAccuracy = CategoricalAccuracy()
epochTestLoss = CategoricalCrossentropy()
with tf.device('/gpu:0'):
  for batch in range(len(testGenerator)):
    genU, genLabel = testGenerator[batch]
    valLoss, predictedLabel = testStep(genU, genLabel)
    predictedIndices = tf.argmax(predictedLabel, axis = 1)
    actualIndices = tf.argmax(genLabel, axis = 1)

    epochTestAccuracy.update_state(actualIndices, predictedIndices)
    epochTestLoss.update_state(actualIndices, predictedIndices)

    #print(f'For batch no. {batch + 1}, val loss. {valLoss}')

  print(f'Test Accuracy {epochTestAccuracy.result().numpy()*100}')
  print(f'Test Loss {epochTestLoss.result().numpy()/(len(testGenerator)*batchSize)}')