A Pytorch implementation of SGDHess, a SGD-based algorithm that incorporates Second-order information to expedite the training.
Our implementation is based on the official Pytorch implementation of SGD. Our most important modification is the addition of Hessian-vectors product (hvp[i] term) to "correct" the momentum of the gradient estimate. This "correction" would allow our algorithm to take advantage of Second-order information and make faster progress than the normal SGD.
buf.add_(hvp[i]).add_(displacement, alpha = weight_decay).mul_(momentum).add_(d_p, alpha=1 - dampening)
This Hessian-vectors product term is efficiently computed through Pytorch's automatic differentiation package.
hvp = torch.autograd.grad(outputs = grads, inputs = param, grad_outputs=vector)
Below is one example instance of SGDHess. Our recommended usage for the optimizer would be without internal gradient clipping (setting the "clip" flag of the optimizer to False) since the internal clipping could potentially slow down the progress of the optimizer.
from SGDHess import SGDHess
optimizer = SGDHess(net.parameters(), lr = 0.05, momentum = 0.9, clip = False)
To run the optimizer, we need to specify the flag create_graph = True when we call loss.backward(). This flag would tell autograd to construct a derivative graph, allowing us to compute higher order derivative.
loss.backward(create_graph = True)
- Here we include our modified scripts for our experiments. To run the experiments, fork the orginal repo (Imagenet - https://github.com/pytorch/examples/tree/master/imagenet; Cifar10 - https://github.com/amirgholami/adahessian/tree/master/image_classification; Fairseq - https://github.com/pytorch/fairseq), follow the instruction and replace the files included with the files of the same names in the original repo.