- Parameter/Buffer fix: by using
nics_fix_pt.nn_fix.<original name>_fix
modules; parameters are those to be trained, buffers are those to be persistenced but not considered parameter - Activation fix: by using
ActivationFix
module - Data fix VS. Gradient fix: by supply
nf_fix_params
/nf_fix_params_grad
kwargs args innics_fix_pt.nn_fix.<original name>_fix
orActivationFix
module construction
NOTE: When constructing a fixed-point module, the dictionary you passed in as
nf_fix_params
argument will be used directly, so if you pass the same configuration dict to two modules. These two module will share the same configurations... If you want each module to be quantized independently, you should construct different configuration dict for each module.
The underhood quantization of each datum is implemented in nics_fix.quant._do_quantitize
, the code is easy to understand, if you need different plug-in quantize method other than this default one, or maybe need a configurable behavior of _do_quantitize
(maybe in simulation for simultaneous computation on different types of device), please contribute for a pluggable and configurable _do_quantitize
(I believe this will be a simple change), or you can contact me.
See examples/mnist/train_mnist.py
for a MNIST fix-point training example, and some demo usage of the utilities.
How to make a module fixed-point
If you have implemented your new module type, e.g. a masked convolution module type (a convolution with a mask that is generated by prunning) MaskedConv2d
, you want a fixed-point version of it, just do:
from nics_fix_pt import register_fix_module
register_fix_module(MaskedConv2d)
Then you can construct the fixed-point module, and use it in the following code. All of it functionalities will stay the same, except the weights and the gradient of weights(optionally) will be converted to fixed-point before each usage.
import nics_fix_pt.nn_fix as nnf
masked_conv = nnf.MaksedConv2d_fix(...your paramters, including fixed-point configs...)
The internal registered fixed-point modules are already tested in our test examples and our other works:
torch.nn.Linear
torch.nn.Conv2d
torch.nn.BatchNorm1d
torch.nn.BatchNorm2d
NOTE: We expect most
torch.nn
modules will work well without special handling, so you can use the AutoRegister functionality, which will auto register non-registeredtorch.nn
modules. However, as we do not test all these modules thoroughly, if you find some modules fail to work normally in some use case, please tell us or contribute to handle that specially.
How to inspect the float-point precision/quantized datum
Network parameters are saved as float in module._parameters[<parameter or buffer name>]
, and module.<parameter or buffer name>
is the fixed-point/quantized version. You can either use module._parameters
, or <quantized tensor>.nfp_actual_data
(e.g. masked_conv.weights.nfp_actual_data
) to access the float-point precision datum.
As module._parameters
and module._buffers
are used by model.state_dict
, when you dump checkpoints onto the disk, the saved network parameters using model.state_dict
is float precision:
- In your use cases(e.g. fixed-point hardware simultation), together with the saved float-point precision parameters, you might need to dump and then load/modify the fixed configurations of the variables using
model.get_fix_configs
andmodel.load_fix_configs
. Checkexamples/mnist/train_mnist.py
for an example. - Or if you want to directly dump the latest used version of the parameters (which might be a quantitized tensor, depend on your latest configuration), use
nnf.FixTopModule.fix_state_dict(module)
for dumping instead.
How to config the fixed-point behavior
Every fixed-point module need a fixed-point configuration, and an optional fixed-point configuration for the gradients.
A config for the module should be a dict, keys are the parameter or buffer names, values is a dict includes torch tensors (for current "bitwidth", "method", "scale"), which are modified in place by function calls to nics_fix_pt.quant.quantitize_cfg
;
How to check/manage the configuration
- For each quantitized datum, you can use
<quantized tensor>.data_cfg
or<quantized tensor>.grad_cfg
, e.g.masked_conv.weights.data_cfg
ormasked_conv.bias.data_cfg
. - You can also use
FixTopModule.get_fix_configs(module)
to get configs for multiple modulues in one OrderedDict.
You can modify the config tensors in place to change the behavior.
-
FixTopModule: dump/load fix configuration to/from file; print fix configs.
FixTopModule is just a wrapper that gather config print/load/dump/setting utilities, these utilities will work with nested normal nn.Module as intermediate module containers, e.g.
nn.Sequential
of fixed modules will also work, you do not need to have a subclass multi-inherited fromnn.Sequential
andnnf.FixTopModule
! -
AutoRegister: auto register corresponding fixed-point modules for modules in
torch.nn
: Automatically register all not-registered module by proxing to modules in torch.nn. Exampe usage:from nics_fix_pt import NAR as nnf_auto bilinear = nnf_auto.Bilinear_fix(...parameters...)
Tested with Python 2.7, 3.5, 3.6.1+.
Pytorch 0.4.1, 1.0.0, 1.1.0, 1.4.0. Note that fixed-point simulation using DataParallel with Pytorch>=1.5.0 versions are not supported now!
Run python setup.py test
to run the pytest test cases.