Skip to content

Commit

Permalink
Merge pull request #10 from ragavvenkatesan/Rotate-Layer-PR
Browse files Browse the repository at this point in the history
Rotate layer pr
  • Loading branch information
Ragav Venkatesan committed Feb 3, 2017
2 parents e874c70 + 13cbf13 commit 1517ded
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 3 deletions.
4 changes: 2 additions & 2 deletions yann/layers/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__ ( self,
elif type == 'concatenate':
self.output = T.concatenate([x[0],x[1]], axis = 1)
if len(input_shape[0]) == 2:
self.output_shape = (input_shape [0], input_shape[0][1] + input_shape[1][1])
self.output_shape = (input_shape [0][0], input_shape[0][1] + input_shape[1][1])
elif len(input_shape[1]) == 4:
self.output_shape = (input_shape [0], input_shape[0][1] + input_shape[1][1],
self.output_shape = (input_shape [0][0], input_shape[0][1] + input_shape[1][1],
input_shape[2], input_shape[3])

def loss(self, type = None):
Expand Down
257 changes: 257 additions & 0 deletions yann/layers/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""
This code is used to rotate the images given some angles between [0,1].
Obliging License, credit and conditions for Lasagne: Some part of the file was
directly reproduced from the Lasagne code base.
LICENSE
=======
Copyright (c) 2014-2015 Lasagne contributors
Lasagne uses a shared copyright model: each contributor holds copyright over
their contributions to Lasagne. The project versioning records all such
contribution and copyright details.
By contributing to the Lasagne repository through pull-request, comment,
or otherwise, the contributor releases their content to the license and
copyright terms herein.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from abstract import layer, _activate, _dropout
import numpy
import numpy.random as nprnd
import theano
import theano.tensor as T

class rotate_layer (layer):
"""
This is a rotate layer. This takes a layer and an angle (rotation normalized in [0,1]) as input
and rotates the batch of images by the specified rotation parameter.
Args:
input: An input ``theano.tensor`` variable. Even ``theano.shared`` will work as long as they
are in the following shape ``mini_batch_size, height, width, channels``
verbose: similar to the rest of the toolbox.
input_shape: ``(mini_batch_size, input_size)``
angle: value from [0,1]
borrow: ``theano`` borrow, typically ``True``.
input_params: Supply params or initializations from a pre-trained system.
"""

def __init__ (self,
input,
input_shape,
id,
angle = None,
borrow = True,
verbose = 2 ):
super(rotate_layer,self).__init__(id = id, type = 'rotate', verbose = verbose)
if verbose >= 3:
print "... Creating rotate layer"

if len(input_shape) == 4:
if verbose >= 3:
print "... Creating the rotate layer"

if angle is None:
angle = nprnd.uniform(size = (input_shape[0],1), low = 0, high = 1)

theta = T.stack([T.cos(angle[:,0]*90).reshape([angle.shape[0],1]),
-T.sin(angle[:,0]*90).reshape([angle.shape[0],1]),
T.zeros((input_shape[0],1),dtype='float32'),
T.sin(angle[:,0]*90).reshape([angle.shape[0],1]),
T.cos(angle[:,0]*90).reshape([angle.shape[0],1]),
T.zeros((input_shape[0],1),dtype='float32')], axis=1)
theta = theta.reshape((-1, 6))

self.output = self._transform_affine(theta, input)
self.output_shape = input_shape
self.angle = angle

def _transform_affine(self, theta, input):
num_batch, num_channels, height, width = input.shape
theta = T.reshape(theta, (-1, 2, 3))

# grid of (x_t, y_t, 1)
out_height = T.cast(height, 'int64')
out_width = T.cast(width, 'int64')
grid = self._meshgrid(out_height, out_width)

# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = T.dot(theta, grid)
x_s = T_g[:, 0]
y_s = T_g[:, 1]
x_s_flat = x_s.flatten()
y_s_flat = y_s.flatten()

# dimshuffle input to (bs, height, width, channels)
input_dim = input.dimshuffle(0, 2, 3, 1)
input_transformed = self._interpolate(
input_dim, x_s_flat, y_s_flat,
out_height, out_width)

output = T.reshape(
input_transformed, (num_batch, out_height, out_width, num_channels))
output = output.dimshuffle(0, 3, 1, 2) # dimshuffle to conv format
return output

def _interpolate(self, im, x, y, out_height, out_width):
# *_f are floats
num_batch, height, width, channels = im.shape
height_f = T.cast(height, theano.config.floatX)
width_f = T.cast(width, theano.config.floatX)

# clip coordinates to [-1, 1]
x = T.clip(x, -1, 1)
y = T.clip(y, -1, 1)

# scale coordinates from [-1, 1] to [0, width/height - 1]
x = (x + 1) / 2 * (width_f - 1)
y = (y + 1) / 2 * (height_f - 1)

# obtain indices of the 2x2 pixel neighborhood surrounding the coordinates;
# we need those in floatX for interpolation and in int64 for indexing. for
# indexing, we need to take care they do not extend past the image.
x0_f = T.floor(x)
y0_f = T.floor(y)
x1_f = x0_f + 1
y1_f = y0_f + 1
x0 = T.cast(x0_f, 'int64')
y0 = T.cast(y0_f, 'int64')
x1 = T.cast(T.minimum(x1_f, width_f - 1), 'int64')
y1 = T.cast(T.minimum(y1_f, height_f - 1), 'int64')

# The input is [num_batch, height, width, channels]. We do the lookup in
# the flattened input, i.e [num_batch*height*width, channels]. We need
# to offset all indices to match the flat version
dim2 = width
dim1 = width*height
base = T.repeat(
T.arange(num_batch, dtype='int64')*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1

# use indices to lookup pixels for all samples
im_flat = im.reshape((-1, channels))
Ia = im_flat[idx_a]
Ib = im_flat[idx_b]
Ic = im_flat[idx_c]
Id = im_flat[idx_d]

# calculate interpolated values
wa = ((x1_f-x) * (y1_f-y)).dimshuffle(0, 'x')
wb = ((x1_f-x) * (y-y0_f)).dimshuffle(0, 'x')
wc = ((x-x0_f) * (y1_f-y)).dimshuffle(0, 'x')
wd = ((x-x0_f) * (y-y0_f)).dimshuffle(0, 'x')
output = T.sum([wa*Ia, wb*Ib, wc*Ic, wd*Id], axis=0)
return output

def _linspace(self, start, stop, num):
# Theano linspace. Behaves similar to np.linspace
start = T.cast(start, theano.config.floatX)
stop = T.cast(stop, theano.config.floatX)
num = T.cast(num, theano.config.floatX)
step = (stop-start)/(num-1)
return T.arange(num, dtype=theano.config.floatX)*step+start

def _meshgrid(self, height, width):
# This function is the grid generator.
# It is equivalent to the following numpy code:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
# It is implemented in Theano instead to support symbolic grid sizes.
# Note: If the image size is known at layer construction time, we could
# compute the meshgrid offline in numpy instead of doing it dynamically
# in Theano. However, it hardly affected performance when we tried.
x_t = T.dot(T.ones((height, 1)),
self._linspace(-1.0, 1.0, width).dimshuffle('x', 0))
y_t = T.dot(self._linspace(-1.0, 1.0, height).dimshuffle(0, 'x'),
T.ones((1, width)))

x_t_flat = x_t.reshape((1, -1))
y_t_flat = y_t.reshape((1, -1))
ones = T.ones_like(x_t_flat)
grid = T.concatenate([x_t_flat, y_t_flat, ones], axis=0)
return grid


class dropout_rotate_layer (rotate_layer):
"""
This class is the typical dropout neural hidden layer and batch normalization layer. Called
by the ``add_layer`` method in network class.
Args:
input: An input ``theano.tensor`` variable. Even ``theano.shared`` will work as long as they
are in the following shape ``mini_batch_size, height, width, channels``
verbose: similar to the rest of the toolbox.
num_neurons: number of neurons in the layer
input_shape: ``(mini_batch_size, input_size)``
batch_norm: If provided will be used, default is ``False``.
rng: typically ``numpy.random``.
borrow: ``theano`` borrow, typicall ``True``.
dropout_rate: ``0.5`` is the default.
Notes:
Use ``dropout_rotate_layer.output`` and ``dropout_rotate_layer.output_shape`` from
this class. ``L1`` and ``L2`` are also public and can also can be used for regularization.
The class also has in public ``w``, ``b`` and ``alpha``
which are also a list in ``params``, another property of this class.
"""
def __init__ (self,
input,
input_shape,
id,
rng = None,
dropout_rate = 0.5,
angle = None,
borrow = True,
verbose = 2):

if verbose >= 3:
print "... set up the dropout rotate layer"
if rng is None:
rng = numpy.random
super(dropout_rotate_layer, self).__init__ (
input = input,
input_shape = input_shape,
id = id,
borrow = borrow,
verbose = verbose
)
if not dropout_rate == 0:
self.output = _dropout(rng = rng,
params = self.output,
dropout_rate = dropout_rate)
self.dropout_rate = dropout_rate
if verbose >=3:
print "... Dropped out"

if __name__ == '__main__':
pass
4 changes: 3 additions & 1 deletion yann/modules/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ def visualize_activities(self, layer_activities, epoch, index = 0, verbose = 2):
loc = self.root + '/activities/epoch_' + str(epoch)
if not os.path.exists(loc):
os.makedirs(loc)
for id, activity in layer_activities.iteritems():
for id, activity in layer_activities.iteritems():
imgs = activity(index)
if verbose >= 3:
print "... Visualizing Activities :: id = %s" % id
if len(imgs.shape) == 2:
if not os.path.exists(loc + '/layer_' + id):
os.makedirs(loc + '/layer_' + id)
Expand Down
57 changes: 57 additions & 0 deletions yann/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def add_layer(self, type, verbose = 2, **kwargs):
'merge' or 'join' - a layer that merges two layers.
'flatten' - a layer that produces a flattened output of a block data.
'random' - a layer that produces random numbers.
'rotate' - a layer that rotate the input images.
From now on everything is optional args..
id: <string> how to identify the layer by.
Default is just layer number that starts with ``0``.
Expand Down Expand Up @@ -184,6 +185,9 @@ def add_layer(self, type, verbose = 2, **kwargs):
coefficients.
error: ``merge`` layers take an option called ``'error'`` which can be None or others
which are methods in ``yann.core.errors``.
angle: Takes value between [0,1] to capture the angle between [0,180] degrees
Default is None. If None is specified, random number is generated from a uniform
distriibution between 0 and 1.
layer_type: If ``value`` supply, else it is default ``'discriminator'``
"""
Expand All @@ -206,6 +210,7 @@ def add_layer(self, type, verbose = 2, **kwargs):
type == 'flatten' or \
type == 'unflatten' or \
type == 'random' or \
type == 'rotate' or \
type == 'loss' or \
type == 'energy' or \
type == 'join':
Expand Down Expand Up @@ -265,6 +270,9 @@ def add_layer(self, type, verbose = 2, **kwargs):
elif type == 'random':
self._add_random_layer (id = id, options = kwargs, verbose = verbose)

elif type == 'rotate':
self._add_rotate_layer (id = id, options = kwargs, verbose = verbose)

else:
raise Exception('No layer called ' + type + ' exists in yann')

Expand Down Expand Up @@ -1295,6 +1303,55 @@ def _add_random_layer(self, id, options, verbose = 2):
options = options,
verbose =verbose)

def _add_rotate_layer(self, id, options, verbose = 2):
"""
This is an internal function. Use ``add_layer`` instead of this from outside the class.
Args:
options: Basically kwargs supplied to the add_layer function.
angle: Value between [0,1] to capture the rotation between [0,90] degrees
If None is specified, angle is generated randomly from a uniform dist
verbose: simiar to everywhere on the toolbox.
"""
if verbose >=3:
print "... Adding a rotate layer"

if not 'origin' in options.keys():
if self.last_layer_created is None:
raise Exception("You can't create a fully connected layer without an" + \
" origin layer.")
if verbose >=3:
print "... origin layer is not supplied, assuming the last layer created is."
origin = self.last_layer_created
else:
origin = options ["origin"]

from yann.layers.transform import rotate_layer as rl
from yann.layers.transform import dropout_rotate_layer as drl

input = self.layers[origin].output
dropout_input = self.dropout_layers[origin].output
input_shape = self.layers[origin].output_shape

if 'angle' in options.keys():
angle = options['angle']
else:
angle = None

self.layers[id] = rl(
input = input,
input_shape = input_shape,
id = id,
angle = angle,
verbose =verbose)

self.dropout_layers[id] = drl (
input = dropout_input,
input_shape = input_shape,
id = id,
angle = angle,
verbose = verbose)

def _initialize_test_classifier(self, errors, verbose):
"""
Internal function that creates a test method for a classifier network
Expand Down

0 comments on commit 1517ded

Please sign in to comment.