In [1]:
import random
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.func import vmap
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import math

In [2]:
# set seeds
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [3]:
class SketchySGD(Optimizer):
    """Implements SketchySGD. We assume that there is only one parameter group to optimize.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        rank (int): sketch rank
        rho (float): regularization
        lr (float): learning rate
        weight_decay (float): weight decay parameter
        hes_update_freq (int): how frequently we should update the Hessian approximation
        momentum (float): momentum parameter
        proportional (bool): option to maintain lr to rho ratio, even when lr decays
        chunk_size (int): number of Hessian-vector products to compute in parallel
                          if set to None, binary search will be used to find the maximally allowed value
    """
    def __init__(self, params, rank = 100, rho = 0.1, lr = 0.01, weight_decay = 0.0,
                 hes_update_freq = 100, momentum = 0.0, proportional = False, 
                 chunk_size = None, line_search_fn = None, verbose = False):
        # initialize the optimizer    
        defaults = dict(rank = rank, rho = rho, lr = lr, weight_decay = weight_decay, 
                        hes_update_freq = hes_update_freq, proportional = proportional,
                        chunk_size = chunk_size, momentum = momentum, line_search_fn = line_search_fn)
        self.rank = rank
        self.hes_update_freq = hes_update_freq
        self.proportional = proportional
        self.chunk_size = chunk_size
        self.ratio = rho / lr
        self.hes_iter = 0
        self.U = None
        self.S = None
        self.counter = 0
        self.momentum = momentum
        self.momentum_buffer = None
        self.line_search_fn = line_search_fn
        self.large_step_size_index = []
        self.verbose = verbose
        super(SketchySGD, self).__init__(params, defaults)
    
    def step(self, closure = None):
        loss = None
        grad_tuple = None
        if closure is not None:
            with torch.enable_grad():
                loss, grad_tuple = closure()

        # update Hessian approximation, if needed
        g = torch.cat([gradient.view(-1) for gradient in grad_tuple if gradient is not None])
        if self.hes_iter % self.hes_update_freq == 0:
            params = []

            for group in self.param_groups:
                for p in group['params']:
                    params.append(p)

            # update preconditioner
            self._update_preconditioner(params, g)

        g = g.detach()

        # update momentum buffer
        if self.momentum_buffer is None: 
            self.momentum_buffer = g
        else:
            self.momentum_buffer = self.momentum * self.momentum_buffer + g

        # one step update
        for group_idx, group in enumerate(self.param_groups):
            lr = group['lr']
            weight_decay = group['weight_decay']

            # Adjust rho to be proportional to lr, if necessary
            if self.proportional:
                rho = lr * self.ratio
            else:
                rho = group['rho']

            # compute gradient as a long vector
            # g = torch.cat([p.grad.view(-1) for p in group['params'] if p.grad is not None]) # only get gradients if they exist!
            # calculate the search direction by Nystrom sketch and solve
            UTg = torch.mv(self.U.t(), self.momentum_buffer) 
            g_new = torch.mv(self.U, (self.S + rho).reciprocal() * UTg) + self.momentum_buffer / rho - torch.mv(self.U, UTg) / rho
            
            # use backtracking line search to find an appropriate step-size
            step_size = lr
            direction = g_new.detach().neg()
            if self.line_search_fn is not None: 
                # possibly reduce the initial step size for the very first iteration
                if self.hes_iter == 0: 
                    step_size = min(1., 1. / g.abs().sum()) * lr
                # get a copy of current param values (as evaluating loss requires overwriting current param)
                current_params = self._clone_param(group_idx)
                # compute the dot product of gradient
                grad_dir_prod = g.dot(direction).item()
                # define the objective/loss evaluation function
                def obj_func(current_params, step_size, direction): 
                    # set new param values
                    self._add_grad(group_idx, step_size, direction)
                    # obtain objective/loss
                    obj_value, grad = closure()
                    flat_grad = torch.cat([gradient.view(-1) for gradient in grad if gradient is not None]).detach()
                    # revert back to the original/current param values
                    self._set_param(group_idx, current_params)

                    # return obj_value
                    return float(obj_value), flat_grad
                # search for the sufficient decrease step-size
                if self.line_search_fn == "backtracking": 
                    step_size = self._backtracking(obj_func, current_params, step_size, direction, loss.item(), grad_dir_prod, use_interpolation=True)
                elif self.line_search_fn == "strong_wolfe": 
                    step_size = self._strong_wolfe(obj_func, current_params, step_size, direction, loss.item(), g, grad_dir_prod)
                else: 
                    raise Exception(f'Line search function \"{self.line_search_fn}\" is not supported.')
            # store step-size in state dict
            self.state[group_idx]['step_size'] = step_size
            # update model parameters with either fixed or found step-size
            self._add_grad(group_idx, step_size, direction)
        
        self.hes_iter += 1

        return loss
    
    def _update_preconditioner(self, params, gradsH):
        p = gradsH.shape[0]
        # Generate test matrix (NOTE: This is transposed test matrix)
        Phi = (torch.randn(self.rank, p) / (p ** 0.5)).to(params[0].device)
        
        if self.chunk_size is None: 
            self._set_chunk_size(params, gradsH, Phi)

        # Calculate sketch (NOTE: This is transposed sketch)
        Y = self._hvp_vmap(gradsH, params)(Phi)

        # Calculate shift
        shift = torch.finfo(Y.dtype).eps * 10
        Y_shifted = Y + shift * Phi
        # Calculate Phi^T * H * Phi (w/ shift) for Cholesky
        choleskytarget = torch.mm(Y_shifted, Phi.t())
        # Perform Cholesky, if fails, do eigendecomposition
        # The new shift is the abs of smallest eigenvalue (negative) plus the original shift
        try:
            C = torch.linalg.cholesky(choleskytarget)
        except:
            # eigendecomposition, eigenvalues and eigenvector matrix
            eigs, eigvectors = torch.linalg.eigh(choleskytarget)
            shift = shift + torch.abs(torch.min(eigs))
            # add shift to eigenvalues
            eigs = eigs + shift
            # print(eigs)
            # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T 
            C = torch.linalg.cholesky(torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)))

        try: 
            B = torch.linalg.solve_triangular(C, Y_shifted, upper = False, left = True)
        # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211
        except: 
            B = torch.linalg.solve_triangular(C.to('cpu'), Y_shifted.to('cpu'), upper = False, left = True).to(C.device)
        _, S, UT = torch.linalg.svd(B, full_matrices = False) # B = V * S * U^T b/c we have been using transposed sketch
        self.U = UT.t()
        self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0))
        
        if self.verbose: 
            # print low-rank Hessian approximation (without rho)
            print(f'Hessian Approximation: {torch.mm(torch.mm(self.U, torch.diag(self.S)), self.U.t())}')

    def _hvp_vmap(self, grad_params, params):
        return vmap(lambda v: hvp(grad_params, params, v), in_dims = 0, chunk_size=self.chunk_size)

    """
    Helper function for finding the maximally allowed chunck_size. 

    INPUT: 
    - params: ...
    - gradsH: ...
    - Phi: ...
    - safety_margin: float; free / total GPU memory ratio -- if the free memory is lower than the margin, 
                     then this suggests actual chunk size should be set at a multiplicative factor of the found value
    - safety_margin_factor: integer; multiplicative factor to use when the free memory is low
    """
    def _set_chunk_size(self, params, gradsH, Phi, safety_margin=0.05, safety_margin_factor=0.95): 
        # start with the rank
        self.chunk_size = self.rank
        # set bounds for the search
        max_size = self.rank
        min_size = 1
        while(True): 
            # update lower bound if attempted computation was successful
            try: 
                self._hvp_vmap(gradsH, params)(Phi)
                min_size = self.chunk_size
                # search range has converged to a single point
                if max_size - min_size <= 1: 
                    # grab memory information
                    free_mem, total_mem = torch.cuda.mem_get_info()
                    if free_mem / total_mem < safety_margin: 
                        min_size = int(safety_margin_factor * min_size)
                    # create some safety margin (e.g. 95% of the found size)
                    self.chunk_size = max(1, min_size)
                    torch.cuda.empty_cache()
                    break
            # update upper bound if attempted computation ran out of memory
            except RuntimeError as e:
                if str(e).startswith('CUDA out of memory.') and self.chunk_size > 1:
                    max_size = self.chunk_size
                    torch.cuda.empty_cache()
                # terminate if other runtime error occurred or chunk_size = 1 still ran out of memory
                else: 
                    raise e
            # halve the search range
            self.chunk_size = int(0.5 * (min_size + max_size))
        # report final chunk size
        print(f'SketchySGD: chunk size has been set to {self.chunk_size}.')

    """
    Helper function for performing the backtracking line search (Armijo rule).

    INPUT: 
    - current_params: Tensor; long flattened vector of the model parameters
    - step_size: Tensor (of size 1); initial step-size (learning rate) for the step
    - direction: Tensor; long flattened vector of the update direction
    - obj_value: Tensor (of size 1); objective value (loss) at the current value of the parameters
    - grad_dir_prod: Tensor (of size 1); result of the dot product of the gradient of the objective function and direction
    - c1: float; constant used in the evaluation of the sufficient decrease condition
    - alpha: float; constant multiplicative factor used to decrease step-size after each unsuccessful search
    - max_ls: integer; maximum number of line searches
    OUTPUT: 
    - step_size: Tensor; resulting step-size from the search
    """
    def _backtracking(self, obj_func, current_params, step_size, direction, obj_value, grad_dir_prod, c1=1e-4, alpha=0.5, max_ls=30, use_interpolation=True): 
        # compute objective function with the initial step-size
        obj_value_new = obj_func(current_params, step_size, direction)[0]
        # start line search
        ls_iter = 0
        while ls_iter < max_ls:
            # evaluate the sufficient decrease condition
            if obj_value_new > (obj_value + c1 * step_size * grad_dir_prod): 
                # find step size using quadratic interpolation
                if use_interpolation: 
                    step_size = self._quadratic_interpolate(obj_value, obj_value_new, grad_dir_prod, step_size)
                # decrease step-size by a constant multiplicative factor alpha
                else: 
                    step_size = alpha * step_size        
                # compute objective function with the new step-size
                obj_value_new = obj_func(current_params, step_size, direction)[0]
                ls_iter += 1
            # otherwise the condition is satisfied
            else: 
                break

        return step_size

    """
    Helper function for making a copy of the params in the given param group. 

    INPUT: 
    - group_idx: integer; index of the param group
    OUTPUT: 
    - step_size: list
    """
    def _clone_param(self, group_idx):
        return [p.clone(memory_format=torch.contiguous_format) for p in self.param_groups[group_idx]['params']]

    """
    Helper function for updating param values in the given param group.
    Specifically, x <- x + step_size * (update + weight_decay + x)
    Here we use the weight decay, and it is not the same as L2 regularization.

    INPUT: 
    - group_idx: integer; index of the param group
    - step_size: float; step-size
    - update: list; long Tensor representing the update direction
    """
    def _add_grad(self, group_idx, step_size, update):
        weight_decay = self.param_groups[group_idx]['weight_decay']
        offset = 0
        for p in self.param_groups[group_idx]['params']:
            numel = p.numel()
            p.data.add_(update[offset:offset + numel].view_as(p) + weight_decay * p.data, alpha=step_size)
            offset += numel

    """
    Helper function for assigning value to params in the given param group.
    Specifically, x <- params_data.

    INPUT: 
    - group_idx: integer; index of the param group
    - params_data: list; long Tensor representing the value to assign
    """
    def _set_param(self, group_idx, params_data):
        for p, pdata in zip(self.param_groups[group_idx]['params'], params_data):
            p.data.copy_(pdata)
    
        """
    Helper function for performing quadratic interpolation. 
    Specifically, compute x_hat_min = (-g1 * (x ** 2)) / (2 * (f2 - f1 - g1 * x)) > 0
    If x_hat_min < x_lower_bound, then x_new = x_lower_bound
    If x_lower_bound <= x_hat_min <= x_upper_bound, then x_new = x_hat_min
    If x_hat_min > x_upper_bound, then x_new = x_upper_bound

    INPUT: 
    - f1: float; function value at initial x, f(x_0)
    - f2: float; function value at current x, f(x)
    - g1: float; gradient at initial x, grad(f(x_0))
    - x: float; current x
    - bounds: tuple of floats; lower and upper bounds of the new x
    OUTPUT: 
    - x_new: float; value of the x in the bound that minimizes the interpolation
    """
    def _quadratic_interpolate(self, f1, f2, g1, x, bounds=None): 
        if bounds is not None:
            x_lower_bound, x_upper_bound = bounds
        else: 
            x_lower_bound = 0.5 * x
            x_upper_bound = 0.95 * x

        x_hat_min = (-g1 * (x ** 2)) / (2 * (f2 - f1 - g1 * x))

        return min(max(x_hat_min, x_lower_bound), x_upper_bound)
    
    def _cubic_interpolate(self, f1, f2, g1, g2, x1, x2, bounds=None): 
        if bounds is not None:
            xmin_bound, xmax_bound = bounds
        else:
            xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

        d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
        d2_square = d1**2 - g1 * g2
            
        if d2_square >= 0:
            d2 = d2_square.sqrt()
            if x1 <= x2:
                min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
            else:
                min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
            return min(max(min_pos, xmin_bound), xmax_bound)
        else:
            return (xmin_bound + xmax_bound) / 2.

    def _strong_wolfe(self, obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=30): 
        d_norm = d.abs().max()
        g = g.clone(memory_format=torch.contiguous_format)
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals = 1
        gtd_new = g_new.dot(d)
        
        if math.isinf(f_new) or math.isnan(f_new) or torch.isinf(gtd_new) or torch.isnan(gtd_new): 
            self.large_step_size_index.append(self.hes_iter)
            t = t / g.abs().sum()
            f_new, g_new = obj_func(x, t, d)
            gtd_new = g_new.dot(d)

        t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
        done = False
        ls_iter = 0
        while ls_iter < max_ls:
            if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
                bracket = [t_prev, t]
                bracket_f = [f_prev, f_new]
                bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
                bracket_gtd = [gtd_prev, gtd_new]
                break

            if abs(gtd_new) <= -c2 * gtd:
                bracket = [t]
                bracket_f = [f_new]
                bracket_g = [g_new]
                done = True
                break

            if gtd_new >= 0:
                bracket = [t_prev, t]
                bracket_f = [f_prev, f_new]
                bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
                bracket_gtd = [gtd_prev, gtd_new]
                break

            min_step = t + 0.01 * (t - t_prev)
            max_step = t * 10
            tmp = t
            t = self._cubic_interpolate(
                t_prev,
                f_prev,
                gtd_prev,
                t,
                f_new,
                gtd_new,
                bounds=(min_step, max_step))

            t_prev = tmp
            f_prev = f_new
            g_prev = g_new.clone(memory_format=torch.contiguous_format)
            gtd_prev = gtd_new
            f_new, g_new = obj_func(x, t, d)
            ls_func_evals += 1
            gtd_new = g_new.dot(d)
            ls_iter += 1

        if ls_iter == max_ls:
            bracket = [0, t]
            bracket_f = [f, f_new]
            bracket_g = [g, g_new]

        insuf_progress = False
        low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
        while not done and ls_iter < max_ls:
            if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: break

            t = self._cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
                               bracket[1], bracket_f[1], bracket_gtd[1])

            eps = 0.1 * (max(bracket) - min(bracket))
            if min(max(bracket) - t, t - min(bracket)) < eps:
                # interpolation close to boundary
                if insuf_progress or t >= max(bracket) or t <= min(bracket):
                    # evaluate at 0.1 away from boundary
                    if abs(t - max(bracket)) < abs(t - min(bracket)):
                        t = max(bracket) - eps
                    else:
                        t = min(bracket) + eps
                    insuf_progress = False
                else:
                    insuf_progress = True
            else:
                insuf_progress = False

            f_new, g_new = obj_func(x, t, d)
            ls_func_evals += 1
            gtd_new = g_new.dot(d)
            ls_iter += 1

            if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
                bracket[high_pos] = t
                bracket_f[high_pos] = f_new
                bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
                bracket_gtd[high_pos] = gtd_new
                low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
            else:
                if abs(gtd_new) <= -c2 * gtd:
                    done = True
                elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
                    # old high becomes new low
                    bracket[high_pos] = bracket[low_pos]
                    bracket_f[high_pos] = bracket_f[low_pos]
                    bracket_g[high_pos] = bracket_g[low_pos]
                    bracket_gtd[high_pos] = bracket_gtd[low_pos]

                # new point becomes new low
                bracket[low_pos] = t
                bracket_f[low_pos] = f_new
                bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
                bracket_gtd[low_pos] = gtd_new

        t = bracket[low_pos]
        f_new = bracket_f[low_pos]
        g_new = bracket_g[low_pos]
        
        return t

