In [1]:
import random
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import ortho_group
import matplotlib.pyplot as plt

In [2]:
# set seeds
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [3]:
import torch
from torch.optim import Optimizer
from torch.func import vmap
from functools import reduce


def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5):
    f0 = f(x, 0, dx)
    f1 = f(x, t, dx)
    while f1 > f0 + alpha * t * gx.dot(dx):
        t *= beta
        f1 = f(x, t, dx)
    return t


def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x):
    z = U.T @ x
    z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z)
    return z


def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters):
    lambd_r = S[r - 1]
    S_mu_inv = (S + mu) ** (-1)

    resid = b - (hess(x) + mu * x)
    z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
    p = z.clone()

    i = 0

    while torch.norm(resid) > tol and i < max_iters:
        v = hess(p) + mu * p
        alpha = torch.dot(resid, z) / torch.dot(p, v)
        x += alpha * p

        rTz = torch.dot(resid, z)
        resid -= alpha * v
        z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
        beta = torch.dot(resid, z) / rTz

        p = z + beta * p

        if i % 100 == 0:
            print(
                f"PCG iteration {i} complete. Residual norm = {torch.norm(resid)}")
        i += 1

    if torch.norm(resid) > tol:
        print(
            f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}")

    return x


class NysNewtonCG(Optimizer):
    def __init__(self, params, lr=1.0, rank=10, mu=1e-4, chunk_size=1,
                 cg_tol=1e-16, cg_max_iters=1000, line_search_fn=None, verbose=False):
        defaults = dict(lr=lr, rank=rank, chunk_size=chunk_size, mu=mu, cg_tol=cg_tol,
                        cg_max_iters=cg_max_iters, line_search_fn=line_search_fn)
        self.rank = rank
        self.mu = mu
        self.chunk_size = chunk_size
        self.cg_tol = cg_tol
        self.cg_max_iters = cg_max_iters
        self.line_search_fn = line_search_fn
        self.verbose = verbose
        self.U = None
        self.S = None
        self.n_iters = 0
        super(NysNewtonCG, self).__init__(params, defaults)

        if len(self.param_groups) > 1:
            raise ValueError(
                "NysNewtonCG doesn't currently support per-parameter options (parameter groups)")

        if self.line_search_fn is not None and self.line_search_fn != 'armijo':
            raise ValueError("NysNewtonCG only supports Armijo line search")

        self._params = self.param_groups[0]['params']
        self._params_list = list(self._params)
        self._numel_cache = None

    def step(self, closure=None):
        if self.n_iters == 0:
            # Store the previous direction for warm starting PCG
            self.old_dir = torch.zeros(
                self._numel(), device=self._params[0].device)

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss, grad_tuple = closure()

        g = torch.cat([grad.view(-1)
                      for grad in grad_tuple if grad is not None])

        # g = torch.cat([p.grad.view(-1)
        #               for group in self.param_groups for p in group['params'] if p.grad is not None])
        # g = g.detach()

        # one step update
        for group_idx, group in enumerate(self.param_groups):
            def hvp_temp(x):
                return self._hvp(g, self._params_list, x)

            # Calculate the Newton direction
            d = _nystrom_pcg(hvp_temp, g, self.old_dir,
                             self.mu, self.U, self.S, self.rank, self.cg_tol, self.cg_max_iters)

            # Store the previous direction for warm starting PCG
            self.old_dir = d

            # Check if d is a descent direction
            if torch.dot(d, g) <= 0:
                print("Warning: d is not a descent direction")

            if self.line_search_fn == 'armijo':
                x_init = self._clone_param()

                def obj_func(x, t, dx):
                    self._add_grad(t, dx)
                    loss = float(closure()[0])
                    self._set_param(x)
                    return loss

                # Use -d for convention
                t = _armijo(obj_func, x_init, g, -d, group['lr'])
            else:
                t = group['lr']

            self.state[group_idx]['t'] = t

            # update parameters
            ls = 0
            for p in group['params']:
                np = torch.numel(p)
                dp = d[ls:ls+np].view(p.shape)
                ls += np
                if p.grad is None:
                    continue
                p.data.add_(-dp, alpha=t)

        self.n_iters += 1

        return loss, g

    def update_preconditioner(self, grad_tuple):
        gradsH = torch.cat([gradient.view(-1)
                           for gradient in grad_tuple if gradient is not None])

        p = gradsH.shape[0]
        # Generate test matrix (NOTE: This is transposed test matrix)
        Phi = torch.randn(
            (self.rank, p), device=self._params_list[0].device) / (p ** 0.5)
        Phi = torch.linalg.qr(Phi.t(), mode='reduced')[0].t()

        Y = self._hvp_vmap(gradsH, self._params_list)(Phi)

        # Calculate shift
        shift = torch.finfo(Y.dtype).eps
        Y_shifted = Y + shift * Phi
        # Calculate Phi^T * H * Phi (w/ shift) for Cholesky
        choleskytarget = torch.mm(Y_shifted, Phi.t())
        # Perform Cholesky, if fails, do eigendecomposition
        # The new shift is the abs of smallest eigenvalue (negative) plus the original shift
        try:
            C = torch.linalg.cholesky(choleskytarget)
        except:
            # eigendecomposition, eigenvalues and eigenvector matrix
            eigs, eigvectors = torch.linalg.eigh(choleskytarget)
            shift = shift + torch.abs(torch.min(eigs))
            # add shift to eigenvalues
            eigs = eigs + shift
            # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T
            C = torch.linalg.cholesky(
                torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)))

        try:
            B = torch.linalg.solve_triangular(
                C, Y_shifted, upper=False, left=True)
        # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211
        except:
            B = torch.linalg.solve_triangular(C.to('cpu'), Y_shifted.to(
                'cpu'), upper=False, left=True).to(C.device)
        # B = V * S * U^T b/c we have been using transposed sketch
        _, S, UT = torch.linalg.svd(B, full_matrices=False)
        self.U = UT.t()
        self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0))

        self.rho = self.S[-1]

        if self.verbose:
            print(f'Approximate eigenvalues = {self.S}')

    def _hvp_vmap(self, grad_params, params):
        return vmap(lambda v: self._hvp(grad_params, params, v), in_dims=0, chunk_size=self.chunk_size)

    def _hvp(self, grad_params, params, v):
        Hv = torch.autograd.grad(grad_params, params, grad_outputs=v,
                                 retain_graph=True)
        Hv = tuple(Hvi.detach() for Hvi in Hv)
        return torch.cat([Hvi.reshape(-1) for Hvi in Hv])

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(
                lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # Avoid in-place operation by creating a new tensor
            p.data = p.data.add(
                update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _clone_param(self):
        return [p.clone(memory_format=torch.contiguous_format) for p in self._params]

    def _set_param(self, params_data):
        for p, pdata in zip(self._params, params_data):
            # Replace the .data attribute of the tensor
            p.data = pdata.data

In [4]:
class LSQ(torch.nn.Module):
    def __init__(self, n_features):
        super(LSQ, self).__init__()
        self.w = torch.nn.Linear(n_features, 1, bias=False)

    def forward(self, x):
        return self.w(x)

In [5]:
# define experiment parameters
n_train = 5000
n_test = 500
n_features = 100
n_iters = 50

weight = np.random.normal(size=n_features)

# Xtrain = np.random.normal(size = (n_train, n_features))
# ytrain = (Xtrain @ weight)[: , np.newaxis]

# Xtest = np.sort(np.random.normal(size = (n_test, n_features)))
# ytest = (Xtest @ weight)[: , np.newaxis]

# Create a vector with polynomial decay starting at 1
decay = (np.arange(n_features) + 1) ** (-2.0)
decay = np.diag(decay)
U = ortho_group.rvs(n_train)[:, :n_features]
VT = ortho_group.rvs(n_features)
X = U @ decay @ VT

Xtrain = X[:n_train]
ytrain = (Xtrain @ weight)[:, np.newaxis]

# Xtest = X[n_train:]
# ytest = (Xtest @ weight)[: , np.newaxis]

print(f'True Hessian: {Xtrain.T @ Xtrain / n_train}')

True Hessian: [[5.87207932e-08 5.30202325e-07 3.84452423e-08 ... 1.77979336e-07
  5.41539004e-07 9.73623657e-08]
 [5.30202325e-07 6.54229435e-06 3.50059491e-07 ... 2.13488754e-06
  6.16464522e-06 1.13735742e-06]
 [3.84452423e-08 3.50059491e-07 8.98778386e-08 ... 1.02080514e-07
  1.62413930e-07 1.57564297e-07]
 ...
 [1.77979336e-07 2.13488754e-06 1.02080514e-07 ... 7.74132079e-07
  2.26517543e-06 3.62492561e-07]
 [5.41539004e-07 6.16464522e-06 1.62413930e-07 ... 2.26517543e-06
  7.11691785e-06 8.17371842e-07]
 [9.73623657e-08 1.13735742e-06 1.57564297e-07 ... 3.62492561e-07
  8.17371842e-07 3.56961183e-07]]


In [6]:
model = LSQ(n_features)

# specify optimizer
optimizer = NysNewtonCG(model.parameters(), lr=1.0,
                      rank=10, mu=0, cg_tol=1e-16, cg_max_iters=3000, line_search_fn='armijo')
precond_update_freq = 20

loss_hist = []
step_size_hist = []

Xt = torch.tensor(Xtrain, dtype=torch.float)
yt = torch.tensor(ytrain, dtype=torch.float)

torch.nn.init.zeros_(model.w.weight)

loss_function = nn.MSELoss()

for i in range(n_iters):
    model.train()

    # Update preconditioner for PCG
    if i % precond_update_freq == 0:
        optimizer.zero_grad()
        output = model(Xt)
        loss = loss_function(output, yt)
        grad_tuple = torch.autograd.grad(
            loss, model.parameters(), create_graph=True)
        optimizer.update_preconditioner(grad_tuple)

    # Take a step
    def closure():
        optimizer.zero_grad()
        output = model(Xt)
        loss = loss_function(output, yt)
        grad_tuple = torch.autograd.grad(
            loss, model.parameters(), create_graph=True)
        return loss, grad_tuple

    optimizer.step(closure)

    model.eval()
    output = model(Xt)
    loss = loss_function(output, yt).item()
    loss_hist.append(loss)

    step_size_hist.append(optimizer.state_dict()['state'][0]['t'])

    if i % 1 == 0:
        print(f'Iteration {i} | Loss: {loss}')
        # print(optimizer.U @ torch.diag(optimizer.S) @ optimizer.U.t())
        # print(optimizer.S)
        # print(optimizer.rho)

PCG iteration 0 complete. Residual norm = 0.00026859284844249487
PCG iteration 100 complete. Residual norm = 1.6308203498738294e-07


PCG iteration 200 complete. Residual norm = 1.747113032024572e-07
PCG iteration 300 complete. Residual norm = 3.3161501988843156e-08
PCG iteration 400 complete. Residual norm = 9.619902208157782e-09
PCG iteration 500 complete. Residual norm = 1.5657322549600394e-08
PCG iteration 600 complete. Residual norm = 1.5950593734714857e-08
PCG iteration 700 complete. Residual norm = 1.1943350308740719e-08
PCG iteration 800 complete. Residual norm = 4.711823820002792e-09
PCG iteration 900 complete. Residual norm = 5.792965551698614e-10
PCG iteration 1000 complete. Residual norm = 2.5352251409316295e-09
PCG iteration 1100 complete. Residual norm = 1.1603052962882998e-09
PCG iteration 1200 complete. Residual norm = 3.487970579030275e-09
PCG iteration 1300 complete. Residual norm = 3.8693689918822827e-10
PCG iteration 1400 complete. Residual norm = 1.1918374065444937e-09
PCG iteration 1500 complete. Residual norm = 2.864116166456654e-10
PCG iteration 1600 complete. Residual norm = 2.071150334570504