Skip to content

walkerning/nics_fix_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fixed Point Training Simulation Framework on PyTorch

Core Functionality

  • Parameter/Buffer fix: by using nics_fix_pt.nn_fix.<original name>_fix modules; parameters are those to be trained, buffers are those to be persistenced but not considered parameter
  • Activation fix: by using ActivationFix module
  • Data fix VS. Gradient fix: by supply nf_fix_params/nf_fix_params_grad kwargs args in nics_fix_pt.nn_fix.<original name>_fix or ActivationFix module construction

NOTE: When constructing a fixed-point module, the dictionary you passed in as nf_fix_params argument will be used directly, so if you pass the same configuration dict to two modules. These two module will share the same configurations... If you want each module to be quantized independently, you should construct different configuration dict for each module.

The underhood quantization of each datum is implemented in nics_fix.quant._do_quantitize, the code is easy to understand, if you need different plug-in quantize method other than this default one, or maybe need a configurable behavior of _do_quantitize(maybe in simulation for simultaneous computation on different types of device), please contribute for a pluggable and configurable _do_quantitize (I believe this will be a simple change), or you can contact me.

Examples

See examples/mnist/train_mnist.py for a MNIST fix-point training example, and some demo usage of the utilities.

Usage Explained

How to make a module fixed-point

If you have implemented your new module type, e.g. a masked convolution module type (a convolution with a mask that is generated by prunning) MaskedConv2d, you want a fixed-point version of it, just do:

from nics_fix_pt import register_fix_module
register_fix_module(MaskedConv2d)

Then you can construct the fixed-point module, and use it in the following code. All of it functionalities will stay the same, except the weights and the gradient of weights(optionally) will be converted to fixed-point before each usage.

import nics_fix_pt.nn_fix as nnf
masked_conv = nnf.MaksedConv2d_fix(...your paramters, including fixed-point configs...)

The internal registered fixed-point modules are already tested in our test examples and our other works:

  • torch.nn.Linear
  • torch.nn.Conv2d
  • torch.nn.BatchNorm1d
  • torch.nn.BatchNorm2d

NOTE: We expect most torch.nn modules will work well without special handling, so you can use the AutoRegister functionality, which will auto register non-registered torch.nn modules. However, as we do not test all these modules thoroughly, if you find some modules fail to work normally in some use case, please tell us or contribute to handle that specially.

How to inspect the float-point precision/quantized datum

Network parameters are saved as float in module._parameters[<parameter or buffer name>], and module.<parameter or buffer name> is the fixed-point/quantized version. You can either use module._parameters, or <quantized tensor>.nfp_actual_data (e.g. masked_conv.weights.nfp_actual_data) to access the float-point precision datum.

As module._parameters and module._buffers are used by model.state_dict, when you dump checkpoints onto the disk, the saved network parameters using model.state_dict is float precision:

  • In your use cases(e.g. fixed-point hardware simultation), together with the saved float-point precision parameters, you might need to dump and then load/modify the fixed configurations of the variables using model.get_fix_configs and model.load_fix_configs. Check examples/mnist/train_mnist.py for an example.
  • Or if you want to directly dump the latest used version of the parameters (which might be a quantitized tensor, depend on your latest configuration), use nnf.FixTopModule.fix_state_dict(module) for dumping instead.

How to config the fixed-point behavior

Every fixed-point module need a fixed-point configuration, and an optional fixed-point configuration for the gradients.

A config for the module should be a dict, keys are the parameter or buffer names, values is a dict includes torch tensors (for current "bitwidth", "method", "scale"), which are modified in place by function calls to nics_fix_pt.quant.quantitize_cfg;

How to check/manage the configuration

  1. For each quantitized datum, you can use <quantized tensor>.data_cfg or <quantized tensor>.grad_cfg, e.g. masked_conv.weights.data_cfg or masked_conv.bias.data_cfg.
  2. You can also use FixTopModule.get_fix_configs(module) to get configs for multiple modulues in one OrderedDict.

You can modify the config tensors in place to change the behavior.

Utilities

  • FixTopModule: dump/load fix configuration to/from file; print fix configs.

    FixTopModule is just a wrapper that gather config print/load/dump/setting utilities, these utilities will work with nested normal nn.Module as intermediate module containers, e.g. nn.Sequential of fixed modules will also work, you do not need to have a subclass multi-inherited from nn.Sequential and nnf.FixTopModule!

  • AutoRegister: auto register corresponding fixed-point modules for modules in torch.nn: Automatically register all not-registered module by proxing to modules in torch.nn. Exampe usage:

    from nics_fix_pt import NAR as nnf_auto
    bilinear = nnf_auto.Bilinear_fix(...parameters...)

Test cases

Tested with Python 2.7, 3.5, 3.6.1+.

Pytorch 0.4.1, 1.0.0, 1.1.0, 1.4.0. Note that fixed-point simulation using DataParallel with Pytorch>=1.5.0 versions are not supported now!

coverage percentage

Run python setup.py test to run the pytest test cases.

About

pytorch fixed point training tool/framework

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages