In [3]:
import torch.optim as optim
import torch
import math

class Adam(optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(Adam, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None

        for group in self.param_groups:
            lr, betas, eps, weight_decay = group['lr'], group['betas'], group['eps'], group['weight_decay']
            beta1, beta2 = betas

            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad.data

                if weight_decay != 0:
                    grad = grad.add(weight_decay, param.data)

                state = self.state[param]

                if len(state) == 0:
                    state['step'] = 0
                    state['first_moment'] = torch.zeros_like(param.data)
                    state['second_moment'] = torch.zeros_like(param.data)

                first_moment, second_moment = state['first_moment'], state['second_moment']
                state['step'] += 1

                first_moment.mul_(beta1).add_(1 - beta1, grad)
                second_moment.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                momentum_bias_correction1 = 1 - beta1 ** state['step']
                momentum_bias_correction2 = 1 - beta2 ** state['step']

                step_size = lr * math.sqrt(momentum_bias_correction2) / momentum_bias_correction1

                denom = second_moment.sqrt().add_(eps)
                param.data.addcdiv_(-step_size, first_moment, denom)

        return loss