def hvp(grad_params, params, v):
    Hv = torch.autograd.grad(grad_params, params, grad_outputs = v,
                              retain_graph = True)
    Hv = tuple(Hvi.detach() for Hvi in Hv)
    return torch.cat([Hvi.reshape(-1) for Hvi in Hv])

def group_product(xs, ys):
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])

def normalize(v):
    s = torch.sqrt(group_product(v, v))
    v = [x / (s + 1e-6) for x in v]
    return v

In [4]:
class LSQ(torch.nn.Module):

    def __init__(self, n_features):
        super(LSQ, self).__init__()
        self.w = torch.nn.Linear(n_features, 1, bias=False)

    def forward(self, x):
        return self.w(x)

In [5]:
# define experiment parameters
n_train = 5000
n_test = 500
n_features = 100
n_iters = 50

weight = np.random.normal(size=n_features)

Xtrain = np.random.normal(size = (n_train, n_features))
ytrain = (Xtrain @ weight)[: , np.newaxis]

Xtest = np.sort(np.random.normal(size = (n_test, n_features)))
ytest = (Xtest @ weight)[: , np.newaxis]

print(f'True Hessian: {Xtrain.T @ Xtrain / n_train}')

True Hessian: [[ 0.98961967  0.00881748 -0.0118267  ... -0.01783493  0.00830881
   0.01661189]
 [ 0.00881748  1.01844719  0.0044552  ...  0.00497486 -0.00325788
  -0.01623477]
 [-0.0118267   0.0044552   1.03757619 ... -0.00380699  0.00220931
   0.00585819]
 ...
 [-0.01783493  0.00497486 -0.00380699 ...  1.02608195  0.00346571
  -0.0090292 ]
 [ 0.00830881 -0.00325788  0.00220931 ...  0.00346571  0.95896859
  -0.02342711]
 [ 0.01661189 -0.01623477  0.00585819 ... -0.0090292  -0.02342711
   0.99756846]]


