In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from functools import partial

import jax
import jax.numpy as np
from jax import random, vjp, vmap
from flax import linen as nn
from flax.core import freeze, unfreeze

from torchsummary import summary
# from torchvision.models import resnet18, resnet50
from bagnet import *

import matplotlib.pyplot as plt

from gpax import *
from jax_models import *
from resnet import *

from typing import Type, Any, Callable, Union, List, Optional


In [None]:

## Computes receptive fields for bagnet

resnet_versions = ['18', '50']
num_conv3x3_per_stages = [[1,1,0,0],[1,1,1,0],[1,1,1,1],
                        [2,1,1,1],[2,2,1,1],[2,2,2,1],[2,2,2,2]]

print('Receptive Fields')

for resnet_version in resnet_versions:
    for num_conv3x3_per_stage in num_conv3x3_per_stages:
        if resnet_version == '18':
            stage_sizes = [2,2,2,2]
            block_cls = BagNetBlock
        if resnet_version == '50':
            stage_sizes = [3,4,6,3]
            block_cls = BottleneckBagNetBlock
        model_def = partial(BagNetTrunk, stage_sizes=stage_sizes,
                                        block_cls=block_cls,
                                        num_conv3x3_per_stage=num_conv3x3_per_stage,
                                        disable_bn=True)
        rf, gx, _ = compute_receptive_fields(model_def, (1, 224, 224, 1))
        print(f'[{resnet_version}] {num_conv3x3_per_stage}:\t {rf}')

#         from plt_utils import plt_scaled_colobar_ax
#         fig, axs = plt.subplots(1,2,figsize=(30,15))
#         ax = axs[0]
#         ax.hist(gx.flatten())

#         ax = axs[1]
#         im = gx.squeeze()
#         im = ( im-np.min(im) ) / np.ptp(im)
#         pltim = ax.imshow(im, vmin=0, vmax=1, cmap='bwr')
#         fig.colorbar(pltim, cax=plt_scaled_colobar_ax(ax))
#         ax.grid()


In [None]:


    
key = random.PRNGKey(0)
x = random.normal(key, (224,224,3))
model = BagNet18x11Trunk()
params = model.init(key, x)


print(pytree_num_parameters(params['params']))

[(k, v.shape, np.prod(np.asarray(v.shape)).item()) \
     for k,v in pytree_get_kvs(params['params']).items()
     if 'conv' in k or 'Conv' in k]
