In [48]:
import autograd.numpy as np
import autograd
from autograd.core import primitive
import copy

from collections import OrderedDict

from VariationalBayes import Parameters
from VariationalBayes.Parameters import \
    ScalarParam, VectorParam, ArrayParam, \
    PosDefMatrixParam, PosDefMatrixParamVector
from VariationalBayes.ParameterDictionary import ModelParamsDict
import scipy as sp


In [7]:
k = 2

mat = np.full(k ** 2, 0.2).reshape(k, k) + np.eye(k)
vp_mat1 = PosDefMatrixParam('mat1', k, val=mat)
vp_mat2 = PosDefMatrixParam('mat2', k, val=mat * 2.)

mp = ModelParamsDict()
mp.push_param(vp_mat1)
mp.push_param(vp_mat2)

print(mp)

free_vec = mp.get_free()


ModelParamsDict:
	mat1:
[[ 1.2  0.2]
 [ 0.2  1.2]]
	mat2:
[[ 2.4  0.4]
 [ 0.4  2.4]]


In [44]:
def get_param(mp, free_vec, par_name):
    mp[par_name].set_free(free_vec[mp.free_indices_dict[par_name]])
    return mp[par_name].get()

print(get_param(mp, free_vec, 'mat1'))
print(get_param(mp, free_vec, 'mat2'))

get_param_jac = autograd.jacobian(get_param, argnum=1)

print('--------')
print(get_param_jac(mp, free_vec, 'mat1'))
print('--------')
print(get_param_jac(mp, free_vec, 'mat2'))



[[ 1.2  0.2]
 [ 0.2  1.2]]
[[ 2.4  0.4]
 [ 0.4  2.4]]
--------
[[[ 2.4         0.          0.          0.          0.          0.        ]
  [ 0.2         1.09544512  0.          0.          0.          0.        ]]

 [[ 0.2         1.09544512  0.          0.          0.          0.        ]
  [ 0.          0.36514837  2.33333333  0.          0.          0.        ]]]
--------
[[[ 0.          0.          0.          4.8         0.          0.        ]
  [ 0.          0.          0.          0.4         1.54919334  0.        ]]

 [[ 0.          0.          0.          0.4         1.54919334  0.        ]
  [ 0.          0.          0.          0.          0.51639778  4.66666667]]]


In [61]:
@primitive
def get_param_sparse(mp, free_vec, par_name):
    return get_param(mp, free_vec, par_name)

def get_free_vec(mp, free_vec, par_name):
    return free_vec[mp.free_indices_dict[par_name]]

def set_free_and_get(free_vec_par, par):
    par.set_free(free_vec_par)
    return par.get()

mat1_sub_vec = get_free_vec(mp, free_vec, 'mat1')
mat2_sub_vec = get_free_vec(mp, free_vec, 'mat2')
print(set_free_and_get(mat1_sub_vec, mp['mat1']))
print(set_free_and_get(mat2_sub_vec, mp['mat2']))

jac_dict = OrderedDict()
jac_dict['mat1'] = autograd.jacobian(lambda free_sub_vec: set_free_and_get(free_sub_vec, mp['mat1']))
jac_dict['mat2'] = autograd.jacobian(lambda free_sub_vec: set_free_and_get(free_sub_vec, mp['mat2']))

print("-----------------")
print(jac_dict['mat1'](mat1_sub_vec))
print("-----------------")
print(jac_dict['mat2'](mat2_sub_vec))
print("-----------------")


[[ 1.2  0.2]
 [ 0.2  1.2]]
[[ 2.4  0.4]
 [ 0.4  2.4]]
-----------------
[[[ 2.4         0.          0.        ]
  [ 0.2         1.09544512  0.        ]]

 [[ 0.2         1.09544512  0.        ]
  [ 0.          0.36514837  2.33333333]]]
-----------------
[[[ 4.8         0.          0.        ]
  [ 0.4         1.54919334  0.        ]]

 [[ 0.4         1.54919334  0.        ]
  [ 0.          0.51639778  4.66666667]]]
-----------------


In [98]:
foo = np.random.random((2, 2, 3))
bar = np.random.random((2, 2))

print((foo * np.expand_dims(bar, axis=2)).shape)
print(np.sum(foo * np.expand_dims(bar, axis=2), (0, 1)))
print(np.sum(foo * np.expand_dims(bar, axis=2), -2))
print(np.sum([ foo[:, :, k] * bar for k in range(3) ], (1, 2)))

(2, 2, 3)
[ 0.35932676  0.70286063  0.7597611 ]
[[ 0.30010286  0.60179243  0.53072853]
 [ 0.0592239   0.1010682   0.22903258]]
[ 0.35932676  0.70286063  0.7597611 ]


In [106]:
def get_param_sparse_vjp(g, ans, vs, gvs, mp, free_vec, par_name):
    jac = jac_dict[par_name](get_free_vec(mp, free_vec, par_name))
#     print('++++++++++++++')
#     print(g)
#     print(par_name)
#     print(jac)
#     print(jac.shape)
#     print('++++++++++++++')
    par_jac = np.sum(jac * np.expand_dims(g, axis=2), (0, 1))
    full_jac = np.zeros(free_vec.shape)
    full_jac[mp.free_indices_dict[par_name]] = par_jac
    return full_jac

get_param_sparse.defvjp(get_param_sparse_vjp, argnum=1)

get_param_sparse_jac = jacobian(get_param_sparse, argnum=1)
print(get_param_sparse_jac(mp, free_vec, 'mat2') - get_param_jac(mp, free_vec, 'mat2'))

print('++++++++++++++')
print(get_param_sparse_jac(mp, free_vec, 'mat1'))
print('++++++++++++++')
print(get_param_jac(mp, free_vec, 'mat1'))



[[[ 0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.]]]
++++++++++++++
[[[ 2.4         0.          0.          0.          0.          0.        ]
  [ 0.2         1.09544512  0.          0.          0.          0.        ]]

 [[ 0.2         1.09544512  0.          0.          0.          0.        ]
  [ 0.          0.36514837  2.33333333  0.          0.          0.        ]]]
++++++++++++++
[[[ 2.4         0.          0.          0.          0.          0.        ]
  [ 0.2         1.09544512  0.          0.          0.          0.        ]]

 [[ 0.2         1.09544512  0.          0.          0.          0.        ]
  [ 0.          0.36514837  2.33333333  0.          0.          0.        ]]]
