Skip to content

Commit

Permalink
Fix bug in transformer model: Changed the examples to specify the out…
Browse files Browse the repository at this point in the history
…put size (tensorflow#155)

* Changed the examples to specify the output size instead of the downsample_factor.

This is required by PR tensorflow#57

* Address flake8 errors.

* Update readme and parameter descriptions.
  • Loading branch information
Evan Lezar authored and martinwicke committed May 31, 2016
1 parent 76f567d commit eec7938
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 86 deletions.
9 changes: 2 additions & 7 deletions transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@ transformer(U, theta, downsample_factor=1)
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
out_size: tuple of two ints
The size of the output of the network

#### Notes
Expand Down
53 changes: 27 additions & 26 deletions transformer/cluttered_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# =============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
from tf_utils import weight_variable, bias_variable, dense_to_one_hot

# %% Load data
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
Expand All @@ -37,7 +35,7 @@
# %% Graph representation of our network

# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])

# %% Since x is currently [batch, height*width], we need to reshape to a
Expand All @@ -48,13 +46,15 @@
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])

# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
# %% We'll setup the two-layer localisation network to figure out the
# %% parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])

W_fc_loc2 = weight_variable([20, 6])
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
# Use identity transformation as starting point
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
Expand All @@ -67,8 +67,10 @@
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)

# %% We'll create a spatial transformer module to identify discriminative patches
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
# %% We'll create a spatial transformer module to identify discriminative
# %% patches
out_size = (40, 40)
h_trans = transformer(x_tensor, h_fc_loc2, out_size)

# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
Expand Down Expand Up @@ -140,33 +142,32 @@
n_epochs = 500
train_size = 10000

indices = np.linspace(0,10000 - 1,iter_per_epoch)
indices = np.linspace(0, 10000 - 1, iter_per_epoch)
indices = indices.astype('int')

for epoch_i in range(n_epochs):
for iter_i in range(iter_per_epoch - 1):
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]

if iter_i % 10 == 0:
loss = sess.run(cross_entropy,
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))

sess.run(optimizer, feed_dict={
x: batch_xs, y: batch_ys, keep_prob: 0.8})


print('Accuracy: ' + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
#theta = sess.run(h_fc_loc2, feed_dict={

print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
# theta = sess.run(h_fc_loc2, feed_dict={
# x: batch_xs, keep_prob: 1.0})
#print(theta[0])
# print(theta[0])
19 changes: 11 additions & 8 deletions transformer/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,45 @@
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable

# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
# %% Image retrieved from:
# %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')

# %% Let the output size of the transformer be half the image size.
out_size = (600, 800)

# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3

x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
x = tf.cast(batch, 'float32')

# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):

# %% Create a fully-connected layer with 6 output nodes
n_fc = 6
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')

# %% Zoom into the image
initial = np.array([[0.5,0, 0],[0,0.5,0]])
initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
initial = initial.astype('float32')
initial = initial.flatten()

b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)
h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, out_size)

# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})

# plt.imshow(y[0])
# plt.imshow(y[0])
94 changes: 49 additions & 45 deletions transformer/spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,49 @@
# ==============================================================================
import tensorflow as tf


def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
Based on [2]_ and edited by David Dao for Tensorflow.
Parameters
----------
U : float
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
out_size: tuple of two floats
The size of the output of the network
out_size: tuple of two ints
The size of the output of the network (height, width)
References
----------
.. [1] Spatial Transformer Networks
Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
Submitted on 5 Jun 2015
.. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
Notes
-----
To initialize the network to the identity transform init
``theta`` to :
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
"""

def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.pack([n_repeats, ])), 1), [1, 0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])

def _interpolate(im, x, y, out_size):
with tf.variable_scope('_interpolate'):
Expand All @@ -69,13 +71,13 @@ def _interpolate(im, x, y, out_size):
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
out_width = out_size[1]
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')

# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0

# do sampling
Expand All @@ -98,8 +100,9 @@ def _interpolate(im, x, y, out_size):
idx_c = base_y0 + x1
idx_d = base_y1 + x1

# use indices to lookup pixels in the flat image and restore channels dim
im_flat = tf.reshape(im,tf.pack([-1, channels]))
# use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.pack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Expand All @@ -111,13 +114,13 @@ def _interpolate(im, x, y, out_size):
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output

def _meshgrid(height, width):
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
Expand All @@ -126,12 +129,12 @@ def _meshgrid(height, width):
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
tf.ones(shape=tf.pack([1, width])))
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.pack([1, width])))

x_t_flat = tf.reshape(x_t,(1, -1))
y_t_flat = tf.reshape(y_t,(1, -1))
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))

ones = tf.ones_like(x_t_flat)
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
Expand All @@ -141,7 +144,7 @@ def _transform(theta, input_dim, out_size):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
Expand All @@ -150,37 +153,39 @@ def _transform(theta, input_dim, out_size):
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
grid = tf.tile(grid,tf.pack([num_batch]))
grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.pack([num_batch]))
grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))

# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)
x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
x_s_flat = tf.reshape(x_s,[-1])
y_s_flat = tf.reshape(y_s,[-1])
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])

input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
out_size)
input_dim, x_s_flat, y_s_flat,
out_size)

output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
output = tf.reshape(
input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output

with tf.variable_scope(name):
output = _transform(theta, U, out_size)
return output


def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
"""Batch Spatial Transformer Layer
Parameters
----------
U : float
tensor of inputs [num_batch,height,width,num_channels]
thetas : float
Expand All @@ -196,4 +201,3 @@ def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
indices = [[i]*num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)

0 comments on commit eec7938

Please sign in to comment.