refer to https://github.com/dedhiaparth98/spatial-transformer-network/blob/main/Visualizing-STN-MNIST.ipynb

In [None]:
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from scipy.io.matlab import loadmat
import tensorflow as tf

# tf.compat.v1.enable_eager_execution()
# tf.config.run_functions_eagerly(True)

import numpy as np
import datetime
import os

import HeadCT_motion_correction_PAR.motion_simulator.transformation as transform
%load_ext tensorboard

In [None]:
# !wget -c https://www.cs.toronto.edu/~tijmen/affNIST/32x/transformed/training_and_validation_batches.zip
# import zipfile
# zfile = os.path.join('/mnt/mount_zc_NAS/MNIST/training_and_validation_batches.zip')
# with zipfile.ZipFile(zfile,"r") as zip_ref:
#     zip_ref.extractall("/mnt/mount_zc_NAS/MNIST/")
# # ! unzip training_and_validation_batches.zip

In [None]:
batch_size = 256
epochs = 100

3D version of STN

In [None]:
image_path = '/mnt/mount_zc_NAS/MNIST/training_and_validation_batches/1.mat'
temp = loadmat(image_path)

x = temp['affNISTdata']['image'][0][0].reshape(40, 40, -1)
# make pseudo-3D image
print(x.shape)
x_3d = np.zeros((x.shape[0],x.shape[1],3, x.shape[2]))

for i in range(0,x.shape[-1]):
    x_3d[:,:,0,i] = x[:,:,i]
    x_3d[:,:,1,i] = x[:,:,i]
    x_3d[:,:,2,i] = x[:,:,i]
print(x_3d.shape)
x = np.moveaxis(x_3d, -1, 0)
x = np.expand_dims(x, axis=-1)
print(x.shape)
x = x/255.0

y = temp['affNISTdata']['label_int'][0][0]
y = np.moveaxis(y, -1, 0)
print(y.shape)
y = y.astype(np.int32)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42)
X_train = X_train[0:10000,...]
y_train = y_train[0:10000,...]
print(X_train.shape, X_test.shape)

In [None]:
mnist_train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
mnist_train_ds = mnist_train_ds.shuffle(5000).batch(batch_size)

