In [1]:
import theano, theano.tensor as T
from theano.tensor import constant as tconstant
import numpy as np
from theano import shared
from theano.ifelse import ifelse
from theano.compile.debugmode import DebugMode
from collections import OrderedDict
srng = theano.tensor.shared_randomstreams.RandomStreams(1234)
np_rng = np.random.RandomState(1234)

In [2]:
class Layer(object):
    """
    Base object for neural network layers.

    A layer has an input set of neurons, and
    a hidden activation. The activation, f, is a
    function applied to the affine transformation
    of x by the connection matrix W, and the bias
    vector b.

    > y = f ( W * x + b )

    """

    def __init__(self, input_size, hidden_size, activation, clip_gradients=False):
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.activation  = activation
        self.clip_gradients = clip_gradients
        self.is_recursive = False
        self.create_variables()

    def create_variables(self):
        """
        Create the connection matrix and the bias vector
        """
        self.linear_matrix        = create_shared(self.hidden_size, self.input_size, name="Layer.linear_matrix")
        self.bias_matrix          = create_shared(self.hidden_size, name="Layer.bias_matrix")

    def activate(self, x):
        """
        The hidden activation of the network
        """
        if self.clip_gradients is not False:
            x = clip_gradient(x, self.clip_gradients)

        if x.ndim > 1:
            return self.activation(
                T.dot(self.linear_matrix, x.T) + self.bias_matrix[:,None] ).T
        else:
            return self.activation(
                T.dot(self.linear_matrix, x) + self.bias_matrix )

    @property
    def params(self):
        return [self.linear_matrix, self.bias_matrix]

    @params.setter
    def params(self, param_list):
        self.linear_matrix.set_value(param_list[0].get_value())
        self.bias_matrix.set_value(param_list[1].get_value())
        
class MDLSTMLayer(Layer):
    """Multi-dimensional long short-term memory cell layer.

    The cell-states are explicitly passed on through a part of
    the input/output buffers (which should be connected correctly with IdentityConnections).

    The input consists of 4 parts, in the following order:
    - input gate
    - forget gates (1 per dim)
    - cell input
    - output gate
    - previous states (1 per dim)

    The output consists of two parts:
    - cell output
    - current statte

    """
    def __init__(self, input_size, hidden_size, shape, activation=T.tanh, clip_gradients=False):
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.activation  = activation
        self.clip_gradients = clip_gradients
        self.is_recursive = True
        self.create_variables()
        self.shape = shape
        
    def create_variables(self):
        """
        Create the different LSTM gates and
        their variables, along with the initial
        hidden state for the memory cells and
        the initial hidden activation.

        """
        # input gate for cells
        self.in_gate_layer     = Layer(self.input_size + 2*self.hidden_size, self.hidden_size, T.nnet.sigmoid, self.clip_gradients)
        # forget gate for cells
        self.forget_gate_x_layer = Layer(self.input_size + 2*self.hidden_size, self.hidden_size, T.nnet.sigmoid, self.clip_gradients)
        self.forget_gate_y_layer = Layer(self.input_size + 2*self.hidden_size, self.hidden_size, T.nnet.sigmoid, self.clip_gradients)
        # input modulation for cells
        self.in_cell_layer    = Layer(self.input_size + 2*self.hidden_size, self.hidden_size, T.tanh, self.clip_gradients)
        # output modulation
        self.out_gate_layer    = Layer(self.input_size + 2*self.hidden_size, self.hidden_size, T.nnet.sigmoid, self.clip_gradients)

        # keep these layers organized
        self.internal_layers = [self.in_gate_layer, self.forget_gate_x_layer, self.forget_gate_y_layer, self.in_cell_layer, self.out_gate_layer]

        # store the memory cells in first n spots, and store the current
        # output in the next n spots:
    @property
    def params(self):
        """
        Parameters given by the 4 gates and the
        initial hidden activation of this LSTM cell
        layer.
        """
        return [param for layer in self.internal_layers for param in layer.params]
    @params.setter
    def params(self, param_list):
        start = 0
        for layer in self.internal_layers:
            end = start + len(layer.params)
            layer.params = param_list[start:end]
            start = end
    def postprocess_activation(self, x, *args):
        if x.ndim > 1:
            return x[:, self.hidden_size:]
        else:
            return x[self.hidden_size:]


        
    def create_prediction(self, x_input):
        directions = tconstant(np.array([[-1,-1],[-1,1],[1,-1],[1,1]]))
        res,update = theano.scan(lambda direction, xipt: self.create_prediction_once(xipt, direction),
                    sequences = [directions],
                    outputs_info = None,
                    non_sequences = [x_input])
        return T.reshape(res,(4*self.hidden_size, self.shape[0], self.shape[1]))
    def create_prediction_once(self, x_input, direction):
        ''' input: 
                x_input: a tensor of size H*W*C
                shape: spatial dimension of x_input and x_output: (H,W)
                direction:(-1,-1): from top left
                          (-1, 1): from top right
                          ( 1,-1): from bottom left
                          ( 1, 1): from bottom right                            
            output:
                x_output: a tensor of size H*W*K
        '''
        #   e.g. for shape of (3,3), the scanning order of LSTM would be [0,1,2,3,4,5,6,7,8]
        scan_order = self.permsForSwiping(direction)
        
        # store previous cell states and hidden activations