In [6]:
model = LSQ(n_features)

# specify optimizer
# optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn='strong_wolfe', lr=1.0)
optimizer = SketchySGD(model.parameters(), lr=1.0, rank=100, rho=1e-3, chunk_size=5, hes_update_freq=1, momentum=0.0, line_search_fn='strong_wolfe', verbose=True)

loss_hist = []
step_size_hist = []

Xt = torch.tensor(Xtrain, dtype=torch.float)
yt = torch.tensor(ytrain, dtype=torch.float)

torch.nn.init.zeros_(model.w.weight)

loss_function = nn.MSELoss()

for i in range(n_iters):
    model.train()
    
    def closure(): 
        optimizer.zero_grad()
        output = model(Xt)
        loss = 0.5 * loss_function(output, yt)
        if isinstance(optimizer, SketchySGD): 
            grad_tuple = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            return loss, grad_tuple   
        loss.backward()
        return loss
    
    optimizer.step(closure)
    
    # record step size taken (if using linear search)
    cur_step_size = None
    if isinstance(optimizer, SketchySGD) and optimizer.state_dict()['param_groups'][0]['line_search_fn'] is not None: 
        cur_step_size = optimizer.state_dict()['state'][0]['step_size']
    if isinstance(optimizer, torch.optim.LBFGS) and optimizer.state_dict()['param_groups'][0]['line_search_fn'] is not None: 
        cur_step_size = optimizer.state_dict()['state'][0]['t']
    
    model.eval()
    output = model(Xt)
    loss = 0.5 * loss_function(output, yt).item()
    loss_hist.append(loss)
    if cur_step_size is not None: 
        step_size_hist.append(cur_step_size)

