In [1]:
import autograd
from autograd import numpy as np

In [2]:
from collections import OrderedDict
import paragami

In [3]:
import time

a = np.full(100000, float('nan'))
b = np.random.random(a.shape)

def set_inplace(b, a):
    a[:] = b + 2
    return a

def set_notinplace(b):
    return b + 2

set_inplace(b, a)

t1 = time.time()
for iter in range(1000):
    set_inplace(b, a)
print('Inplace: ', time.time() - t1)

t1 = time.time()
for iter in range(1000):
    a = set_notinplace(b)
print('Not in place: ', time.time() - t1)

# WTF?  Try %memit and id

Inplace:  0.10280919075012207
Not in place:  0.03980541229248047


In [4]:


def check_dict_equal(dict1, dict2):
    assert dict1.keys() == dict2.keys()
    for key  in dict1:
        assert_array_almost_equal(dict1[key], dict2[key])

dict_pattern = paragami.OrderedDictPattern()
dict_pattern['a'] = paragami.NumericArrayPattern((2, 3, 4), lb=-1, ub=2)
dict_pattern['b'] = paragami.NumericArrayPattern((5, ), lb=-1, ub=10)
dict_pattern['c'] = paragami.NumericArrayPattern((5, 2), lb=-1, ub=10)

dict_val = dict_pattern.random()
print(dict_val)
dict_val2 = dict(dict_val)
print(dict_val2)
flat_val1 = dict_pattern.flatten(dict_val, free=True)
flat_val2 = dict_pattern.flatten(dict_val2, free=True)

np.linalg.norm(flat_val1 - flat_val2)