#         hidden_mat = T.zeros([self.hidden_size, self.shape[0], self.shape[1]], dtype=theano.config.floatX)
#         cell_mat = T.zeros([self.hidden_size, self.shape[0], self.shape[1]], dtype=theano.config.floatX)
            
        CH = T.zeros([self.shape[0], 2*self.hidden_size])       
        
        xyv = tconstant(np.array([np.meshgrid(range(self.shape[0]),range(self.shape[1]))]))
        xyv_seq = T.transpose(T.reshape(xyv,(2,self.shape[0]*self.shape[1])))
        cell_order = xyv_seq[scan_order,:]
        
        # flatten input
        X_input_reshaped = T.transpose(T.reshape(T.transpose(x_input,[0,2,1]), (self.input_size,self.shape[0]*self.shape[1])))
        X_input_reshaped = X_input_reshaped[scan_order,:]
        
        res, updates = theano.scan(lambda x_inpt, cell_loc, ch_x, ch_y, dire:
                                        self.activate(x_inpt,T.concatenate([
                                        ifelse((T.lt(cell_loc[0]+dire[0],shared(0)) | T.ge(cell_loc[0]+dire[0],self.shape[0])),
                                        T.zeros(2*self.hidden_size),ch_x),
                                        ifelse((T.lt(cell_loc[1]+dire[1],shared(0)) | T.ge(cell_loc[1]+dire[1],self.shape[1])),
                                        T.zeros(2*self.hidden_size),ch_y)])),
                                        sequences = [X_input_reshaped,cell_order],
                                       
                                        outputs_info=[dict(initial=CH, taps=[-1, -self.shape[0]])],
                                        non_sequences = direction)
        hidden_mat = T.reshape(res[:,self.hidden_size:],(self.hidden_size,self.shape[0],self.shape[1]))
        return hidden_mat
  
    def permsForSwiping(self, direction):
        """ Given the spatial dimension of the input
            Return the correct permutations of blocks for all swiping direction.
        """
        identity = T.arange(T.prod(self.shape))    
        identity_flipped = T.reshape(identity,self.shape,ndim=2)[::-1,:].flatten()

        perms = ifelse(T.all(T.eq(direction,T.constant([-1,-1]))), identity,
                         ifelse(T.all(T.eq(direction,T.constant([1,1]))),identity[::-1],
                         ifelse(T.all(T.eq(direction,T.constant([-1,1]))),identity_flipped,identity_flipped[::-1])))
        return perms
    def activate(self, x, h):
        """
        The hidden activation, h, of the network, along
        with the new values for the memory cells, c,
        Both are concatenated as follows:

        >      y = f( x, past )

        Or more visibly, with past = [prev_c, prev_h]

        > [c, h] = f( x, [prev_c, prev_h] )
        
        Currently we don't have peephole connections.
        
        """
        
        if h.ndim > 1:
            prev_c_x = h[:, :self.hidden_size]
            #previous activations of the hidden layer in x direction
            prev_h_x = h[:, self.hidden_size:self.hidden_size+self.hidden_size]
            #previous memory cell values in y direction
            prev_c_y = h[:, self.hidden_size+self.hidden_size:2*self.hidden_size+self.hidden_size]
            #previous activations of the hidden layer in x direction
            prev_h_y = h[:, 2*self.hidden_size+self.hidden_size:]
        else:
            prev_c_x = h[:self.hidden_size]
            #previous activations of the hidden layer in x direction
            prev_h_x = h[self.hidden_size:2*self.hidden_size]
            #previous memory cell values in y direction
            prev_c_y = h[2*self.hidden_size:3*self.hidden_size]
            #previous activations of the hidden layer in x direction
            prev_h_y = h[3*self.hidden_size:]
            
        # input and previous hidden states in two directions constitute the actual
        # input to the LSTM:
        if h.ndim > 1:
            obs = T.concatenate([x, prev_h_x, prev_h_y], axis=1)
        else:
            obs = T.concatenate([x, prev_h_x, prev_h_y], axis=0)
            
        # input gate
        in_gate = self.in_gate_layer.activate(obs)
        # forget (in two directions)
        forget_gate_x = self.forget_gate_x_layer.activate(obs)
        forget_gate_y = self.forget_gate_y_layer.activate(obs)
        
        # compute cell input 
        in_cell = self.in_cell_layer.activate(obs)
        # new memory cells
        next_c = forget_gate_x * prev_c_x + forget_gate_y * prev_c_y +in_cell * in_gate
        
        # output gate
        out_gate = self.out_gate_layer.activate(obs)
        # new hidden states
        next_h = out_gate * T.tanh(next_c)     
        if h.ndim > 1:
            return T.concatenate([next_c, next_h], axis=1)
        else:
            return T.concatenate([next_c, next_h], axis=0)
        
