Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spatial transformer implementation may have a bug somewhere #193

Closed
seuqaj114 opened this issue Jun 13, 2016 · 10 comments
Closed

Spatial transformer implementation may have a bug somewhere #193

seuqaj114 opened this issue Jun 13, 2016 · 10 comments
Assignees
Labels
stat:awaiting model gardener Waiting on input from TensorFlow model gardener

Comments

@seuqaj114
Copy link

seuqaj114 commented Jun 13, 2016

So I was trying to plug this ST module into a the write attention part of a DRAW model, and I just couldn't get it to work. After a grueling day of trying every parameter choice, I tried comparing the output of an identity scaling of ST vs scipy interpolation zoom, and that's when I found something interesting.

The code I used is below, adapted from the example.py.

from scipy import ndimage
import tensorflow as tf
from spatial_transformer import transformer
import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1.0/(1.0+np.exp(-x))

# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from:
# %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
#im = np.load("../../datautils/templates/circle/circle.npy")
#im = im / 255.
im = np.array([-1.2053933, -1.1743802, -0.75044346, -0.74455976, -1.0506268,
 -0.91364104, -0.21054152, 0.1543106, 0.032554384, -0.52717745,
-0.66026419, -0.021319218, -0.060581781, -0.099243492, -0.26127103,
 -0.52252597, 0.1389422, -0.13638327, 0.033274196, -0.20344208,
  -0.53625256, 0.02523746, -0.076311894, 0.10775769, 0.20])
im = im.reshape(1, 5, 5, 1)
im = im.astype('float32')

# %% Let the output size of the transformer be 5 times the image size.
out_size = (30, 30)

# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3

x = tf.placeholder(tf.float32, [None, 5, 5, 1])
x = tf.cast(batch, 'float32')

# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):

    # %% Create a fully-connected layer with 6 output nodes
    n_fc = 6
    W_fc1 = tf.Variable(tf.zeros([5*5*1, n_fc]), name='W_fc1')

    # %% Zoom into the image
    initial = np.array([[1.0, 0, 0.0], [0, 1.0, 0.0]])
    initial = initial.astype('float32')
    initial = initial.flatten()

    b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
    h_fc1 = tf.matmul(tf.zeros([num_batch, 5*5*1]), W_fc1) + b_fc1
    h_trans = transformer(x, h_fc1, out_size)
    h_trans = tf.sigmoid(h_trans)

# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})

imgplot = plt.imshow(y[0].reshape(30,30))
imgplot.set_cmap('gray')
plt.savefig("myfig.png")

imgplot = plt.imshow(sigmoid(ndimage.interpolation.zoom(im.reshape(5,5),6.0)))
imgplot.set_cmap('gray')
plt.savefig("myfigorig.png")

The resulting images are the following.

Result of scipy zoom (correct)
myfigorig

Result of spatial trasnformer
myfig

As you can see the trasnformer output neglects that down and right side, and instead creates a 23x23 (ish) version of the image, when I asked it to create a 30x30 version. This can easily go unnoticed if the background of the image is black itself, which is why it took me so long to notice it.

Let me know what you think,
Miguel

@seuqaj114
Copy link
Author

I found a problem in the code that may be causing this.

In the _interpolate function, the lines
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
actually cause the scaling to be done incorrectly, since the indices shouldn't be scaled from to [0, width], but to [0, width-1], because indices in the original image go from 0 to width-1. Same for the height.

Substituting the line above to
x = (x + 1.0)*(width_f-1.01) / 2.0
corrects the problem mentioned in the original post (using -1.0 doesn't cut it, it has to be an epsilon smaller than -1.0).

I have to investigate whether this fix breaks the case where out_size is smaller than the image size, but from these results I believe it fixes the case where out_size is bigger than the image size - it works for any affine transformation, not just identity as in my example. If it doesn't, should I make a pull request or should someone else check on this?

@prb12
Copy link
Member

prb12 commented Jun 21, 2016

@elezar You recenetly changed the code mentioned in the above message. Could you please take a look at this bug?

@prb12 prb12 added the triaged label Jun 21, 2016
@elezar
Copy link

elezar commented Jun 22, 2016

Thanks for pinging me on this @prb12. I will have a look, but if I recall correctly, the only changes that I made to the code was to update the documentation for the function parameters, and apply PEP8 formatting. (or at least this was my intention. It could be that I inadvertently committed something that I shouldn't have).

@elezar
Copy link

elezar commented Jun 22, 2016

I have confirmed that the only changes were to whitespace (and the formulation that is is incorrect is also used in https://github.com/skaae/transformer_network/blob/master/transformerlayer.py#L103 from which this implementation is apparently taken).

@elezar
Copy link

elezar commented Jun 22, 2016

With regards to #193 (comment). Note, I don't agree that 1.01 should be used here. That seems a little arbitrary, and an input value of 1 should return the value of that last pixel exactly!

It seems as if it has been corrected in https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/special.py#L473

Looking at the file history the change was introduced in this commit:
Lasagne/Lasagne@c8572b2

@prb12
Copy link
Member

prb12 commented Jun 22, 2016

@elezar Who is the owner of this code? If it was copied from another repo, should the fix be imported?

@elezar
Copy link

elezar commented Jun 23, 2016

@prb12 It seems as if it was added by @daviddao in 41c52d6. It was later modified by @psycharo in bf60abf.

If I understand correctly, the code was adapted from https://github.com/skaae/transformer_network (which was subsequently moved into Lasagne as a layer). The original example also contains the bug.

The way I see it, there are two options:

  1. The example here should simply be updated (and kept update)
  2. The "layer" should be implemented in TensorFlow itself, as it seems to have general applicability.

@aselle
Copy link
Contributor

aselle commented Jul 27, 2016

@seuqaj114
Interpolate seems correct to me. If you have real-valued spatial coordinates between in [-1, 1] x [1,1] and you have a 4x4 image. That image corresponds to 16 boxes. Let's take two examples
[-1,-1], [-.5,-.5] --> pixel 0,0
[0.5,0.5], [1,1] --> pixel 3,3

The formula that achieves that mapping is
pixel = clamp( (xy + [1,1])/2 * [width,height], [0,0], [width-1,height-1])
e.g. ([1,1] + [1,1])/2 * [width,height]) = [4,4] but then clamping makes it [3,3].
Dividing by width-1 uses the wrong mapping. You need to think about the boxes that the pixels represent to do these types of transforms correctly.

So the point is that the clamp is the part that is missing.

@aselle aselle removed the triaged label Jul 28, 2016
@michaelisard michaelisard added the stat:awaiting response Waiting on input from the contributor label Jul 28, 2016
@girving
Copy link

girving commented Aug 8, 2016

@aselle Is there a bug here that should be fixed?

@aselle aselle added stat:awaiting model gardener Waiting on input from TensorFlow model gardener and removed stat:awaiting response Waiting on input from the contributor labels Aug 18, 2016
@aselle aselle self-assigned this Aug 18, 2016
mfolnovic added a commit to mfolnovic/models that referenced this issue Nov 17, 2016
@itsmeolivia itsmeolivia added the stat:awaiting model gardener Waiting on input from TensorFlow model gardener label Feb 7, 2018
@itsmeolivia
Copy link
Contributor

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting model gardener Waiting on input from TensorFlow model gardener
Projects
None yet
Development

No branches or pull requests

7 participants