Skip to content

Commit

Permalink
Spatial transformer: (tensorflow#57)
Browse files Browse the repository at this point in the history
* Modified the way the output size is specified.
* Added support for batches of inputs.
  • Loading branch information
psycharo authored and martinwicke committed May 12, 2016
1 parent 8332400 commit bf60abf
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions transformer/spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
import tensorflow as tf

def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
Expand All @@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
out_size: tuple of two floats
The size of the output of the network
References
----------
.. [1] Spatial Transformer Networks
Expand All @@ -61,7 +56,7 @@ def _repeat(x, n_repeats):
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])

def _interpolate(im, x, y, downsample_factor):
def _interpolate(im, x, y, out_size):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
Expand All @@ -73,8 +68,8 @@ def _interpolate(im, x, y, downsample_factor):
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
out_height = out_size[0]
out_width = out_size[1]
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
Expand Down Expand Up @@ -142,7 +137,7 @@ def _meshgrid(height, width):
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
return grid

def _transform(theta, input_dim, downsample_factor):
def _transform(theta, input_dim, out_size):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
Expand All @@ -154,8 +149,8 @@ def _transform(theta, input_dim, downsample_factor):
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
Expand All @@ -171,11 +166,34 @@ def _transform(theta, input_dim, downsample_factor):

input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
downsample_factor)
out_size)

output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output

with tf.variable_scope(name):
output = _transform(theta, U, downsample_factor)
return output
output = _transform(theta, U, out_size)
return output

def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
"""Batch Spatial Transformer Layer
Parameters
----------
U : float
tensor of inputs [num_batch,height,width,num_channels]
thetas : float
a set of transformations for each input [num_batch,num_transforms,6]
out_size : int
the size of the output [out_height,out_width]
Returns: float
Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
"""
with tf.variable_scope(name):
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
indices = [[i]*num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)

0 comments on commit bf60abf

Please sign in to comment.