In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from functools import partial

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze

from torchsummary import summary
# from torchvision.models import resnet18, resnet50
from bagnet import *

import matplotlib.pyplot as plt

from gpax import *
from jax_models import *
from resnet import *


from typing import Type, Any, Callable, Union, List, Optional


In [None]:

## Computes receptive fields for bagnet

resnet_versions = ['18', '50']
num_conv3x3_per_stages = [[1,1,0,0],[1,1,1,0],[1,1,1,1],
                          [2,1,1,1],[2,2,1,1],[2,2,2,1],[2,2,2,2]]

print('Receptive Fields')

for resnet_version in resnet_versions:
    for num_conv3x3_per_stage in num_conv3x3_per_stages:
        if resnet_version == '18':
            stage_sizes = [2,2,2,2]
            block_cls = BagNetBlock
        if resnet_version == '50':
            stage_sizes = [3,4,6,3]
            block_cls = BottleneckBagNetBlock
        model_def = partial(BagNetTrunk, stage_sizes=stage_sizes,
                                        block_cls=block_cls,
                                        num_conv3x3_per_stage=num_conv3x3_per_stage,
                                        disable_bn=True)
        rf, gx, _ = compute_receptive_fields(model_def, (1, 224, 224, 1))
        print(f'[{resnet_version}] {num_conv3x3_per_stage}:\t {rf}')

#         from plt_utils import plt_scaled_colobar_ax
#         fig, axs = plt.subplots(1,2,figsize=(30,15))
#         ax = axs[0]
#         ax.hist(gx.flatten())

#         ax = axs[1]
#         im = gx.squeeze()
#         im = ( im-np.min(im) ) / np.ptp(im)
#         pltim = ax.imshow(im, vmin=0, vmax=1, cmap='bwr')
#         fig.colorbar(pltim, cax=plt_scaled_colobar_ax(ax))
#         ax.grid()


In [None]:
model_defs = [
    # ('CNNMnistTrunk', CNNMnistTrunk),
    ('BagNet18x11Trunk', BagNet18x11Trunk),
    ('BagNet18x19Trunk', BagNet18x19Trunk),
    ('BagNet18x35Trunk', BagNet18x35Trunk),
    ('BagNet18x47Trunk', BagNet18x47Trunk),
    ('BagNet18x63Trunk', BagNet18x63Trunk),
    ('BagNet18x95Trunk', BagNet18x95Trunk),
    ('BagNet50x11Trunk', BagNet50x11Trunk),
    ('BagNet50x19Trunk', BagNet50x19Trunk),
    ('BagNet50x35Trunk', BagNet50x35Trunk),
    ('BagNet50x47Trunk', BagNet50x47Trunk),
    ('BagNet50x63Trunk', BagNet50x63Trunk),
    ('BagNet50x95Trunk', BagNet50x95Trunk),
]

in_shape = (1, 224, 224, 1)

for name, model_def in model_defs:
    if name.startswith('BagNet'):
        model_def = partial(model_def, disable_bn=True)
    start_ind, rf = compute_receptive_fields_start_ind_extrap(model_def, in_shape)
    
    
    

In [None]:

class CNNMnistTrunk(nn.Module):
    # in_shape: (1, 28, 28, 1)
    # receptive field: (10, 10)

    @nn.compact
    def __call__(self, x):
        # (N, H, W, C)
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        # (N, Px, Py, L)
        return x

    def gz(self, x):
        if x.ndim == 3:
            x = x[..., np.newaxis]  # add color channel
        x = np.pad(x, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)),
                   mode='constant',
                   constant_values=0)
        x = self.__call__(x)
        if x.shape[1:3] != (3, 3):
            raise ValueError(
                'CNNMnistTrunk.gz(x) does not have correct shape')
        x = x[:, 1, 1, :].squeeze()
        return x

    @property
    def receptive_field(self):
        return 10

    @classmethod
    def get_start_ind(self, image_shape):
        if  image_shape[:2] == (14, 14):
            start_ind = np.array([[0, 0], [0, 1], [0, 5], [1, 0], [
                                 1, 1], [1, 5], [5, 0], [5, 1], [5, 5]])
        elif image_shape[:2] == (28, 28):
            start_ind = np.array([[-3, -3], [-3, 1], [-3, 5], [-3, 9], [-3, 13], [-3, 17], [-3, 21], [1, -3], [1, 1], [1, 5], [1, 9], [1, 13], [1, 17], [1, 21], [5, -3], [5, 1], [5, 5], [5, 9], [5, 13], [5, 17], [5, 21], [9, -3], [9, 1], [9, 5], [9, 9], [9, 13], [9, 17], [9, 21], [13, -3], [13, 1], [13, 5], [13, 9], [13, 13], [13, 17], [13, 21], [17, -3], [17, 1], [17, 5], [17, 9], [17, 13], [17, 17], [17, 21], [21, -3], [21, 1], [21, 5], [21, 9], [21, 13], [21, 17], [21, 21]])
        elif image_shape[:2] == (32, 32):
            start_ind = np.array([[-3, -3], [-3, 1], [-3, 5], [-3, 9], [-3, 13], [-3, 17], [-3, 21], [-3, 25], [1, -3], [1, 1], [1, 5], [1, 9], [1, 13], [1, 17], [1, 21], [1, 25], [5, -3], [5, 1], [5, 5], [5, 9], [5, 13], [5, 17], [5, 21], [5, 25], [9, -3], [9, 1], [9, 5], [9, 9], [9, 13], [9, 17], [9, 21], [9, 25], [13, -3], [13, 1], [13, 5], [13, 9], [13, 13], [13, 17], [13, 21], [13, 25], [17, -3], [17, 1], [17, 5], [17, 9], [17, 13], [17, 17], [17, 21], [17, 25], [21, -3], [21, 1], [21, 5], [21, 9], [21, 13], [21, 17], [21, 21], [21, 25], [25, -3], [25, 1], [25, 5], [25, 9], [25, 13], [25, 17], [25, 21], [25, 25]])
        else:
            start_ind, _ = compute_receptive_fields_start_ind_extrap(
                CNNMnistTrunk, (1,)+image_shape)
        return start_ind

    @classmethod
    def get_XL(self, image_shape=(28, 28, 1)):
        if image_shape[:2] not in [(14, 14), (28, 28), (32, 32)]:
            raise ValueError('CNNMnistTrunk.getXL() invalid image shape')
        patch_size = (10, 10)
        start_ind = self.get_start_ind(image_shape)
        scal, transl = startind_to_scal_transl(
            image_shape[:2], patch_size, start_ind)
        scal = np.repeat(scal[np.newaxis, ...], len(transl), axis=0)
        XL = np.column_stack([scal, transl])
        return XL
    
    
    

key = random.PRNGKey(0)
model_def = CNNMnistTrunk
in_shape = (1,10,10,1)
x = random.normal(key, in_shape)
m = model_def()
params = m.init(key, x)
m = m.bind(params)

z = m.gz(x)

print(x.shape, '->', z.shape)



In [None]:
# For `model_def`, how much padding needed 
# such that the padded patch, going through the encoder
# is the patch only at `spatial_coord` over feature space
#
# model_def, pad_hw, spatial_coord

gz_info = [
    (BagNet18x11Trunk, (10, 0), (1, 1), 11),
    (BagNet18x19Trunk, (10, 0), (1, 1), 19),
    (BagNet18x35Trunk, (10, 0), (1, 1), 35),
    (BagNet18x47Trunk, (10, 0), (2, 2), 47),
#     (BagNet18x63Trunk, (10, 0), (2, 2), 63),
]
model_def, pad_hw, spatial_coord, rf_len = gz_info[3]




In [None]:
# model_def = CNNMnistTrunk; in_shape = (1,12,12,1)


model_def = partial(model_def, disable_bn=True)
pad_len = np.sum(np.array(pad_hw)).item()
in_shape = (1, rf_len+pad_len, rf_len+pad_len, 1)

print('in_shape: ', in_shape)


image_shape = in_shape[1:1+2]  # ndim=2
rf, _, gy = compute_receptive_fields(model_def, in_shape); rf=np.array([rf_len,rf_len])
Py, Px = gy.shape[1:3]
P = Py*Px
spike_locs = list(itertools.product(np.arange(Py), np.arange(Px)))
spike_locs = np.array(spike_locs, dtype=np.int32)

x = np.ones(in_shape)
key = random.PRNGKey(0)
model = model_def()
params = model.init(key, x)
params = unfreeze(params)
params['params'] = jax.tree_map(lambda w: 1e-1*random.normal(key, w.shape),
                                params['params'])
params = freeze(params)

def f(x): return model.apply(params, x)
y, vjp_fn = vjp(f, x)
print('z.shape: ', y.shape)

def construct_gy(spike_loc):
    if len(spike_loc) != 2:
        raise ValueError(f'len(spike_loc)={len(spike_loc)}')
    gy = np.zeros(y.shape)
    ind = jax.ops.index[0, spike_loc[0], spike_loc[1], ...]
    gy = jax.ops.index_update(gy, ind, 1)
    gx = vjp_fn(gy)[0]
    return gx
# (P, *image_shape)
gx = []
for spike_loc in spike_locs:
    gx.append(construct_gy(spike_loc))
gx = np.vstack(gx)

ind = []
for p in range(len(gx)):
    gxp = gx[p]
    I = np.where(gxp > np.mean(gxp)*.1)
    ind.append([(np.min(idx), np.max(idx)) for idx in I])

# (P, 3, 2)
ind = np.array(ind)[:, [0, 1]]
# (P, hi/wi, min/max)
if not np.all((ind[:, :, 1]-ind[:, :, 0]+1) <= rf):
    # Note patches on boundary have < receptive_field!
    raise ValueError('leaky gradient, `gx` has'
                     'more nonzero entries than possible')
# (P, hi/wi)
ind_start = ind[:, :, 0]


print('grad_wrt_X: ', gx.shape)
plt.imshow(gx[4], cmap='Greys')


In [None]:
ind_start

In [None]:

# 2, 2
#
print('grad_wrt_X: ', gx.shape)
plt.imshow(gx[10], cmap='Greys')

In [None]:

x = random.normal(key, (rf_len, rf_len, 1))
xp1 = np.pad(x, pad_width=(pad_hw, pad_hw, (0, 0)), constant_values=0)
xp2 = np.pad(x, pad_width=(pad_hw, pad_hw, (0, 0)), constant_values=100)

m = model_def()
params = m.init(key, xp1)
m = m.bind(params)

z1 = m(xp1)
z2 = m(xp2)

pad_patch_correct = ( np.sum(z1[spatial_coord]-z2[spatial_coord])==0 )
print(pad_patch_correct)



fig, axs = plt.subplots(1,2,figsize=(20,10))
ax = axs[0]
ax.imshow(xp1, cmap='Greys')

ax = axs[1]
ax.imshow(xp2, cmap='Greys')



In [None]:
npy_filename = 'arch_to_startind.npy'

d = {'k1': np.ones(3, dtype=np.int32), 'k2': np.ones(2)}
np.save(npy_filename, d)
np.load(npy_filename, allow_pickle=True)

