In [2]:
import torch
from torch.optim import Optimizer


class SGD(Optimizer):
    def __init__(self, params, lr=1e-2, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None

        for group in self.param_groups:
            lr, momentum, dampening, weight_decay, nesterov = group['lr'], group['momentum'], group['dampening'], group['weight_decay'], group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                if weight_decay != 0:
                    grad = grad.add(weight_decay, p.data)

                state = self.state[p]
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(p.data)

                momentum_buffer = state['momentum_buffer']
                momentum_buffer.mul_(momentum).add_(grad)

                if nesterov:
                    grad = grad.add(momentum, momentum_buffer)
                else:
                    grad = momentum_buffer

                p.data.add_(-lr, grad)

        return loss