Authors: Zhewei Yao <https://github.com/yaozhewei>, Amir Gholami <http://amirgholami.org/>


This tutorial shows how to compute the Hessian information using (randomized) numerical linear algebra for both explicit Hessian (the matrix is given) as well as implicit Hessian (the matrix is ungiven).

We'll start by doing the necessary imports:

In [2]:
import numpy as np
import torch 
from torchvision import datasets, transforms
from utils import * # get the dataset
from pyhessian import hessian
from pyhessian.hessian_with_activation import hessian_with_activation # Hessian computation
from density_plot import get_esd_plot # ESD plot
from pytorchcv.model_provider import get_model as ptcv_get_model # model
from pyhessian.utils import group_product

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from HAWQ.bit_config import *
from HAWQ.utils import *

# enable cuda devices
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
model_arch = 'resnet18'
model_resume = '/home/thuako/PyHessian/HAWQ/save_model/resnet18/uniform18model_best.pth.tar'
save_path = '~/PyHessian/HAWQ/save_model/pyhessian/'
quant_scheme = 'uniform8'
bias_bit = 32
channel_wise = True
act_range_momentum = -1
weight_percentile = 0
act_percentile = 1
fix_BN = True
fix_BN_threshold = None
checkpoint_iter = 1
fixed_point_quantization = False

In [4]:

quantize_arch_dict = {'resnet50': q_resnet50, 'resnet50b': q_resnet50,
                      'resnet18': q_resnet18, 'resnet101': q_resnet101,
                      'inceptionv3': q_inceptionv3,
                      'mobilenetv2_w1': q_mobilenetv2_w1}

quantize_arch = quantize_arch_dict[model_arch]          
model = ptcv_get_model(model_arch, pretrained=False)
model = quantize_arch(model)

checkpoint = torch.load(model_resume)['state_dict']
model_key_list = list(model.state_dict().keys())
for key in model_key_list:
    if 'num_batches_tracked' in key: model_key_list.remove(key)
i = 0
modified_dict = {}
for key, value in checkpoint.items():
    if 'scaling_factor' in key: continue
    if 'num_batches_tracked' in key: continue
    if 'weight_integer' in key: continue
    if 'min' in key or 'max' in key: continue
    modified_key = model_key_list[i]
    modified_dict[modified_key] = value
    i += 1

bit_config = bit_config_dict["bit_config_" + model_arch + "_" + quant_scheme]
name_counter = 0

for name, m in model.named_modules():
    if name in bit_config.keys():
        name_counter += 1
        setattr(m, 'quant_mode', 'symmetric')
        setattr(m, 'bias_bit', bias_bit)
        setattr(m, 'quantize_bias', (bias_bit != 0))
        setattr(m, 'per_channel', channel_wise)
        setattr(m, 'act_percentile', act_percentile)
        setattr(m, 'act_range_momentum', act_range_momentum)
        setattr(m, 'weight_percentile', weight_percentile)
        setattr(m, 'fix_flag', False)
        setattr(m, 'fix_BN', fix_BN)
        setattr(m, 'fix_BN_threshold', fix_BN_threshold)
        setattr(m, 'training_BN_mode', fix_BN)
        setattr(m, 'checkpoint_iter_threshold', checkpoint_iter)
        setattr(m, 'save_path', save_path)
        setattr(m, 'fixed_point_quantization', fixed_point_quantization)

        if type(bit_config[name]) is tuple:
            bitwidth = bit_config[name][0]
            if bit_config[name][1] == 'hook':
                m.register_forward_hook(hook_fn_forward)
                global hook_keys
                hook_keys.append(name)
        else:
            bitwidth = bit_config[name]

        if hasattr(m, 'activation_bit'):
            setattr(m, 'activation_bit', bitwidth)
            if bitwidth == 4:
                setattr(m, 'quant_mode', 'asymmetric')
        else:
            setattr(m, 'weight_bit', bitwidth)


In [5]:
# for resnet18 
# set model and data set
dataset = "imagenet"
dataset_dir = "/ImageNet/dataset/imagenet"

# model = ptcv_get_model("resnet18", pretrained=True)
model.eval()
model = model.cuda()
criterion = torch.nn.CrossEntropyLoss()
# for name, module in model.named_modules():
#     print(name)