Hessian Approximation: tensor([[ 0.9888,  0.0089, -0.0130,  ..., -0.0182,  0.0089,  0.0170],
        [ 0.0089,  1.0184,  0.0045,  ...,  0.0050, -0.0033, -0.0163],
        [-0.0130,  0.0045,  1.0359,  ..., -0.0043,  0.0031,  0.0064],
        ...,
        [-0.0182,  0.0050, -0.0043,  ...,  1.0259,  0.0037, -0.0088],
        [ 0.0089, -0.0033,  0.0031,  ...,  0.0037,  0.9585, -0.0237],
        [ 0.0170, -0.0163,  0.0064,  ..., -0.0088, -0.0237,  0.9974]])
Hessian Approximation: tensor([[ 0.9892,  0.0089, -0.0114,  ..., -0.0175,  0.0081,  0.0161],
        [ 0.0089,  1.0184,  0.0044,  ...,  0.0049, -0.0032, -0.0161],
        [-0.0114,  0.0044,  1.0372,  ..., -0.0041,  0.0024,  0.0063],
        ...,
        [-0.0175,  0.0049, -0.0041,  ...,  1.0258,  0.0036, -0.0086],
        [ 0.0081, -0.0032,  0.0024,  ...,  0.0036,  0.9589, -0.0237],
        [ 0.0161, -0.0161,  0.0063,  ..., -0.0086, -0.0237,  0.9969]])
Hessian Approximation: tensor([[ 0.9896,  0.0089, -0.0118,  ..., -0.0179,  0.0083,  0.

In [7]:
# compute test loss
with torch.no_grad():
    Xttest = torch.tensor(Xtest, dtype=torch.float)
    yttest = torch.tensor(ytest, dtype=torch.float)
    output = model(Xttest)
    loss =  0.5 * loss_function(output, yttest)
print(f'Test Loss: {loss}')

Test Loss: 4.632312328861632e-13


In [8]:
# make plot
set_matplotlib_formats('pdf')

fig = plt.figure()

ax1 = fig.add_subplot()
step_line = ax1.plot([float(i) for i in step_size_hist], label='step size', color='C4', alpha=0.85, marker='o', markersize=2, linestyle='solid', linewidth=0.35)
ax1.set_ylabel('step size')
ax1.set_xlabel('iteration')

ax2 = plt.twinx()
loss_line = ax2.semilogy(loss_hist, label='loss', alpha=0.5)
ax2.set_ylabel('loss')

lines = step_line + loss_line
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc='center left', bbox_to_anchor=(1.15, 0.5))
optimizer_name = "SketchySGD" if isinstance(optimizer, SketchySGD) else "L-BFGS"
ax1.set_title(f'Least Squares / {optimizer_name}')

Text(0.5, 1.0, 'Least Squares / SketchySGD')

<Figure size 432x288 with 2 Axes>

In [9]:
# print out loss history
loss_hist

[49.9944953918457,
 0.00011550551425898448,
 2.5517304935718244e-10,
 4.060598428681783e-13,
 3.969303184747641e-13,
 3.968911787763374e-13,
 3.9634615034422893e-13,
 3.954164469813226e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92796960322489e-13,
 3.92