## Setup:

`conda install -c conda-forge ipywidgets`

*OR*

`pip install ipywidgets` 

`jupyter nbextension enable --py widgetsnbextension`

*AND*

`conda install nb_conda_kernels`


In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from importlib import reload
from mxnet_resnet import train_cifar10

In [24]:
from tqdm.notebook import tqdm

In [1]:
import matplotlib
matplotlib.use('Agg')

import argparse, time, logging

import numpy as np
import mxnet as mx

from mxnet import gluon, nd
from mxnet import autograd as ag
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms

import gluoncv as gcv
gcv.utils.check_version('0.6.0')
from gluoncv.model_zoo import get_model
from gluoncv.utils import makedirs, TrainingHistory
from gluoncv.data import transforms as gcv_transforms

In [6]:
opt = argparse.Namespace()
opt.batch_size = 32
opt.num_gpus = 0
opt.model = 'cifar_resnet20_v1'
opt.num_workers = 4
opt.num_epochs = 3
opt.lr = 0.1
opt.momentum = 0.9
opt.wd = 0.0001
opt.lr_decay = 0.1
opt.lr_decay_period = 0
opt.lr_decay_epoch = '40,60'
opt.drop_rate = 0.0
opt.mode = 'hybrid'
opt.save_period = 10
opt.save_dir = 'params'
opt.resume_from = ''
opt.save_plot_dir = '.'
opt.save_on_quit = False
opt.zmg = 0.

In [16]:
reload(train_cifar10)
train_cifar10.main(opt)

INFO:root:Namespace(batch_size=32, drop_rate=0.0, lr=0.1, lr_decay=0.1, lr_decay_epoch='40,60', lr_decay_period=0, mode='hybrid', model='cifar_resnet20_v1', momentum=0.9, num_epochs=3, num_gpus=0, num_workers=4, resume_from='', save_dir='params', save_on_quit=False, save_period=10, save_plot_dir='.', wd=0.0001, zmg=0.0)


HBox(children=(FloatProgress(value=0.0, description='Epochs', max=3.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=1562.0, style=ProgressStyle(description_wid…

INFO:root:Interrupted!






In [40]:
# main
batch_size = opt.batch_size
classes = 10

num_gpus = opt.num_gpus
batch_size *= max(1, num_gpus)
context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
num_workers = opt.num_workers

lr_decay = opt.lr_decay
lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

model_name = opt.model
if model_name.startswith('cifar_wideresnet'):
    kwargs = {'classes': classes,
            'drop_rate': opt.drop_rate}
else:
    kwargs = {'classes': classes}
net = get_model(model_name, **kwargs)
if opt.resume_from:
    net.load_parameters(opt.resume_from, ctx = context)
optimizer = 'nag'

save_period = opt.save_period
if opt.save_dir and save_period:
    save_dir = opt.save_dir
    makedirs(save_dir)
else:
    save_dir = ''
    save_period = 0

plot_path = opt.save_plot_dir

logging.basicConfig(level=logging.INFO)
logging.info(opt)

transform_train = transforms.Compose([
    gcv_transforms.RandomCrop(32, pad=4),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])

INFO:root:Namespace(batch_size=32, drop_rate=0.0, lr=0.1, lr_decay=0.1, lr_decay_epoch='40,60', lr_decay_period=0, mode='hybrid', model='cifar_resnet20_v1', momentum=0.9, num_epochs=3, num_gpus=0, num_workers=4, resume_from='', save_dir='params', save_on_quit=False, save_period=10, save_plot_dir='.', wd=0.0001, zmg=0.0)


In [154]:
def test(ctx, val_data):
    """
    Evaluates net on test metric (Accuracy)
    :param ctx:
    :param val_data:
    :return: (name, value)
    """
    metric = mx.metric.Accuracy()
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        outputs = [net(X) for X in data]
        metric.update(label, outputs)
    return metric.get()

def zmg_norm(ctx, mean_mult=10.):
    """
    Normalize the net's conv params' gradients to maintain a channel-wise zero mean
    :param ctx: mxnet context
    :param mean_mult:
    """
    if opt.zmg > 0.:
        mean_dims = (2,3)  # per-channel mean
        param_dict = net.collect_params()
        for param_name in param_dict:
            if '_conv' and '_weight' in param_name:  # only modify grads of conv weights
                param = param_dict[param_name]
                # Only manipulate gradients in the given context
                ctx_indices = [param.list_ctx().index(c) for c in ctx]
                for i, grad in enumerate(param._grad):
                    if i in ctx_indices:
                        param_mean_grad = grad.mean(axis=mean_dims, keepdims=True,
                                                    name='param_mean_grad')  # IMPORTANT
                        param._grad[i] = grad + (grad - param_mean_grad) * mean_mult


In [42]:
if opt.mode == 'hybrid':
    net.hybridize()

In [155]:
ctx = context
epochs = opt.num_epochs
# train
if isinstance(ctx, mx.Context):
    ctx = [ctx]
net.initialize(mx.init.Xavier(), ctx=ctx)

train_data = gluon.data.DataLoader(
    gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
    batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

val_data = gluon.data.DataLoader(
    gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
    batch_size=batch_size, shuffle=False, num_workers=num_workers)

trainer = gluon.Trainer(net.collect_params(), optimizer,
                        {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
metric = mx.metric.Accuracy()
train_metric = mx.metric.Accuracy()
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
train_history = TrainingHistory(['training-error', 'validation-error'])

iteration = 0
lr_decay_count = 0

best_val_score = 0

for epoch in tqdm(range(epochs), desc='Epochs'):
    try:
        tic = time.time()
        train_metric.reset()
        metric.reset()
        train_loss = 0
        num_batch = len(train_data)
        alpha = 1

        if epoch == lr_decay_epoch[lr_decay_count]:
            trainer.set_learning_rate(trainer.learning_rate*lr_decay)
            lr_decay_count += 1

        for i, batch in enumerate(tqdm(train_data, desc='Batches')):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

            with ag.record():
                output = [net(X) for X in data]
                loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
            for l in loss:
                l.backward()
            # Apply gradient modifications TODO: mod grads in custom Trainer?
            zmg_norm(ctx)
            trainer.step(batch_size)
            train_loss += sum([l.sum().asscalar() for l in loss])

            train_metric.update(label, output)
            name, acc = train_metric.get()
            iteration += 1

        train_loss /= batch_size * num_batch
        name, acc = train_metric.get()
        name, val_acc = test(ctx, val_data)
        train_history.update([1-acc, 1-val_acc])
        train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name))

        if val_acc > best_val_score:
            best_val_score = val_acc
            net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

        logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
            (epoch, acc, val_acc, train_loss, time.time()-tic))

        if save_period and save_dir and (epoch + 1) % save_period == 0:
            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))
    except KeyboardInterrupt as ki:
        if opt.save_on_quit:
            logging.info(f'Interrupted, saving model at epoch {epoch}')
            net.save_parameters('%s/cifar10-%s-%d-interrupt.params' % (save_dir, model_name, epoch))
        else:
            logging.info('Interrupted!')
        break

if save_period and save_dir:
    net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))

  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reinit)
  v.initialize(None, ctx, init, force_reinit=force_reini

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=3.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Batches', max=1562.0, style=ProgressStyle(description_wid…





TypeError: zmg_norm() missing 1 required positional argument: 'ctx'

In [150]:
param_dict = net.collect_params()

In [151]:
[p for p in param_dict if '_conv' and '_weight' in p]

['cifarresnetv117_conv0_weight',
 'cifarresnetv117_stage1_conv0_weight',
 'cifarresnetv117_stage1_conv1_weight',
 'cifarresnetv117_stage1_conv2_weight',
 'cifarresnetv117_stage1_conv3_weight',
 'cifarresnetv117_stage1_conv4_weight',
 'cifarresnetv117_stage1_conv5_weight',
 'cifarresnetv117_stage2_conv0_weight',
 'cifarresnetv117_stage2_conv1_weight',
 'cifarresnetv117_stage2_conv2_weight',
 'cifarresnetv117_stage2_conv3_weight',
 'cifarresnetv117_stage2_conv4_weight',
 'cifarresnetv117_stage2_conv5_weight',
 'cifarresnetv117_stage2_conv6_weight',
 'cifarresnetv117_stage3_conv0_weight',
 'cifarresnetv117_stage3_conv1_weight',
 'cifarresnetv117_stage3_conv2_weight',
 'cifarresnetv117_stage3_conv3_weight',
 'cifarresnetv117_stage3_conv4_weight',
 'cifarresnetv117_stage3_conv5_weight',
 'cifarresnetv117_stage3_conv6_weight',
 'cifarresnetv117_dense0_weight']

In [152]:
param = param_dict['cifarresnetv117_stage1_conv0_weight']

In [153]:
param._grad

[
 [[[[ 1.01734884e-03 -1.26682907e-01 -2.38644615e-01]
    [ 1.14951633e-01 -5.99930733e-02 -1.72402725e-01]
    [ 3.77090305e-01  1.33727521e-01 -5.43578453e-02]]
 
   [[-1.60185806e-03  1.08917989e-01  2.23440737e-01]
    [-7.09771737e-02  4.22134809e-02  1.33995742e-01]
    [-2.05956832e-01 -6.12029657e-02  4.25772443e-02]]
 
   [[ 1.24697760e-02 -1.27726316e-01 -2.96792269e-01]
    [ 1.86754942e-01 -1.30796768e-02 -1.95506215e-01]
    [ 4.65976387e-01  2.21411347e-01 -4.12230082e-02]]
 
   ...
 
   [[ 1.29863694e-01  9.16957855e-04 -4.00356650e-02]
    [ 2.26372749e-01  4.86166403e-02 -3.79877388e-02]
    [ 4.67543364e-01  2.33857989e-01  4.43990827e-02]]
 
   [[ 7.82072395e-02  1.43067598e-01  2.00291768e-01]
    [-5.10805324e-02  9.60586965e-02  2.03832209e-01]
    [-3.17974895e-01 -1.36952460e-01  7.17934147e-02]]
 
   [[-7.44201243e-04  1.24795035e-01  2.20464692e-01]
    [-6.02310002e-02  6.69940263e-02  1.36561349e-01]
    [-1.94393367e-01 -3.27775106e-02  5.21418452e-02]]]


In [108]:
param.list_ctx()[0]

cpu(0)

In [119]:
param.__dict__.keys()

dict_keys(['_var', '_data', '_grad', '_ctx_list', '_ctx_map', '_trainer', '_deferred_init', '_differentiable', '_allow_deferred_init', '_grad_req', '_shape', 'name', '_dtype', 'lr_mult', 'wd_mult', 'init', '_grad_stype', '_stype'])

In [148]:
ctx_indices = [ param.list_ctx().index(c) for c in ctx ]
mean_dims = (2,3)
mean_mult = 10
for i, grad in enumerate(param._grad):
    if i in ctx_indices:
        param_mean_grad = grad.mean(axis=mean_dims, keepdims=True, name='param_mean_grad')  # IMPORTANT
        param._grad[i] = grad + (grad - param_mean_grad) * mean_mult

In [149]:
param._grad

[
 [[[[ 1.01734884e-03 -1.26682907e-01 -2.38644615e-01]
    [ 1.14951633e-01 -5.99930733e-02 -1.72402725e-01]
    [ 3.77090305e-01  1.33727521e-01 -5.43578453e-02]]
 
   [[-1.60185806e-03  1.08917989e-01  2.23440737e-01]
    [-7.09771737e-02  4.22134809e-02  1.33995742e-01]
    [-2.05956832e-01 -6.12029657e-02  4.25772443e-02]]
 
   [[ 1.24697760e-02 -1.27726316e-01 -2.96792269e-01]
    [ 1.86754942e-01 -1.30796768e-02 -1.95506215e-01]
    [ 4.65976387e-01  2.21411347e-01 -4.12230082e-02]]
 
   ...
 
   [[ 1.29863694e-01  9.16957855e-04 -4.00356650e-02]
    [ 2.26372749e-01  4.86166403e-02 -3.79877388e-02]
    [ 4.67543364e-01  2.33857989e-01  4.43990827e-02]]
 
   [[ 7.82072395e-02  1.43067598e-01  2.00291768e-01]
    [-5.10805324e-02  9.60586965e-02  2.03832209e-01]
    [-3.17974895e-01 -1.36952460e-01  7.17934147e-02]]
 
   [[-7.44201243e-04  1.24795035e-01  2.20464692e-01]
    [-6.02310002e-02  6.69940263e-02  1.36561349e-01]
    [-1.94393367e-01 -3.27775106e-02  5.21418452e-02]]]


In [156]:
param

Parameter cifarresnetv117_stage1_conv0_weight (shape=(16, 16, 3, 3), dtype=<class 'numpy.float32'>)