In [1]:
import itertools
import os
import sys
from multiprocessing import Pool

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

sys.path.append('../../src')
from nnclib.utils import reshape_weights
# TODO(vatai): from nnclib import get_config

def get_epsilons_dict(model_name, base=os.path.expanduser("~/tmp/nnc_weights")):
    """TODO(vatai) proper documentation - just notes now.
    
    Args:
        model_name (str): Search for files starting with
            `model_name` in `base`.

        base (str): A path to the directory to search for 
            numpy arrays which are the weights.
        
    Returns:
        result (dict): A dictionary with a key value pair 
            for each file which starting with `model_name` 
            and the matrix of "epsilons" for the key's 
            value. 
    """
    from glob import glob
    from os.path import join


    pattern = join(base, model_name) + "*"
    file_list = glob(pattern)

    data = {}
    for file_name in file_list:        
        weights = np.load(file_name)
        norms = np.linalg.norm(weights, axis=0)
        sorted_weights = np.sort(weights, axis=0) / norms
        # plt.plot(sorted_weights)
        # plt.show()

        mean = np.mean(sorted_weights,
                       axis=1)[:, np.newaxis]
        layer_epsilons = np.abs(sorted_weights - mean)
        key = os.path.basename(file_name)
        data[key] = layer_epsilons
        # plt.plot(layer_epsilons)
        # plt.show()
    return data


def _apply_fn_list(weights, fn_list):
    """Apply `fn_list` (function list) to `weights`.
    Note: the functions are applied from last to first.

    Args:
        weights (list): List of numpy matrices.
        
        fn_list (list): list of (numpy) functions with 
            `len(fn_list)` equal 2 or 3.  As a 
            conveninience if the elements are strings,
    Returns:
        float: The result of unctions applied to the 
            weights in revers order.
    
    If `fn_list = [np.min, np.average, np.max], then 
    the return vaule is the minimum of the average of 
    the maximum of weights of columns of the weights.
    """

    n = len(fn_list)
    proc_if_str = lambda s: eval('np.' + s) if isinstance(s, str) else s
    fn_list = list(map(proc_if_str, fn_list))
    if n == 3:
        f = lambda w: fn_list[2](w, axis=0)
        weights = list(map(f, weights))
    if n >= 2:
        weights = list(map(fn_list[1], weights))
    weights = list(map(np.ravel, weights))
    weights = np.concatenate(weights)
    return fn_list[0](weights)


def _measure(model_names, fn_lists):
    output = {}
    for model_name in model_names:
        output[model_name] = []
        data = get_epsilons_dict(model_name)
        data = data.values()
        partial = lambda t: _apply_fn_list(data, t)
        result = list(map(partial, fn_lists))
        output[model_name] = result
    return output


def _latexify(results, fn_lists):
    trans = {
        'min': '\\min',
        'max': '\\max',
        'average': '\\operatorname{avg}',
        'median': '\\operatorname{med}'
    }
    fn_lists = [
        " \\circ ".join(map(lambda s: trans[s], lst)) for lst in fn_lists 
    ]
    header = " & ".join(fn_lists)
    output = [header]
    for key, val in results.items():
        fmt = list(map(lambda x: " {:.4f} ".format(x), val))
        line = " & ".join([key] + fmt)
        output.append(line)
    return "\\\\\n".join(output)


def eps_table(model_names, fn_lists):
    result = _measure(model_names, fn_lists)
    out = _latexify(result, fn_lists)
    return out


my_model_names = ['xception', 'vgg16', 'vgg19']
my_fns = [
    ['min'],
    ['max'],
    ['average'],
    ['average', 'min'],
    ['average', 'median', 'median']
]

test_funs = [np.max, np.min, np.average, np.median]

all_fns = [item for sublist in [itertools.product(test_funs, repeat=k) for k in range(1, 4)] for item in sublist]

# data = get_epsilons_dict(my_model_names[0])
# data = data.values()
# result = _apply_fn_list(data, my_fns[2])
# print(result)
etbl = eps_table(my_model_names, my_fns)
print(etbl)

\min & \max & \operatorname{avg} & \operatorname{avg} \circ \min & \operatorname{avg} \circ \operatorname{med} \circ \operatorname{med}\\
xception &  0.0000  &  0.9739  &  0.0031  &  0.0004  &  0.0752 \\
vgg16 &  0.0000  &  0.9371  &  0.0006  &  0.0001  &  0.0338 \\
vgg19 &  0.0000  &  0.9379  &  0.0007  &  0.0001  &  0.0287 