In [6]:
# get dataset 
train_loader, test_loader = getData(name='imagenet', train_bs=40, train_length=0.00096, data_dir=dataset_dir)

for inputs, targets in train_loader:
    break;
print(f"train_loader size : {len(train_loader)}\n \
input size : {len(inputs)}") 

hessian_comp_resnet = hessian_with_activation(model, criterion, dataloader=train_loader, cuda=True)


train_loader size : 31
 input size : 40


In [7]:
weight_trace = hessian_comp_resnet.trace(maxIter=1,param_name='conv.weight' )
np.mean(weight_trace, axis=0)

trace had not been converge


array([-1330.01403809,  4634.32324219, -4970.03222656, -3347.58496094,
       -5367.63818359, -3505.81005859,  5201.21630859,   593.97424316,
        6201.81835938, -2027.17626953,   639.85101318, -3703.32641602,
        1171.79919434, 15643.91796875, -5805.25878906, -4618.35351562,
       26616.61914062, -3289.92871094,  5587.85009766, 14902.50292969])

In [12]:
for name, module in model.named_buffers():
    if 'weight' in name:
        print(dir(module))
        print(module.grad_fn)
        break
print(model)

['T', '__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__class__', '__complex__', '__contains__', '__cuda_array_interface__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ifloordiv__', '__ilshift__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pow__', '__radd__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rfloordiv__', '__rmul__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__setattr__', '__setitem__', '

In [7]:
device = hessian_comp_resnet.device
num_data = 0  # count the number of datum points in the dataloader
hessian_comp_resnet.insert_hook_quant_module("quant_convbn")
for inputs, targets in hessian_comp_resnet.data:
    hessian_comp_resnet.model.zero_grad()

    hessian_comp_resnet.reset_reg_active()

    outputs = hessian_comp_resnet.model(inputs.to(device))
    loss = hessian_comp_resnet.criterion(outputs, targets.to(device))
    loss.backward(create_graph=True)
    break


stage1.unit1.quant_convbn1 module hooked
stage1.unit1.quant_convbn2 module hooked
stage1.unit2.quant_convbn1 module hooked
stage1.unit2.quant_convbn2 module hooked
stage2.unit1.quant_convbn1 module hooked
stage2.unit1.quant_convbn2 module hooked
stage2.unit2.quant_convbn1 module hooked
stage2.unit2.quant_convbn2 module hooked
stage3.unit1.quant_convbn1 module hooked
stage3.unit1.quant_convbn2 module hooked
stage3.unit2.quant_convbn1 module hooked
stage3.unit2.quant_convbn2 module hooked
stage4.unit1.quant_convbn1 module hooked
stage4.unit1.quant_convbn2 module hooked
stage4.unit2.quant_convbn1 module hooked
stage4.unit2.quant_convbn2 module hooked


In [11]:

for key in hessian_comp_resnet.activation_grads.keys():
    print(f'********** {key} **********')
    print((hessian_comp_resnet.activation_grads[key][0][0].grad_fn))
    print((hessian_comp_resnet.activation_grads[key][0][0].grad_fn.next_functions))
    print(hessian_comp_resnet.activations[key][0][0].grad_fn)

# for name, params in hessian_comp_resnet.model.named_modules():
#     if name[:-1].endswith('quant_convbn'):
#         print(f"{name}")

********** stage4.unit2.quant_convbn2 **********
<DivBackward0 object at 0x7fe185c18a90>
((<CloneBackward object at 0x7fe0e3b4ed60>, 0), (None, 0))
<MulBackward0 object at 0x7fe0e3b4ed60>
********** stage4.unit2.quant_convbn1 **********
<ThresholdBackwardBackward object at 0x7fe0e3b4e970>
((<DivBackward0 object at 0x7fe0e3b4e9a0>, 0), (<MulBackward0 object at 0x7fe0e3b4ed00>, 0))
<MulBackward0 object at 0x7fe0e3b4e9a0>
********** stage4.unit1.quant_convbn2 **********
<DivBackward0 object at 0x7fe0e3b4ebb0>
((<CloneBackward object at 0x7fe0e3b4ed90>, 0), (None, 0))
<MulBackward0 object at 0x7fe0e3b4ed90>
********** stage4.unit1.quant_convbn1 **********
<ThresholdBackwardBackward object at 0x7fe0e6d52b50>
((<DivBackward0 object at 0x7fe0e6d52520>, 0), (<MulBackward0 object at 0x7fe0e6d52d30>, 0))
<MulBackward0 object at 0x7fe0e6d52520>
********** stage3.unit2.quant_convbn2 **********
<DivBackward0 object at 0x7fe0e6d52b20>
((<CloneBackward object at 0x7fe0e6d52d00>, 0), (None, 0))
<MulBa

In [None]:
# hook and find activation trace
hessian_comp_resnet.insert_hook("conv")
hessian_comp_resnet.check_reg_hook_size()

In [None]:

act_trace_resnet = hessian_comp_resnet.trace_activ(maxIter=1, tol=1e-3)
np.mean(act_trace_resnet, axis=0)

In [7]:
pyhessian = np.mean(weight_trace, axis=0)[1:]
hawq = np.array([0.06857826, 0.03162379, 0.03298575, 0.01205663, 0.02222431, 0.00596336, 0.06931772, 0.00807129, 0.00372905, 0.00530698, 0.00209011, 0.00737569, 0.00210454, 0.00151197, 0.00158041,0.00078146, 0.00451841, 0.00098745, 0.00072944])

print(len(pyhessian), len(hawq))
print(pyhessian / hawq)

19 19
[ 6.28384714e+03 -1.28923168e+04 -2.09945935e+05  3.37263196e+05
  2.85121360e+05  1.05976722e+06  5.16448765e+03  6.01798115e+05
 -8.99145786e+04  8.87246902e+05 -4.33565557e+06 -1.54747175e+05
  1.27785208e+07 -7.81820771e+06  2.00849316e+06  2.67749379e+07
 -7.63025513e+05  1.86099652e+07  2.99477301e+07]


In [14]:
param_count = []

for name, param in model.named_parameters():
    if "conv.weight" in name:
        param_count.append(len(param.view(-1)))
(pyhessian / param_count[-1:]) 

array([ 0.00018265, -0.00017281, -0.00293529,  0.0017235 ,  0.00268581,
        0.00267867,  0.00015174,  0.00205879, -0.00014212,  0.00199577,
       -0.00384098, -0.00048377,  0.0113987 , -0.00501035,  0.00134542,
        0.00886855, -0.00146131,  0.00778894,  0.00925915])

In [5]:
hessian_comp_resnet.show_hessian_layer()

features.init_block.conv.conv.weight
features.stage1.unit1.body.conv1.conv.weight
features.stage1.unit1.body.conv2.conv.weight
features.stage1.unit2.body.conv1.conv.weight
features.stage1.unit2.body.conv2.conv.weight
features.stage2.unit1.body.conv1.conv.weight
features.stage2.unit1.body.conv2.conv.weight
features.stage2.unit1.identity_conv.conv.weight
features.stage2.unit2.body.conv1.conv.weight
features.stage2.unit2.body.conv2.conv.weight
features.stage3.unit1.body.conv1.conv.weight
features.stage3.unit1.body.conv2.conv.weight
features.stage3.unit1.identity_conv.conv.weight
features.stage3.unit2.body.conv1.conv.weight
features.stage3.unit2.body.conv2.conv.weight
features.stage4.unit1.body.conv1.conv.weight
features.stage4.unit1.body.conv2.conv.weight
features.stage4.unit1.identity_conv.conv.weight
features.stage4.unit2.body.conv1.conv.weight
features.stage4.unit2.body.conv2.conv.weight


In [5]:
count = {"conv.weight": 0, "bn.weight": 0, "bn.bias":  0, "features": 0}

weight_layer_list = []
for (i, (name, p)) in enumerate(hessian_comp_resnet.model.named_parameters()):
    print(f"{name} gradient is {p.grad_fn}")
    for key in count.keys():
        if key in name:
            count[key] += 1
                
        
for key in count.keys():
    print(f"the number of {key} : {count[key]}")

print(f"the number of hessian : {len(np.mean(weight_trace, axis=0))}")

features.init_block.conv.conv.weight gradient is None
features.init_block.conv.bn.weight gradient is None
features.init_block.conv.bn.bias gradient is None
features.stage1.unit1.body.conv1.conv.weight gradient is None
features.stage1.unit1.body.conv1.bn.weight gradient is None
features.stage1.unit1.body.conv1.bn.bias gradient is None
features.stage1.unit1.body.conv2.conv.weight gradient is None
features.stage1.unit1.body.conv2.bn.weight gradient is None
features.stage1.unit1.body.conv2.bn.bias gradient is None
features.stage1.unit2.body.conv1.conv.weight gradient is None
features.stage1.unit2.body.conv1.bn.weight gradient is None
features.stage1.unit2.body.conv1.bn.bias gradient is None
features.stage1.unit2.body.conv2.conv.weight gradient is None
features.stage1.unit2.body.conv2.bn.weight gradient is None
features.stage1.unit2.body.conv2.bn.bias gradient is None
features.stage2.unit1.body.conv1.conv.weight gradient is None
features.stage2.unit1.body.conv1.bn.weight gradient is None
fe

In [6]:
weight_trace = hessian_comp_resnet.trace(maxIter=50)
np.mean(weight_trace, axis=0)

trace had not been converge


array([ 2.71595064e+03,  8.71741765e+00,  2.62039937e+01,  3.64793158e+03,
        4.74658755e+00,  6.00919633e-01,  1.37105965e+03, -9.40865991e-01,
        5.33281157e+00,  1.64131633e+03, -5.24804253e-01,  1.09716395e+00,
        6.97771637e+02,  1.34698081e+00, -1.11872495e+00,  1.70895995e+03,
       -1.33548991e+00,  4.47255922e+00,  1.17322799e+03, -1.52370351e+00,
        2.89366024e+00,  4.39872609e+02, -1.25825181e+00,  1.79194921e-01,
        1.58911529e+03, -5.89528441e-03, -6.85732658e-01,  7.19000252e+02,
        2.33258068e+00,  5.12466332e-01,  1.98001098e+03,  2.98390957e+00,
        8.72554815e-02,  1.38419868e+03,  1.84380284e+00,  6.69318685e+00,
        3.47149133e+02,  5.48378468e-01, -1.47029647e+00,  1.60556912e+03,
       -1.89265554e-01,  6.95783268e-01,  1.07468449e+03,  1.29804430e+00,
        1.51972714e+00,  2.47274337e+03,  2.63763887e+00,  1.21341769e+00,
        2.22600354e+03,  4.98319585e-01, -6.70611072e-02,  7.62577498e+02,
        4.47491714e-01,  

In [None]:
# get the model 
model = ptcv_get_model("resnet20_cifar10", pretrained=True)
model.eval()
model = model.cuda()

# create loss function
criterion = torch.nn.CrossEntropyLoss()

# get dataset 
train_loader, test_loader = getData(train_bs=80, train_length=0.024)

# for illustrate, we only use one batch to do the tutorial
for inputs, targets in train_loader:
    break;
print(len(train_loader))    
print(len(inputs))

# change the model to eval mode to disable running stats upate


In [None]:
# create the hessian computation module
hessian_comp = hessian_with_activation(model, criterion, dataloader=train_loader, cuda=True)
hessian_comp.insert_hook("conv")

In [None]:
act_trace = hessian_comp.trace_activ(maxIter=100, tol=1e-3)
np.mean(act_trace, axis=0)

In 83th iteration, trace had been converge


array([12.90942681,  1.92099523,  0.22628432,  1.12419579,  0.1060052 ,
        0.93684972,  0.13285358,  0.12223492,  0.92358128,  0.14782701,
        1.02247512,  0.15254644,  1.81130452,  0.02953197,  0.2341992 ,
        1.69802872,  1.14589175,  1.44514605,  0.73692089,  0.51901171])

Note how different the loss landscape looks. In particular note that there is almost no change in the loss value (see the small scale of the y-axis). This is expected, since for a converged NN, many of the directions are typically degenarate (i.e. they are flat).