In [16]:
import theano
import theano.tensor as tt
import numpy as np
import itertools
from collections import OrderedDict

x = tt.tensor(name='x', dtype='floatX', broadcastable=(False, True, True))
y = tt.tensor(name='y', dtype='floatX', broadcastable=(True, False, True))
z = tt.tensor(name='z', dtype='floatX', broadcastable=(False, False, False))
full_shape = [x.shape[0], y.shape[1], z.shape[2]]
orig_vars = [x, y, z]

# Need to keep cache's keys in the same order as orig_vars
cache = OrderedDict()
for s in orig_vars:
    cache[s.name] = s
rep_cache = {i: tt.alloc(*itertools.chain([cache[i]], full_shape)) for i in cache.keys()}

# flat concatenation to use in Hessian
flat_params = tt.join(*itertools.chain([0], [rep_cache[i].flatten() for i in cache.keys()]))

# reconstructed variables to use to compute obj
flat_vars = tt.split(flat_params, len(orig_vars) * [np.multiply.reduce(full_shape)], n_splits=len(orig_vars))
new_vars = [i.reshape(full_shape) for i in flat_vars]

obj = (1./6)*new_vars[0]**3 + (1./6)*new_vars[1] ** 3 + new_vars[0]*new_vars[1] + (1./6)*new_vars[2]**3

grad = theano.grad(obj.sum(), flat_params)
hess, u = theano.scan(lambda i, gp, p: theano.grad(gp[i], p),
                      sequences=tt.mod(tt.arange(flat_params.shape[0]), len(orig_vars) * np.multiply.reduce(full_shape)),
                      non_sequences=[grad, flat_params])
# Not pretty, but it works
hess = hess.reshape((-1, np.multiply.reduce(full_shape))).sum(axis=-1)
hess = hess.reshape((len(orig_vars), np.multiply.reduce(full_shape), len(orig_vars))).transpose(1,0,2)
hess = hess.reshape(full_shape + [len(orig_vars), len(orig_vars)])

grad_reshaped = grad.reshape(full_shape + [len(orig_vars)])

ofunc = theano.function(list(cache.values()), obj, mode='FAST_RUN', on_unused_input='warn')
gfunc = theano.function(list(cache.values()), grad_reshaped, mode='FAST_RUN', on_unused_input='warn')
hfunc = theano.function(list(cache.values()), hess, mode='FAST_RUN', on_unused_input='warn')
x1 = np.arange(4, dtype=np.float).reshape(4,1,1)
y1 = np.arange(3, dtype=np.float).reshape(1,3,1)
z1 = np.arange(24, dtype=np.float).reshape(4,3,2)
ores = ofunc(x1, y1, z1)
gres = gfunc(x1, y1, z1)
hres = hfunc(x1, y1, z1)
print('obj: ', ores)
print('obj.shape: ', ores.shape)
print('grad: ', gres)
print('grad elements shape: ', gres.shape)
print('hess shape: ', hres.shape)
print('hres: ', hres)

('obj: ', array([[[  0.00000000e+00,   1.66666667e-01],
        [  1.50000000e+00,   4.66666667e+00],
        [  1.20000000e+01,   2.21666667e+01]],

       [[  3.61666667e+01,   5.73333333e+01],
        [  8.66666667e+01,   1.22833333e+02],
        [  1.70166667e+02,   2.25333333e+02]],

       [[  2.89333333e+02,   3.67500000e+02],
        [  4.60833333e+02,   5.66000000e+02],
        [  6.89333333e+02,   8.25500000e+02]],

       [[  9.76500000e+02,   1.14766667e+03],
        [  1.34100000e+03,   1.55116667e+03],
        [  1.78650000e+03,   2.03966667e+03]]]))
('obj.shape: ', (4, 3, 2))
('grad: ', array([[[[   0. ,    0. ,    1. ],
         [   1. ,    2. ,    2. ]],

        [[   0.5,    0.5,    1.5],
         [   1.5,    2.5,    2.5]],

        [[   2. ,    2. ,    3. ],
         [   3. ,    4. ,    4. ]]],


       [[[   4.5,    4.5,    5.5],
         [   5.5,    6.5,    6.5]],

        [[   0. ,    0. ,    0.5],
         [   0.5,    2. ,    2. ]],

        [[   1. ,    1. ,    

In [None]:
[0,1] * 2