Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
skorch doctor: a tool to understand the net (#912)
A helper class to assist in understanding the neural net training The SkorchDoctor helper class allows users to wrap their neural net before training and then automatically collect useful data that allows to better understand what is going on during training and how to possibly improve it. The class will automatically record activations of each module + gradients and updates of each learnable parameter, all of those for each training step. Once training is finished, the user can either directly take a look at the data, which is stored as an attribute on the helper class, or use one of the provided plotting functions (requires matplotlib) to plot distributions of the data. Examples of what conclusions could be drawn from the data: - Net is not powerful enough - Need for better weight initialization or normalization - Need to adjust optimizer - Need for gradient clipping However, the helper class will not suggest any of those solutions itself, I don't think that's possible. It is only intended to help surfacing potential problems, it's up to the user to decide on a solution. A notebook to show the usage of SkorchDoctor, once for a simple MLP and once for fine-tuning a BERT model, is provided: https://github.com/skorch-dev/skorch/blob/skorch-doctor/notebooks/Skorch_Doctor.ipynb Implementation Because of the additional data being collected, depending on the use case, a significant memory overhead is expected. To keep this in check, a few measures are taken: - The collected data is immediately pulled to numpy to avoid clobbering GPU memory. - It is documented, and shown in examples, that you should use only a small amount of data and low number of epochs, since that's enough to understand most problems. Most notably, this helps with storing less data about activations. - For parameter updates, only a single scalar per weight/bias is stored, indicating the relative magnitude of the update. - The biggest overhead will most likely come from storing the gradients, not sure if something can be done here without losing too much useful data. - An option is provided to filter by layer/parameter name. For storing activations, some heuristics are in place to deal with the output of the modules. The problem here is that modules can return any arbitrary data from their forward call. A few assumptions are being made here: The output can be shoved into to_numpy and it has to be either a torch tensor, a list, a tuple, or a mapping of torch tensors. If it's neither of those, an error is raised. --------- Co-authored-by: BenjaminBossan <b.bossan@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
- Loading branch information