Skip to content

tranhp98/SGDHess

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

54 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SGD with Hessian-based momentum

A Pytorch implementation of SGDHess, a SGD-based algorithm that incorporates Second-order information to expedite the training.

Quick overview

Our implementation is based on the official Pytorch implementation of SGD. Our most important modification is the addition of Hessian-vectors product (hvp[i] term) to "correct" the momentum of the gradient estimate. This "correction" would allow our algorithm to take advantage of Second-order information and make faster progress than the normal SGD.

buf.add_(hvp[i]).add_(displacement, alpha = weight_decay).mul_(momentum).add_(d_p, alpha=1 - dampening)

This Hessian-vectors product term is efficiently computed through Pytorch's automatic differentiation package.

hvp = torch.autograd.grad(outputs = grads, inputs = param, grad_outputs=vector)

Usage

Below is one example instance of SGDHess. Our recommended usage for the optimizer would be without internal gradient clipping (setting the "clip" flag of the optimizer to False) since the internal clipping could potentially slow down the progress of the optimizer.

from SGDHess import SGDHess
optimizer = SGDHess(net.parameters(), lr = 0.05, momentum = 0.9, clip = False)

To run the optimizer, we need to specify the flag create_graph = True when we call loss.backward(). This flag would tell autograd to construct a derivative graph, allowing us to compute higher order derivative.

loss.backward(create_graph = True)

Experiments

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages