In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from functools import partial

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze

from torchsummary import summary
# from torchvision.models import resnet18, resnet50
from bagnet import *

import matplotlib.pyplot as plt

from gpax import *
from jax_models import *
from resnet import *


from typing import Type, Any, Callable, Union, List, Optional


In [None]:

## Computes receptive fields for bagnet

resnet_versions = ['18', '50']
num_conv3x3_per_stages = [[1,1,0,0],[1,1,1,0],[1,1,1,1],
                          [2,1,1,1],[2,2,1,1],[2,2,2,1],[2,2,2,2]]

print('Receptive Fields')

for resnet_version in resnet_versions:
    for num_conv3x3_per_stage in num_conv3x3_per_stages:
        if resnet_version == '18':
            stage_sizes = [2,2,2,2]
            block_cls = BagNetBlock
        if resnet_version == '50':
            stage_sizes = [3,4,6,3]
            block_cls = BottleneckBagNetBlock
        model_def = partial(BagNetTrunk, stage_sizes=stage_sizes,
                                        block_cls=block_cls,
                                        num_conv3x3_per_stage=num_conv3x3_per_stage,
                                        disable_bn=True)
        rf, gx, _ = compute_receptive_fields(model_def, (1, 224, 224, 1))
        print(f'[{resnet_version}] {num_conv3x3_per_stage}:\t {rf}')

#         from plt_utils import plt_scaled_colobar_ax
#         fig, axs = plt.subplots(1,2,figsize=(30,15))
#         ax = axs[0]
#         ax.hist(gx.flatten())

#         ax = axs[1]
#         im = gx.squeeze()
#         im = ( im-np.min(im) ) / np.ptp(im)
#         pltim = ax.imshow(im, vmin=0, vmax=1, cmap='bwr')
#         fig.colorbar(pltim, cax=plt_scaled_colobar_ax(ax))
#         ax.grid()


In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"


import jax
import jax.numpy as np
from jax import random, vjp, vmap

from gpax import *
from setup_convgp import *


def get_config_mnist():
    import copy
    config_base = copy.deepcopy(get_config_base())
    config = ml_collections.ConfigDict(config_base)
    
    config.image_shape = (32, 32, 3)
    config.patch_shape = (10, 10)
#     config.patch_encoder = ''

    config.output_dim = 10
    
    config.n_inducing = 40
    config.inducing_init_fn = 'random'
    
    config.T_type = 'transl'
    config.use_loc_kernel = True
    
    return config

config = get_config_mnist()


key = random.PRNGKey(0)
X_train = random.normal(key, (10,32,32,3))
model_cls, k_cls, lik_cls, inducing_loc_cls, transform_cls = \
    get_model_cls(key, config, X_train)

model = model_cls()
params = model.get_init_params(model, key, X_shape=config.image_shape)
print(model)
pytree_keys(params)

# (40, 64) (2, 7, 7, 64)


In [None]:

in_shape = (1,10,10,1)
x = random.normal(key, in_shape)

# m = CNNMnistTrunk()
m = BagNet18x19Trunk()


# params = m.init(key, x)
# m = m.bind(params)






In [None]:
# For `model_def`, how much padding needed 
# such that the padded patch, going through the encoder
# is the patch only at `spatial_coord` over feature space
#
# model_def, pad_hw, spatial_coord, z_shape, receptive_field

patch_response_info = [
    (CNNMnistTrunk,    (1,  1), (1, 1), (3, 3), 10),
    (BagNet18x11Trunk, (10, 0), (1, 1), (2, 2), 11),
    (BagNet18x19Trunk, (10, 0), (1, 1), (2, 2), 19),
    (BagNet18x35Trunk, (10, 0), (1, 1), (3, 3), 35),
    (BagNet18x47Trunk, (1, 10), (1, 1), (4, 4), 47), # not exactly match, but close ...
    (BagNet18x63Trunk, (9,  2), (2, 2), (5, 5), 63), # not exactly match, but close ...
    (BagNet18x95Trunk, (9,  2), (3, 3), (7, 7), 95), # not exactly match, but cloee ...
]
model_def, pad_hw, spatial_coord, rf_len = patch_response_info[5]



In [None]:
# model_def = CNNMnistTrunk; in_shape = (1,12,12,1)


model_def = partial(model_def, disable_bn=True)
pad_len = np.sum(np.array(pad_hw)).item()
in_shape = (1, rf_len+pad_len, rf_len+pad_len, 1)

print('in_shape: ', in_shape)


image_shape = in_shape[1:1+2]  # ndim=2
rf, _, gy = compute_receptive_fields(model_def, in_shape); rf=np.array([rf_len,rf_len])
Py, Px = gy.shape[1:3]
P = Py*Px
spike_locs = list(itertools.product(np.arange(Py), np.arange(Px)))
spike_locs = np.array(spike_locs, dtype=np.int32)

x = np.ones(in_shape)
key = random.PRNGKey(0)
model = model_def()
params = model.init(key, x)
params = unfreeze(params)
params['params'] = jax.tree_map(lambda w: random.normal(key, w.shape),
                                params['params'])
params = freeze(params)

def f(x): return model.apply(params, x)
y, vjp_fn = vjp(f, x)
print('z.shape: ', y.shape)

def construct_gy(spike_loc):
    if len(spike_loc) != 2:
        raise ValueError(f'len(spike_loc)={len(spike_loc)}')
    gy = np.zeros(y.shape)
    ind = jax.ops.index[0, spike_loc[0], spike_loc[1], ...]
    gy = jax.ops.index_update(gy, ind, 1)
    gx = vjp_fn(gy)[0]
    return gx
# (P, *image_shape)
gx = []
for spike_loc in spike_locs:
    gx.append(construct_gy(spike_loc))
gx = np.vstack(gx)

ind = []
for p in range(len(gx)):
    gxp = gx[p]
    I = np.where(gxp > np.mean(gxp)*.1)
    ind.append([(np.min(idx), np.max(idx)) for idx in I])

# (P, 3, 2)
ind = np.array(ind)[:, [0, 1]]
# (P, hi/wi, min/max)
if not np.all((ind[:, :, 1]-ind[:, :, 0]+1) <= rf):
    # Note patches on boundary have < receptive_field!
    raise ValueError('leaky gradient, `gx` has'
                     'more nonzero entries than possible')
# (P, hi/wi)
ind_start = ind[:, :, 0]


print('grad_wrt_X: ', gx.shape)

ind_flat = np.ravel_multi_index(np.array([[1,],[1,]]), (y.shape[1], y.shape[2]))[0]
print('ind_flat', ind_flat)
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx[ind_flat], cmap='Greys')
ax.grid(color='r')

