# Analyzing statistics of Conv3x3 layers in pretrained models

### Setup:

`conda install -c conda-forge ipywidgets`

*OR*

`pip install ipywidgets` 

`jupyter nbextension enable --py widgetsnbextension`

*AND*

`conda install nb_conda_kernels`


In [61]:
import mxnet as mx
from mxnet import np
import gluoncv
import mxnet.ndarray as nd

In [None]:
# optional: uncomment to speed up np mode for pretrained statistics, but this breaks the baseline model
# from mxnet import npx
# npx.set_np()

In [114]:
MODE = 'nd'

class Stats:
    def __init__(self, data, axis=0):
        if MODE == 'np':
            self.mean = np.mean(data, axis)
            self.min = np.min(data, axis)
            self.max = np.max(data, axis)
            self.std = np.std(data, axis)
        else:
            self.mean = nd.mean(data, axis)
            self.min = nd.min(data, axis)
            self.max = nd.max(data, axis)
            # nd does not have std implementation...
    
    
class PassOnStats:
    def __init__(self, stats):
        if type(stats) == Stats:
            stats = [stats]
        if MODE == 'np':
            self.mean = np.mean(np.array([s.mean for s in stats]))
            self.min = np.min(np.array([s.min for s in stats]))
            self.max = np.max(np.array([s.max for s in stats]))
            # pass on the mean of the std value from the original stat
            self.std = np.mean(np.array([s.std for s in stats]))
        else:
            from statistics import mean
            self.mean = nd.mean(nd.stack(*[s.mean for s in stats]))
            self.min = nd.min(nd.stack(*[s.min for s in stats]))
            self.max = nd.max(nd.stack(*[s.max for s in stats]))
            # nd does not have std implementation...
        
    
def create_channelwise_stats(param):
    # channelwise means for filters
    if MODE == 'np':
        cw_means = np.mean(param.data().as_np_ndarray(), axis=(2,3))
    else:
        cw_means = nd.mean(param.data(), axis=(2,3))
    return Stats(cw_means, axis=0)
    
    
class ModelParameterStats:
    params = None
    
    def __init__(self, model, **kwargs):
        self.model = model
        self._collect_params(**kwargs)
        self._collect_stats()
        
    def print_params(self):
        for name in self.params:
            p = name_params[name]
            print(f'{p.name}:\n  {p.shape}')
        
    def _collect_params(self,
                      name_filters=[lambda name: 'conv' in name, lambda name: 'weight' in name], 
                      shape_filters=[lambda shape: len(shape) == 4, lambda shape: shape[2:] == (3, 3)]):
        # collect list of params filtered by name
        all_params = self.model.collect_params()
        name_params = {name: all_params[name] for name in all_params if all([nf(name) for nf in name_filters])}
        # filter by shape
        shape_params = {}
        for name in name_params:
            p = name_params[name]
            if all([sf(p.shape) for sf in shape_filters]):
                shape_params[name] = p
        self.params = shape_params
        
    def _collect_stats(self, stats_func=create_channelwise_stats):
        self._collect_filter_stats(stats_func)
        self._collect_param_stats()
        self._collect_layer_stats()
        self._collect_model_stats()
    
    def _collect_filter_stats(self, stats_func):
        self.filter_stats = {}
        for param in self.params.values():
            self.filter_stats[param.name] = stats_func(param)
    
    def _collect_param_stats(self):
        # stats across parameter/filter bank
        self.param_stats = {}
        for param_name, fstats in self.filter_stats.items():
            self.param_stats[param_name] = PassOnStats(fstats)
    
    def _collect_layer_stats(self):
        # stats across layer (collection of blocks in resnets)
        from collections import defaultdict
        self.layer_stats = {}
        layer_param_stats = defaultdict(list)
        # building lists of parameters by layer
        for param_name in self.param_stats:
            # parsing layer number
            layer_name = list(filter(lambda x: 'layers' in x, param_name.split('_')))
            if layer_name:
                layer_name = layer_name[0]
                layer_param_stats[layer_name].append(self.param_stats[param_name])
        # setting statistics       
        for layer, stats_list in layer_param_stats.items():
            self.layer_stats[layer] = PassOnStats(stats_list)
        
    def _collect_model_stats(self):
        self.model_stats = {}
        self.model_stats = PassOnStats(list(self.layer_stats.values()))


In [115]:
model_name = 'resnet50_v1b'

In [116]:
# download and load the pre-trained model
net = gluoncv.model_zoo.get_model(model_name, pretrained=True)

In [117]:
stats = ModelParameterStats(net)

In [118]:
stats.model_stats.mean


[-0.00077511]
<NDArray 1 @cpu(0)>

In [119]:
stats.param_stats['resnetv1b_layers1_conv1_weight'].mean


[0.00034619]
<NDArray 1 @cpu(0)>

In [120]:
stats.param_stats['resnetv1b_layers1_conv1_weight'].min


[-0.08658515]
<NDArray 1 @cpu(0)>

In [121]:
stats.param_stats['resnetv1b_layers1_conv1_weight'].max


[0.09441461]
<NDArray 1 @cpu(0)>

In [126]:
# create the same model from scratch
baseline_net = gluoncv.model_zoo.get_model(model_name, pretrained=False)

In [127]:
baseline_net.hybridize()

In [128]:
from mxnet.gluon.data.vision import transforms
from gluoncv.data import transforms as gcv_transforms
from mxnet import gluon

transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

def fake_train(epochs=1, ctx=[mx.cpu()]):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    baseline_net.initialize(mx.init.Xavier(), ctx=ctx)

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR100(train=True).transform_first(transform_train),
        batch_size=2)

    for i, batch in enumerate(train_data):
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

        output = [baseline_net(X) for X in data]

        return

In [129]:
# need to do this to initialize all params
fake_train()

In [130]:
baseline_stats = ModelParameterStats(baseline_net)

In [131]:
baseline_stats.model_stats.mean


[-6.0612447e-06]
<NDArray 1 @cpu(0)>

In [132]:
baseline_stats.param_stats['resnetv1b_layers1_conv1_weight'].mean


[-0.00016074]
<NDArray 1 @cpu(0)>

In [133]:
baseline_stats.param_stats['resnetv1b_layers1_conv1_weight'].min


[-0.04811678]
<NDArray 1 @cpu(0)>

In [134]:
baseline_stats.param_stats['resnetv1b_layers1_conv1_weight'].max


[0.04507212]
<NDArray 1 @cpu(0)>