In [None]:
class Localization(tf.keras.layers.Layer):
    def __init__(self):
        super(Localization, self).__init__()
        self.pool1 = tf.keras.layers.MaxPool3D(pool_size=(2,2,1))
        self.conv1 = tf.keras.layers.Conv3D(20, [3,3,1], activation='relu')
        self.pool2 = tf.keras.layers.MaxPool3D(pool_size=(2,2,1))
        self.conv2 = tf.keras.layers.Conv3D(20, [3,3,1], activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(20, activation='relu')
        self.fc2 = tf.keras.layers.Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')

    def build(self, input_shape):
        print("Building Localization Network with input shape:", input_shape)

    def compute_output_shape(self, input_shape):
        return [None, 6]

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        theta = self.fc2(x)
        theta = tf.keras.layers.Reshape((2, 3))(theta)
        return theta

In [None]:
class BilinearInterpolation(tf.keras.layers.Layer):
    def __init__(self, height=40, width=40):
        super(BilinearInterpolation, self).__init__()
        self.height = height
        self.width = width

    def compute_output_shape(self, input_shape):
        return [None, self.height, self.width, 1]

    def get_config(self):
        return {
            'height': self.height,
            'width': self.width,
        }
    
    def build(self, input_shape):
        print("Building Bilinear Interpolation Layer with input shape:", input_shape)

    def advance_indexing(self, inputs, x, y):
        '''
        Numpy like advance indexing is not supported in tensorflow, hence, this function is a hack around the same method
        '''      
      
        shape = tf.shape(inputs)
        batch_size, _, _ = shape[0], shape[1], shape[2] # input (batch_size, height, width, channel)
        
        batch_idx = tf.range(0, batch_size)
        batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
        b = tf.tile(batch_idx, (1, self.height, self.width)) # shape (batch_size, height, width)
        indices = tf.stack([b, y,x], 3) # shape (batch_size, height, width, 3)

        return tf.gather_nd(inputs, indices) # Important!: shape(batch_size, height, width, 1), value of these indices in the inputs image

    def call(self, inputs):
        images, theta = inputs
        homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0])
        return self.interpolate(images, homogenous_coordinates, theta)

    def grid_generator(self, batch):
        x = tf.linspace(-1, 1, self.width) # normalization
        y = tf.linspace(-1, 1, self.height)
            
        xx, yy = tf.meshgrid(x, y) # xx shape (height,width)
        xx = tf.reshape(xx, (-1,)) # flatten
        yy = tf.reshape(yy, (-1,))
        homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) # add the channel, shape (3, height*width)
        homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) # shape (1, 3, height*width)
        homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) # shape (batch_size, 3, height*width)
        homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32)
        return homogenous_coordinates
    
    def interpolate(self, images, homogenous_coordinates, theta):

        with tf.name_scope("Transformation"):
            transformed = tf.matmul(theta, homogenous_coordinates) # theta shape (2,3), coordinates shape (batch_size,3,height*width), transformed shape (batch_size, 2, height*width)
            transformed = tf.transpose(transformed, perm=[0, 2, 1])
            transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) # (batch_size, height, width, 2) 2 here represents x and y
            
            x_transformed = transformed[:, :, :, 0] 
            y_transformed = transformed[:, :, :, 1]
    
            # previously we have converted the coordinate system into -1 to 1, here we convert it back.
            x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5
            y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5
    
        with tf.name_scope("VariableCasting"):
            x0 = tf.cast(tf.math.floor(x), dtype=tf.int32)
            x1 = x0 + 1
            y0 = tf.cast(tf.math.floor(y), dtype=tf.int32)
            y1 = y0 + 1
            
            x0 = tf.clip_by_value(x0, 0, self.width-1)
            x1 = tf.clip_by_value(x1, 0, self.width-1)
            y0 = tf.clip_by_value(y0, 0, self.height-1)
            y1 = tf.clip_by_value(y1, 0, self.height-1)
            x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32)-1.0)
            y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32)-1)

        with tf.name_scope("AdvanceIndexing"):
            Ia = self.advance_indexing(images, x0, y0)
            Ib = self.advance_indexing(images, x0, y1)
            Ic = self.advance_indexing(images, x1, y0)
            Id = self.advance_indexing(images, x1, y1)

        with tf.name_scope("Interpolation"): # bilinear interpolation
            x0 = tf.cast(x0, dtype=tf.float32)
            x1 = tf.cast(x1, dtype=tf.float32)
            y0 = tf.cast(y0, dtype=tf.float32)
            y1 = tf.cast(y1, dtype=tf.float32)
                            
            wa = (x1-x) * (y1-y)
            wb = (x1-x) * (y-y0)
            wc = (x-x0) * (y1-y)
            wd = (x-x0) * (y-y0)

            wa = tf.expand_dims(wa, axis=3)
            wb = tf.expand_dims(wb, axis=3)
            wc = tf.expand_dims(wc, axis=3)
            wd = tf.expand_dims(wd, axis=3)
                        
        return tf.math.add_n([wa*Ia + wb*Ib + wc*Ic + wd*Id])

In [None]:
image = np.zeros((2, 40, 40, 3, 1))
image[:, 15:25, 15:25, :,:] = 1

t = [0,0]
# define rotation
r = 20
r = r/180 * np.pi 

translation,rotation,scale,transformation_matrix = transform.generate_transform_matrix(t,r,[1,1],image[0,:,:,0,0].shape)
transformation_matrix = transform.transform_full_matrix_offset_center(transformation_matrix, image[0,:,:,0,0].shape)
transformation_matrix = transformation_matrix[:2,...]
img_t = transform.apply_affine_transform(image[0,:,:,0,0], transformation_matrix, 3 )
plt.figure(figsize = (8,4))
plt.subplot(121); plt.imshow(image[0,:,:,0,0], 'gray')
plt.subplot(122); plt.imshow(img_t, 'gray')
print(transformation_matrix)

In [None]:
bi = BilinearInterpolation(height=40, width = 40)
# translation should be normalized to -1 to 1!!!!!
# no need to do offset
tt = t
_,_,_,transformation_matrix_n = transform.generate_transform_matrix([ttt / 40 * 2 for ttt in tt],r,[1,1],image[0,:,:,0,0].shape)  
transformation_matrix_n = transformation_matrix_n[0:2,:]
print(transformation_matrix_n)

# roll axis:
image_r = np.rollaxis(image, -2, 0) # input should be (slice_num, batch_size, height, width,channel_num )
print(image_r.shape)

array = [bi([ii , transformation_matrix_n]) 
            for ii in image_r
        ]