def create_shared(hidden_size, in_size=None, name=None):
    """
    Creates a shared matrix or vector
    using the given in_size and hidden_size.

    Inputs
    ------

    hidden_size int            : outer dimension of the
                              vector or matrix
    in_size  int (optional) : for a matrix, the inner
                              dimension.

    Outputs
    -------

    theano shared : the shared matrix, with random numbers in it

    """

    if in_size is None:
        return theano.shared(random_initialization((hidden_size, )), name=name)
    else:
        return theano.shared(random_initialization((hidden_size, in_size)), name=name)
    
def random_initialization(size):
    return (np_rng.standard_normal(size) * 1. / size[0]).astype(theano.config.floatX)


In [3]:
x = np.random.rand(3,4,6)
net = MDLSTMLayer(3, 2,(4,6))  
theano.config.compute_test_value = 'warn'
img = T.tensor3()
img.tag.test_value =  x
direction = T.vector()
direction.tag.test_value = [-1,-1]
f = theano.function(inputs=[img], outputs=[net.create_prediction(img)])
g = theano.function(inputs=[img,direction], outputs=[net.create_prediction_once(img,direction)])

In [4]:
[res1] = g(x,[-1,-1])
[res2] = f(x)
print res1,'\n',res2[0:2]

[[[ 0.09133504 -0.01384627  0.13693459 -0.01521263  0.18107321 -0.00380177]
  [ 0.26117384 -0.0066828   0.13086979 -0.04064852  0.2067028  -0.04933781]
  [ 0.23990984 -0.04504092  0.27120143 -0.06429556  0.24720783 -0.05795617]
  [ 0.50238457 -0.10902152  0.43283703 -0.12133762  0.49449936 -0.13653737]]

 [[ 0.39359728 -0.08791297  0.53469274 -0.17976823  0.43371281 -0.24976274]
  [ 0.70525102 -0.31407452  0.41415621 -0.10011591  0.33159066 -0.23904572]
  [ 0.5999414  -0.3490413   0.67881947 -0.40292604  0.38532258 -0.10804207]
  [ 0.27466529 -0.23058987  0.44999066 -0.4540891   0.64843316 -0.48363151]]] 
[[[ 0.09133504 -0.01384627  0.13693459 -0.01521263  0.18107321 -0.00380177]
  [ 0.26117384 -0.0066828   0.13086979 -0.04064852  0.2067028  -0.04933781]
  [ 0.23990984 -0.04504092  0.27120143 -0.06429556  0.24720783 -0.05795617]
  [ 0.50238457 -0.10902152  0.43283703 -0.12133762  0.49449936 -0.13653737]]

 [[ 0.39359728 -0.08791297  0.53469274 -0.17976823  0.43371281 -0.24976274]
  [ 0