OrderedDict([('a', array([[[0.89358177, 0.89658284, 1.1121885 , 0.70436588],
        [0.83263673, 1.13825349, 0.90575051, 0.85797621],
        [0.89577067, 1.0274987 , 0.61030241, 0.77511769]],

       [[0.62469266, 1.17655873, 0.71461202, 1.08348495],
        [0.87832466, 1.10244331, 1.0606273 , 0.68987249],
        [1.15039823, 0.53624353, 0.90259403, 0.88781202]]])), ('b', array([6.67444852, 6.02320974, 4.63945515, 4.95775229, 4.87335276])), ('c', array([[6.33539152, 6.49540207],
       [5.32714664, 5.67921865],
       [5.74776878, 4.56019538],
       [4.51882782, 6.80849584],
       [6.55545194, 5.78203494]]))])
{'c': array([[6.33539152, 6.49540207],
       [5.32714664, 5.67921865],
       [5.74776878, 4.56019538],
       [4.51882782, 6.80849584],
       [6.55545194, 5.78203494]]), 'a': array([[[0.89358177, 0.89658284, 1.1121885 , 0.70436588],
        [0.83263673, 1.13825349, 0.90575051, 0.85797621],
        [0.89577067, 1.0274987 , 0.61030241, 0.77511769]],

       [[0.62469266, 1

0.0

In [5]:
param_dict = dict_pattern.random()

def test_function(param_dict):
    a = param_dict['a']
    b = param_dict['b']
    return np.mean(a ** 2) + np.mean(b ** 2)

test_function(param_dict)

35.14800648802951

In [11]:
class PatternedFunction:
    def __init__(self, original_fun, pattern, free, argnum=0):
        self._fun = original_fun
        self._argnum = argnum
        self.free = free
        self._pattern = pattern
        
    def __str__(self):
        return('Function: {}\nargnum: {}\nfree: {}\npattern: {}'.format(
            self._fun, self._argnum, self.free, self._pattern))
    
    def __call__(self, *args, **kwargs):
        flat_val = args[self._argnum]
        folded_val = self._pattern.fold(flat_val, free=self.free)
        new_args = args[0:self._argnum] + (folded_val, ) + args[self._argnum + 1:-1]
        return self._fun(*new_args, **kwargs)

    
patterned_test_function = PatternedFunction(test_function, dict_pattern, True)
print(patterned_test_function)

flat_param_dict = dict_pattern.flatten(param_dict, free=True)
assert(np.abs(patterned_test_function(flat_param_dict) - test_function(param_dict)) < 1e-12)

patterned_test_function_grad = autograd.grad(patterned_test_function)
patterned_test_function_grad(flat_param_dict)

patterned_test_function_hessian = autograd.hessian(patterned_test_function)
patterned_test_function_hessian(flat_param_dict)

Function: <function test_function at 0x7fe8a1a05f28>
argnum: 0
free: True
pattern: OrderedDict:
	[a] = Array (2, 3, 4) (lb=-1, ub=2)
	[b] = Array (5,) (lb=-1, ub=10)
	[c] = Array (5, 2) (lb=-1, ub=10)


array([[0.04656489, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.03302051, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.02695071, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [39]:
class FlattenedFunction:
    def __init__(self, original_fun, patterns, free, argnums=None):
        self._fun = original_fun
        self._patterns = np.atleast_1d(patterns)
        if argnums is None:
            argnums = np.arange(0, len(self._patterns))
        if len(self._patterns.shape) != 1:
            raise ValueError('patterns must be a 1d vector.')
        self._argnums = np.atleast_1d(argnums)
        self._argnum_sort = np.argsort(self._argnums)
        self.free = np.broadcast_to(free, self._patterns.shape)

        self._validate_args()
        
    def _validate_args(self):
        if len(self._argnums.shape) != 1:
            raise ValueError('argnums must be a 1d vector.')
        if len(self._argnums) != len(self._patterns):
            raise ValueError('argnums must be the same length as patterns.')
        if len(self.free.shape) != 1:
            raise ValueError('free must be a single boolean or a 1d vector of booleans.')
        if len(self.free) != len(self._patterns):
            raise ValueError('free must broadcast to the same shape as patterns.')
        
    def __str__(self):
        return('Function: {}\nargnums: {}\nfree: {}\npatterns: {}'.format(
            self._fun, self._argnums, self.free, self._patterns))
    
    def __call__(self, *args, **kwargs):
        # Loop through the arguments from beginning to end, replacing parameters
        # with their flattened values.
        new_args = ()
        last_argnum = 0
        for i in self._argnum_sort:
            argnum = self._argnums[i]
            folded_val = self._patterns[i].fold(args[argnum], free=self.free[i])
            new_args += args[last_argnum:argnum] + (folded_val, )
            last_argnum = argnum + 1
            
        return self._fun(*new_args, **kwargs)


def test_function(x, a, y, b):
    return x ** 2 + y ** 2 + np.mean(a ** 2) + np.mean(b ** 2)


x = 2
y = 3
a = param_dict['a']
b = param_dict['b']
a_flat = dict_pattern['a'].flatten(a, True)
b_flat = dict_pattern['b'].flatten(b, True)


patterns = [dict_pattern['a'], dict_pattern['b']]
flat_test_function = FlattenedFunction(
    test_function, patterns, free=[True, True], argnums=[1, 3])

assert(np.linalg.norm(
    test_function(x, a, y, b) - flat_test_function(x, a_flat, y, b_flat) < 1e-9))

print(flat_test_function)




def test_function(param_dict):
    a = param_dict['a']
    b = param_dict['b']
    return np.mean(a ** 2) + np.mean(b ** 2)

flat_test_function = FlattenedFunction(
    test_function, dict_pattern, free=True)
param_dict_flat = dict_pattern.flatten(param_dict, free=True)
assert(np.linalg.norm(
    test_function(param_dict) - flat_test_function(param_dict_flat) < 1e-9))



Function: <function test_function at 0x7fe8a14d3c80>
argnums: [1 3]
free: [ True  True]
patterns: [<paragami.numeric_array_patterns.NumericArrayPattern object at 0x7fe8a1a0ba20>
 <paragami.numeric_array_patterns.NumericArrayPattern object at 0x7fe8a1a0b8d0>]