array = np.stack(array, axis=0)

result = np.rollaxis(array, 0, -1)
print(result.shape)

In [None]:
plt.subplot(2, 2, 1)
plt.imshow(image[0, :, :, 0, 0], cmap='gray')

plt.subplot(2, 2, 2)
plt.imshow(image[0, :, :, 1, 0], cmap='gray')

plt.subplot(2, 2, 3)
plt.imshow(result[0, :, :, 0 ,0], cmap='gray')

plt.subplot(2, 2, 4)
plt.imshow(result[1, :, :,1, 0], cmap='gray')

In [None]:
def get_model(input_shape):
    image = tf.keras.layers.Input(shape=input_shape)
    theta = Localization()(image)
    print(theta.shape)

    # get theta by each neuron
    t1 = theta[...,0,0]; t1 = tf.reshape(t1,[-1,1,1])
    print(t1.shape)
    # t2 = theta[...,0,1]; t2 = tf.reshape(t2,[-1,1,1])
    t2 = tf.zeros_like(t1)
    print(t2.shape)
    print(t2)
    t3 = theta[...,0,2]; t3 = tf.reshape(t3,[-1,1,1])
    t4 = theta[...,1,0]; t4 = tf.reshape(t2,[-1,1,1])
    t5 = theta[...,1,1]; t5 = tf.reshape(t5,[-1,1,1])
    t6 = theta[...,1,2]; t6 = tf.reshape(t6,[-1,1,1])
    first_row = tf.concat([t1,t2,t3],axis = 2)
    second_row = tf.concat([t4,t5,t6],axis = 2)
    theta_c = tf.concat([first_row,second_row],axis = 1)
    print(theta_c.shape)
    
    # do STN for each slice:
    image_r = tf.transpose(image, [3,0,1,2,4])
    array = [BilinearInterpolation(height=input_shape[0], width=input_shape[1])([ii, theta_c])for ii in image_r]
    array = tf.stack(array, axis=0)
    x = tf.transpose(array, [1,2,3,0,4])

    # x = BilinearInterpolation(height=input_shape[0], width=input_shape[1])([image, theta])
    x = tf.keras.layers.Conv3D(64, [9, 9, 1], activation='relu', name = 'first_conv3d')(x)
    x = tf.keras.layers.MaxPool3D(pool_size = (2,2,1))(x)
    x = tf.keras.layers.Conv3D(64, [7, 7, 1], activation='relu')(x)
    x = tf.keras.layers.MaxPool3D(pool_size = (2,2,1))(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(64, activation='relu')(x)
    x = tf.keras.layers.Dense(32, activation='relu')(x)
    x = tf.keras.layers.Dense(10, activation='softmax')(x)

    return tf.keras.models.Model(inputs=image, outputs=x)

In [None]:
model = get_model((40, 40, 3,1))

In [None]:
model.summary()

In [None]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
def schedular(epoch, lr):
    if epoch % 20 == 0 and epoch > 0:
        print("Learning Rate Updated")
        lr /= 10
    return lr

In [None]:
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
learning_rate_callback = tf.keras.callbacks.LearningRateScheduler(schedular)

In [None]:
history = model.fit(mnist_train_ds, epochs=epochs, callbacks=[tensorboard_callback, learning_rate_callback])

In [None]:
model.save_weights('./model/weights', save_format='tf')

In [None]:
print(model.layers[1], model.get_layer('first_conv3d'), model.get_layer('tf.concat_18'))

In [None]:
# stn = tf.keras.models.Model(inputs=model.inputs, outputs=[model.layers[1].output, model.layers[2].output])
stn = tf.keras.models.Model(inputs=model.inputs, outputs=[model.layers[1].output, model.get_layer('tf.concat_18').output, model.get_layer('tf.compat.v1.transpose_9').output])

In [None]:
theta_original, theta_changed, prediction = stn.predict(X_test)
print(theta_original.shape, theta_original[0,...], theta_changed[0,...])
print(prediction.shape)

In [None]:
index = 525

plt.subplot(1,2,1)
plt.title(y_test[index])
plt.imshow(X_test[index, :, :, 0,0], cmap='gray')

plt.subplot(1, 2, 2)
plt.title(np.argmax(model.predict(np.expand_dims(X_test[index, :, :, :,0], axis=0))))
plt.imshow(prediction[index, :, :,0, 0], cmap='gray')

print("\n", theta_original[index,...])