Code associated to the paper Adapting Newton's Method to Neural Networks through a Summary of Higher-Order Derivatives (2023), P. Wolinski.
Link: https://arxiv.org/abs/2312.03885.
import torch
import grnewt
# User-specific
data_loader = torch.utils.data.DataLoader(...)
model = MyModel(...)
loss_fn = lambda output, target: ...
# Create a specific data loader
hg_loader = torch.utils.data.DataLoader(...)
# Set some hyperparameters
damping = .1
damping_int = 10.
# Prepare the optimizer
full_loss = lambda x, target: loss_fn(model(x), target)
param_groups, name_groups = grnewt.partition.canonical(model)
optimizer = grnewt.NewtonSummary(param_groups, full_loss, hg_loader,
damping = damping, dct_nesterov = {'use': True, 'damping_int': damping_int},
period_hg = 10, mom_lrs = .5, remove_negative = True,
momentum = .9, momentum_damp = .9)
# Optimization process
for epoch in range(10):
for x, target in data_loader:
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
Specific variables:
-
param_groups
: partition of the set of parameters; given parameterst1
,t2
, ..., one may set:-
param_groups = [{'params': [t1, t2]}, {'params': [t3]}, {'params': [t4, t5, t6]}]
, - or simply use a predefined partition:
param_groups, name_groups = grnewt.partition.canonical(model)
;
-
-
data_loader
: loader used to compute the gradient of the loss and use it to propose a descent direction$\mathbf{u}$ ; -
hg_loader
: loader used to compute$\bar{\mathbf{H}}$ ,$\bar{\mathbf{g}}$ and$\bar{\mathbf{D}}$ in direction$\mathbf{u}$ , in order to obtain the vector of learning rates$\boldsymbol{\eta}$ ; -
period_hg
: period of updates of$\bar{\mathbf{H}}$ ,$\bar{\mathbf{g}}$ ,$\bar{\mathbf{D}}$ and$\boldsymbol{\eta}^*$ ; -
damping
: damping, or global learning rates factor$\lambda_1$ ; -
damping_int
: internal damping$\lambda_{\mathrm{int}}$ , strength of regularization of$\bar{\mathbf{H}}$ when using anisotropic Nesterov regularization; -
mom_lrs
: factor of the moving average used to update$\boldsymbol{\eta}^*$ , in order to smooth the trajectory of the computed learning rates; -
remove_negative
: set negative learning rates to zero.
We consider a loss
For a direction of descent
where
Our goal is to build a vector
The second-order Taylor approximation of
where:
Therefore, the minimum of the second-order approximation
Besides, in our method, we regularize
where
See Section 3 of the paper for a formal defintion of
Package grnewt
:
-
partition
: tools for building a partition of the set of parameters:-
canonical
: creates a per-tensor partition, -
trivial
: creates a partition with only one subset, containing all the parameters, -
wb
: creates a partition with 3 subsets: subsets of weights, subset of biases, subset of all remaining parameters;
-
-
hg.compute_Hg
: computes$\bar{\mathbf{H}}$ ,$\bar{\mathbf{g}}$ and$\mathbf{D}$ ; -
nesterov.nesterov_lrs
: computes$\boldsymbol{\eta}^*$ with the anisotropic Nesterov regularization scheme, by using$\bar{\mathbf{H}}$ ,$\bar{\mathbf{g}}$ and$\mathbf{D}$ ; -
newton_summary.NewtonSummary
: class containing the main optimizer, implementing the optimization procedure described in the paper (see Appendix F); -
util.fullbatch.fullbatch_gradient
: computes the fullbatch gradient of the loss, given a model, a loss and a dataset; useful for proposing a direction of descent$\mathbf{u}$ ; -
models
: sub-package containing usual models; -
datasets
: sub-package containing usual datasets and dataset tools.