goal
- conv architecture for bagnet using 1x1 convolution. 
    - verify receptive field computation etc.
- implement functions to draw bounding boxes

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from functools import partial


import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze


import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches

from gpax import *
from jax_models import *


In [None]:
patchsize = 10
h,w = 28,28
image = np.arange(0,h*w).reshape((1,1,h,w))
_, c, x, y = image.shape
padded_image = np.zeros((c, x + patchsize - 1, y + patchsize - 1))
ind = jax.ops.index[:, (patchsize-1)//2:(patchsize-1)//2 + x, (patchsize-1)//2:(patchsize-1)//2 + y]
padded_image = jax.ops.index_update(padded_image, ind, image[0])

plt.imshow(padded_image.transpose((1,2,0)))
# image = padded_image[None].astype(np.float32)

In [None]:
import numpy as onp
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

class CNN(nn.Module):
    # rf: [46, 46]
    #
    def __init__(self,
                 num_classes=1,
                 n_filters=16):
        super(CNN, self).__init__()

        def _make_block(in_channels, out_channels, stride=2, padding=1):
            return [nn.Conv2d(in_channels, out_channels,
                              kernel_size=4, stride=stride, padding=padding),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()]
        n_layers = 4
        layers = []
        for i in range(n_layers):
            layers.extend(_make_block(
                1 if i == 0 else n_filters*(2**(i-1)), n_filters*(2**i)))
        self.conv_blocks = nn.Sequential(*layers)

    def forward(self, x, output_feat=True):
        # (1,224,224)
        x = self.conv_blocks(x)
#         # (128,14,14)
        return x


def cnn16(pretrained=False, **kwargs):
    if pretrained:
        raise ValueError('No pretrained model for CNN')
    kwargs['n_filters'] = 16
    model = CNN(**kwargs)
    return model

def compute_RF_numerical(net,img_np, out_cnn_idx=None):
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.fill_(1)
            if hasattr(m, 'bias') and m.bias is not None:
                m.bias.data.fill_(0)
        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.fill_(0)
            m.eval()
    
    net.apply(weights_init)
    img_ = torch.from_numpy(onp.array(img_np)).float()
    img_.requires_grad = True
    out_cnn=net(img_)
    if out_cnn_idx is not None:
        out_cnn = out_cnn[out_cnn_idx]
    out_shape=out_cnn.size()
    print('out_shape: ', out_shape)
    ndims=len(out_cnn.size())
    grad=torch.zeros(out_cnn.size())
    l_tmp=[]
    for i in range(ndims):
        if i==0 or i ==1:#batch or channel
            l_tmp.append(0)
        else:
            l_tmp.append(out_shape[i]//2)
    l_tmp = tuple(int(x) for x in l_tmp)
    grad[l_tmp]=1
    out_cnn.backward(gradient=grad)
    grad_np=img_.grad[0,0].data.numpy()
    idx_nonzeros=np.where(grad_np!=0)
    RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]

    return RF, grad_np



model = cnn16(); h, w = 224, 224

img_np = np.ones((1, 1, h, w))
rf, gx = compute_RF_numerical(model, img_np)
print(rf)

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx.squeeze(), cmap='Greys')

model

In [None]:
from bagnet import bagnet33, bagnet17, bagnet9

model = bagnet33(); model.eval()
h, w = 224, 224
img_np = np.arange(3*h*w).reshape((1, 3, h, w))/(3*h*w)
rf, gx = compute_RF_numerical(model, img_np, out_cnn_idx=1)
print(rf)
print(gx.shape)

fig, axs = plt.subplots(1,2,figsize=(10,10))
ax = axs[0]
ax.imshow(img_np[0].transpose((1,2,0)))

ax = axs[1]
gx[gx!=0] = 1
ax.imshow(gx.squeeze(), cmap='Greys')


In [None]:

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze


class CNNCxrTrunk(nn.Module):

    @nn.compact
    def __call__(self, x):
        conv = partial(nn.Conv, kernel_size=(4, 4), strides=(2, 2))
        # (1, 224, 224, 1)
        x = conv(features=16)(x)
        x = nn.relu(x)
        x = conv(features=32)(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = nn.relu(x)
        return x

    

model_def = CNNMnistTrunk; h, w = 10+2, 10+2;  h, w = 28,28
# model_def = CNNCxrTrunk; h, w = 46+2,46+2    # h, w = 224,224
in_shape = (1, h, w, 1)
# spike_loc = np.array([[1, 1], [1, -2], [-2, 1], [-2, -2]])
spike_loc = np.array([[0,0], [0,6]])
spike_loc = np.array([[6,0]])
rf, gx, gy = compute_receptive_fields(model_def, in_shape, spike_loc=spike_loc)
# _, gx, _ = compute_receptive_fields(model_def, in_shape, spike_loc)
# gx = jax.ops.index_update(gx, gx!=0, 1)
print(f'y.shape: {gy.shape} (spike_loc={(gy.shape[1]//2, gy.shape[2]//2)})')
print('rf: ', rf)

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(np.log(gx.squeeze()), cmap='Greys')
r = rf//2
xy = (h//2-r[1]-.5, w//2-r[0]-.5) # half-pixel
rect = mpl_patches.Rectangle(xy, rf[1], rf[0],
    linewidth=1, edgecolor='r', facecolor='none')




In [None]:
from jax_models import *




def compute_receptive_fields_start_ind(model_def, in_shape):
    """Computes start indices for patches to get transformation 
            parameters (in start indices of patches)

        ```
        g_cls = CNNMnistTrunk; image_shape = (28, 28, 1)
        ind_start, rf = compute_receptive_fields_start_ind(
            g_cls, (1, *image_shape))
        fig,ax = plt.subplots(1,1,figsize=(5,5))
        ax.imshow(np.zeros(image_shape), cmap='Greys', origin='upper')
        ax.scatter(ind_start[:,0], ind_start[:,1])
        ax.grid()
        ``` 
    """
    if len(in_shape) != 4:
        raise ValueError('`in_shape` has dims (N, H, W, C)')

    image_shape = in_shape[1:1+2]  # ndim=2
    rf, _, gy = compute_receptive_fields(model_def, in_shape)
    Py, Px = gy.shape[1:3]
    P = Py*Px
    spike_locs = list(itertools.product(np.arange(Py), np.arange(Px)))
    spike_locs = np.array(spike_locs, dtype=np.int32)

    x = np.ones(in_shape)
    model = model_def()
    params = model.init(random.PRNGKey(0), x)
    params = unfreeze(params)
    params['params'] = jax.tree_map(lambda w: np.ones(w.shape),
                                    params['params'])
    params = freeze(params)

    def f(x): return model.apply(params, x, mutable=['batch_stats'])
    (y, state), vjp_fn = vjp(f, x)

    def construct_gy(spike_loc):
        if len(spike_loc) != 2:
            raise ValueError(f'len(spike_loc)={len(spike_loc)}')
        gy = np.zeros(y.shape)
        ind = jax.ops.index[0, spike_loc[0], spike_loc[1], ...]
        gy = jax.ops.index_update(gy, ind, 1)
        gx = vjp_fn((gy, state))[0]
        return gx
    # (P, *image_shape)  squeeze batch-dim
    gx = vmap(construct_gy)(spike_locs).squeeze(1)
    ind = []
    for p in range(len(gx)):
        gxp = gx[p]
        I = np.where(gxp > np.mean(gxp)*.1)
        ind.append([(np.min(idx), np.max(idx)) for idx in I])

    # (P, 3, 2)
    ind = np.array(ind)[:, [0, 1]]
    # (P, hi/wi, min/max)
    if not np.all((ind[:, :, 1]-ind[:, :, 0]+1) <= rf):
        # Note patches on boundary have < receptive_field!
        raise ValueError('leaky gradient, `gx` has'
                         'more nonzero entries than possible')
    # (P, hi/wi)
    ind_start = ind[:, :, 0]

    return ind_start, rf



    
ResNet18Trunk = partial(ResNetTrunk, stage_sizes=[2,2,2,2],
                        block_cls=ResNetBlock)
    
model_def = ResNet18Trunk; h, w = 224, 224
# model_def = CNNMnistTrunk; h, w = 28,28


in_shape = (1, h, w, 1)
rf, gx, gy = compute_receptive_fields(model_def, in_shape)
# start_ind, rf = compute_receptive_fields_start_ind(model_def, in_shape)
print(rf)




In [None]:
model_def = ResNet18Trunk; h, w = 224, 224; h, w = 512, 512
in_shape = (1, h, w, 1)




if len(in_shape) != 4:
    raise ValueError('`in_shape` has dims (N, H, W, C)')

image_shape = in_shape[1:1+2]  # ndim=2
rf, _, gy = compute_receptive_fields(model_def, in_shape)
Py, Px = gy.shape[1:3]
P = Py*Px
spike_locs = list(itertools.product(np.arange(Py), np.arange(Px)))
spike_locs = np.array(spike_locs, dtype=np.int32)

x = np.ones(in_shape)
model = model_def()
params = model.init(random.PRNGKey(0), x)
params = unfreeze(params)
params['params'] = jax.tree_map(lambda w: random.normal(random.PRNGKey(1), w.shape),
                                params['params'])
# if 'batch_stats' in params:
#     params['batch_stats'] = jax.tree_map(lambda w: np.zeros(w.shape),
#                                     params['batch_stats'])
params = freeze(params)

def f(x): return model.apply(params, x, train=False, mutable=['batch_stats'])
(y, state), vjp_fn = vjp(f, x)

def construct_gy(spike_loc):
    if len(spike_loc) != 2:
        raise ValueError(f'len(spike_loc)={len(spike_loc)}')
    gy = np.zeros(y.shape)
    ind = jax.ops.index[0, spike_loc[0], spike_loc[1], ...]
    gy = jax.ops.index_update(gy, ind, 1)
    gx = vjp_fn((gy, state))[0]
    return gx
# (P, *image_shape)  squeeze batch-dim
gx = vmap(construct_gy)(spike_locs).squeeze(1)
# ind = []
# for p in range(len(gx)):
#     gxp = gx[p]
#     I = np.where(gxp > np.mean(gxp)*.1)
#     ind.append([(np.min(idx), np.max(idx)) for idx in I])

# # (P, 3, 2)
# ind = np.array(ind)[:, [0, 1]]
# # (P, hi/wi, min/max)
# if not np.all((ind[:, :, 1]-ind[:, :, 0]+1) <= rf):
#     # Note patches on boundary have < receptive_field!
#     raise ValueError('leaky gradient, `gx` has'
#                      'more nonzero entries than possible')
# # (P, hi/wi)
# ind_start = ind[:, :, 0]


In [None]:
params

In [None]:

ind = []
for p in range(len(gx)):
    gxp = gx[p]
    I = np.where(gxp > np.mean(gxp)*.1)
    ind.append([(np.min(idx), np.max(idx)) for idx in I])
    break
    
    
ind

In [None]:
from plt_utils import *

fig, axs = plt.subplots(1,3,figsize=(30,10))


ax = axs[0]
ax.hist(gxp.flatten())

ax = axs[1]
pltim = ax.imshow(gxp.squeeze())
fig.colorbar(pltim, cax=plt_scaled_colobar_ax(ax))


ax = axs[2]
pltim = ax.imshow((gxp>np.mean(gxp)*.1).astype(np.int32).squeeze())
fig.colorbar(pltim, cax=plt_scaled_colobar_ax(ax))



In [None]:
# plt.imshow(gx.squeeze())
# plt.imshow(gy[:,:,:,0].squeeze())
new_model_state

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from functools import partial

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze


import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches


import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze

from gpax import *

def compute_receptive_fields_start_ind_extrap(model_def, in_shape):
    """ Just 3 evalution of `vjp`, rest extrapolated ... """

    if len(in_shape) != 4:
        raise ValueError('`in_shape` has dims (N, H, W, C)')

    image_shape = in_shape[1:1+2]  # ndim=2
    rf, _, gy = compute_receptive_fields(model_def, in_shape)
    Py, Px = gy.shape[1:3]
    P = Py*Px
    spike_locs = list(itertools.product([0], np.arange(3)))
    spike_locs = np.array(spike_locs, dtype=np.int32)

    x = np.ones(in_shape)
    model = model_def()
    params = model.init(random.PRNGKey(0), x)
    params = freeze(jax.tree_map(lambda w: np.ones(w.shape),
                                 unfreeze(params)))

    def f(x): return model.apply(params, x)
    y, vjp_fn = vjp(f, x)

    def construct_gy(spike_loc):
        if len(spike_loc) != 2:
            raise ValueError(f'len(spike_loc)={len(spike_loc)}')
        gy = np.zeros(y.shape)
        ind = jax.ops.index[0, spike_loc[0], spike_loc[1], ...]
        gy = jax.ops.index_update(gy, ind, 1)
        gx = vjp_fn(gy)[0]
        return gx
    # (P, *image_shape)  squeeze batch-dim
    gx = vmap(construct_gy)(spike_locs).squeeze(1)
    ind = []
    for p in range(len(gx)):
        gxp = gx[p]
        I = np.where(gxp > np.mean(gxp)*.1)
        ind.append([(np.min(idx), np.max(idx)) for idx in I])

    # (P, 3, 2)
    ind = np.array(ind)[:, [0, 1]]

    # (P, hi/wi, min/max)
    if not np.all((ind[:, :, 1]-ind[:, :, 0]+1) <= rf):
        # Note patches on boundary have < receptive_field!
        raise ValueError('leaky gradient, `gx` has'
                         'more nonzero entries than possible')
    # (P, wi)
    ind_start = ind[:, 1, 0]
    offset_border = ind_start[1]-ind_start[0]
    step = ind_start[2]-ind_start[1]

    ind_start = list(itertools.product(np.arange(-1, Py-1),
                                       np.arange(-1, Px-1)))
    ind_start = np.array(ind_start)*step + offset_border
#     ind_start = np.maximum(0, ind_start)

    return ind_start, rf


class CNNcxr(nn.Module):

    @nn.compact
    def __call__(self, x):
        conv = partial(nn.Conv, kernel_size=(4, 4), strides=(2, 2))
        # (1, 224, 224, 1)
        x = conv(features=16)(x)
        x = nn.relu(x)
        x = conv(features=32)(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = nn.relu(x)
        return x


g_cls = CNNMnistTrunk; image_shape = (28, 28, 1)
# g_cls = CNNMnistTrunk; image_shape = (14, 14, 1)
# g_cls = CNNcxr; image_shape = (224, 224, 1)
in_shape = (1, *image_shape)

ind_start, rf = compute_receptive_fields_start_ind_extrap(
    g_cls, (1, *image_shape))
# rf, gx, gy = compute_receptive_fields(g_cls, in_shape)
print(rf)
print(gx.shape, gy.shape)

fig,ax = plt.subplots(1,1,figsize=(5,5))
ax.imshow(np.zeros(image_shape), cmap='Greys', origin='upper')
ax.scatter(ind_start[:,0], ind_start[:,1])
ax.grid()
ax.set_title(rf)

print(ind_start.tolist())

In [None]:
g_cls = CNNMnistTrunk; image_shape = (28, 28, 1)
ind_start, rf = compute_receptive_fields_start_ind(
    g_cls, (1, *image_shape))
fig,ax = plt.subplots(1,1,figsize=(5,5))
ax.imshow(np.zeros(image_shape), cmap='Greys', origin='upper')
ax.scatter(ind_start[:,0], ind_start[:,1])
ax.grid()

In [None]:
in_shape = (1,46,46,1)
x = np.ones(in_shape)
model = model_def()
key = random.PRNGKey(0)
params = model.init(key, x)
print(model.apply(params, x).shape)

spike_loc = np.array([[1, 1]])
_, gx = compute_receptive_fields(model_def, in_shape, spike_loc)
gx = jax.ops.index_update(gx, gx!=0, 1)


fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.imshow(gx.squeeze(), cmap='Greys')
ax.grid()
# r = rf//2
# xy = (h//2-r[1]-.5, w//2-r[0]-.5) # half-pixel
# rect = mpl_patches.Rectangle(xy, rf[1], rf[0],
#     linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)


In [None]:
gx.shape