<a href="https://colab.research.google.com/github/stephenbeckr/randomized-algorithm-class/blob/master/Demos/demo08_higherAccuracyRegression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# High-accuracy sketched least-squares

Demo of the 
1. Iterative Hessian Sketch (IHS) cf. Pilanci and Wainwright; and of the 
2. preconditioned approaches (BLENDENPIK, LSRN)

These are two methods to get high-accuracy l2 regression

The goal is to approximate the solution of
$$  \min_{x} \| Ax-b \|_2^2 $$
where $A$ is $M \times N$ and we are assuming $M \gg N$.

Code: Stephen Becker, Oct 2021

References:
- "Iterative Hessian Sketch: Fast and Accurate Solution
Approximation for Constrained Least-Squares" (Pilanci, Wainwright; JMLR 2016
http://www.jmlr.org/papers/volume17/14-460/14-460.pdf )
- "Blendenpik: Supercharging LAPACK's Least-Squares Solver" (Avron et al. 2010, https://epubs.siam.org/doi/abs/10.1137/090767911); 
- "LSRN: A Parallel Iterative Solver for Strongly Over- or Underdetermined Systems" (Meng et al. 2014, https://epubs.siam.org/doi/abs/10.1137/120866580 )

In [None]:
import numpy as np
import numpy.linalg
from numpy.linalg import norm
from numpy.random import default_rng
rng = default_rng()
from matplotlib import pyplot as plt

import scipy.linalg

# Download sketching code
!wget -q https://raw.githubusercontent.com/stephenbeckr/randomized-algorithm-class/master/Code/sketch.py
import sketch as sk

Setup some problem data

In [None]:
M, N = int(14e4), int(5e2)
# M, N = int(8e4), int(5e2)

A   = rng.standard_normal( (M,N) )@np.diag(np.logspace(0,3,N))@(
    rng.standard_normal((N,N) ) + 0.1*np.eye(N) )

x   = rng.standard_normal( (N,1) )
b   = A@x
b   += 0.3*norm(b)/np.sqrt(M)*rng.standard_normal( (M,1) ) # add noise
# (The larger the noise, the worse sketch-to-solve will perform )

#### Solve via standard direct solver, nothing randomized

In [None]:
print("Solving via classical dense method")
%time xLS, residLS, rank, singVals = np.linalg.lstsq(A,b,rcond=None)

print(f'Condition number of A is {singVals[0]/singVals[-1]:.3e}')

AxLS = A@xLS
# print(f'Relative residual ||Ax-b||/||b|| is {norm(AxLS-b)/norm(b):.2f}')
print(f'Relative residual ||Ax-b||/||b|| is {np.sqrt(residLS[0])/norm(b):.2f}')
# and use this to create error metrics
def errors(x):
  Ax  = np.ravel(A@x)

  # Careful: do the ravel() since if we try (n,) - (n,1) then numpy
  #   tries to broadcast this to something huge, and isn't what we want.
  err1 = norm( Ax-np.ravel(b) )/norm(np.ravel(AxLS)-np.ravel(b))  - 1  # error in objective value
  err2 = norm( np.ravel(x) - np.ravel(xLS) )/norm( xLS )     # error in x - xLS (relative error)
  err3 = norm( Ax-np.ravel(AxLS) )/norm(AxLS)      # error in IHS analysis
  return err1, err2, err3

Solving via classical dense method
CPU times: user 13 s, sys: 539 ms, total: 13.5 s
Wall time: 7.36 s
Condition number of A is 2.811e+05
Relative residual ||Ax-b||/||b|| is 0.29


### Choose a sketch to use
Usually choose FJLT, but could choose Gaussian (if problem isn't too big) or CountSketch (if problem is huge)

In [None]:
%%time
m   = 40*N  # sketch size
print(f"m is {m}, M is {M}, N is {N}")

if M < 1e4 and False:
  # This runs out of memory if M is too large
  S   = sk.Gaussian( (m,M) )
  print('Using a Gaussian sketch')

elif False:
  # == Use a count-sketch:
  S   = sk.Count( (m,M) )
  print('Using a Count sketch')

else:
  # == ... or try a FJLT ...
  S   = sk.FJLT( (m,M) )
  print('Using a FJLT sketch')

SA  = S@A
Sb  = S@b
print(f'||Sb||/||b|| is {norm(Sb)/norm(b):.4f}')

m is 20000, M is 140000, N is 500
Using a FJLT sketch
||Sb||/||b|| is 1.0031
CPU times: user 3.43 s, sys: 210 ms, total: 3.64 s
Wall time: 3.65 s


In [None]:
def full_sketch( SA, Sb, cond=1e-12,columnVec = True):
  """ SA should be S@A and Sb should be S@b 
  Solves  min_x || S(Ax-b) ||_2 """
  # return np.linalg.lstsq(SA,Sb,rcond=None)[0]
  x = scipy.linalg.lstsq(SA,Sb,cond=cond,lapack_driver='gelsd')[0]
  if columnVec:
    return np.reshape( x, (-1,1) ) # make sure it is (n,1) not (n,)
  else:
    # it will have the same shape convention as Sb, so if Sb is (m,1)
    #   then x will be (n,1) and if Sb is (m,) then x will be (n,)
    return x

def partial_sketch(SA,Atb, printOutput=False, solver=0, reg=0,columnVec = True):
  """ SA should be S@A and Atb should be A.T@b 
  Solves min_x ||SAx||_2^2 - 2<x,A^T b>,
  i.e., x = ( (SA)^T SA )^{-1} A^T b
  
  Solver choices:
    solver=0  is using scipy.linalg.solve on (SA)^T(SA) which is fast
      but less accurate since it square the condition number of SA,
      so recommended for all but the most ill-conditioned problems.
      Set reg>0 (e.g., reg=1e-10) to add a small amount of regularization
      (relative to the largest singular value)

    solver=1  uses a pivoted QR decomposition and is more appropriate when
      the matrix is ill-conditioned, but a bit slower.  `reg` has no effect

    solver=2  uses an unpivoted QR decomposition and is a bit faster than
      solver=1.  `reg` has no effect
   """
  
  if solver == 0:
    # == Below is the basic code that fails if ill-conditioned: ==
    if reg is None or reg==0:
      x = scipy.linalg.solve(  SA.T@SA, Atb, assume_a='pos')
    else:
      # == Slightly better for ill-conditioned, still not good at all though ==
      G = SA.T@SA
      normG = norm(G,ord=2)
      if printOutput:
        print(f"||G|| is {normG:.2e} and has condition number {np.linalg.cond(G):.2e}")
      # Add in a bit of regularization:
      x = scipy.linalg.solve(  G + reg*normG*np.eye(N), Atb, assume_a='pos')
  elif solver == 1:
    # == The above still has problems when ill-conditioned. Let's do SA = QR
    # Then G = R^T R and we can do back substitution
    R, perm = scipy.linalg.qr( SA, mode='r', pivoting=True )
    R = R[:N,:] # Annoyingly, in mode='r', R is rectangular not square, but 'economic' mode is slow.

    y = scipy.linalg.solve_triangular( R, Atb[perm], trans='T')
    x = np.zeros_like(y)
    x[perm] = scipy.linalg.solve_triangular( R, y, trans='N')

  elif solver == 2:
    # == Same as solver==1 but no pivoting, and use numpy not scipy
    #  since it gives us thin factorization (but doesn't support pivoting)
    R = numpy.linalg.qr( SA, mode='r')
    y = scipy.linalg.solve_triangular( R, Atb, trans='T')
    x = scipy.linalg.solve_triangular( R, y, trans='N')

  if printOutput:
    res = norm( SA.T@(SA@x) - Atb )/norm(Atb)
    print(f'Relative residual ||(SA)^T (SA)x - A^T b||/||A^T b|| is {res:.2e}')

  if columnVec:
    return np.reshape( x, (-1,1) ) # make sure it is (n,1) not (n,)
  else:
    # it will have the same shape convention as Sb, so if Sb is (m,1)
    #   then x will be (n,1) and if Sb is (m,) then x will be (n,)
    return x

# IHS (Iterative Hessian Sketch) demo
#### Start solving regression problems with the sketches

The "full sketch" is the standard "sketch-to-solve" which is our baseline method.  We don't expect it to be that good in $\|\hat{x}-x_\text{LS}\|$ unless the data $b$ is almost entirely in the column space of $A$.

In [None]:
print(f'\nFull sketch')
%time xFull = full_sketch( SA, Sb )
err1, err2, err3 = errors(xFull)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )

print(f'\nPartial sketch')
%time xPartial = partial_sketch( SA, A.T@b, printOutput=True, solver=0)
err1, err2, err3 = errors(xPartial)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )


Full sketch
CPU times: user 1.55 s, sys: 133 ms, total: 1.68 s
Wall time: 962 ms

	Errors are 1.1e-02, 6.4e+00 and 4.4e-02

Partial sketch
Relative residual ||(SA)^T (SA)x - A^T b||/||A^T b|| is 1.59e-13
CPU times: user 491 ms, sys: 140 ms, total: 631 ms
Wall time: 339 ms

	Errors are 1.3e-01, 1.8e+01 and 1.6e-01


In [None]:
k   = 5  # number of iterations for Iterative Hessian Sketch

def IHS(k=5):
  mm  = m // k
  xHat= np.zeros((N,1))
  bHat= b.copy()  # important!!!
  print(f'Iterative Hessian Sketch, dividing {m} total rows into {k} blocks of {mm}')
  for i in range(k):
    xx = partial_sketch( np.sqrt(m/mm)*SA[i*mm:(i+1)*mm,:], A.T@bHat )
    rho = norm( A@xx-A@(xLS-xHat) )/norm(A@(xLS-xHat) )
    xHat += xx
    bHat -= A@xx
    err1, err2, err3 = errors(xHat)
    print(f'  Iter {i+1:2d}, contraction factor {rho:.2f}, errors {err1:5.2e}, {err2:5.2e}, {err3:5.2e}')
  print(f'\n\n')

IHS(1)

IHS(5)

IHS(8)

IHS(10)

IHS(20)

Iterative Hessian Sketch, dividing 20000 total rows into 1 blocks of 20000
  Iter  1, contraction factor 0.16, errors 1.29e-01, 1.81e+01, 1.57e-01



Iterative Hessian Sketch, dividing 20000 total rows into 5 blocks of 4000
  Iter  1, contraction factor 0.47, errors 8.74e-01, 3.00e+01, 4.74e-01
  Iter  2, contraction factor 0.44, errors 2.15e-01, 7.34e+01, 2.07e-01
  Iter  3, contraction factor 0.46, errors 4.86e-02, 9.08e+00, 9.44e-02
  Iter  4, contraction factor 0.39, errors 7.58e-03, 8.74e+00, 3.69e-02
  Iter  5, contraction factor 0.45, errors 1.52e-03, 4.96e+00, 1.65e-02



Iterative Hessian Sketch, dividing 20000 total rows into 8 blocks of 2500
  Iter  1, contraction factor 0.69, errors 1.52e+00, 1.06e+02, 6.94e-01
  Iter  2, contraction factor 0.67, errors 8.42e-01, 2.79e+01, 4.63e-01
  Iter  3, contraction factor 0.59, errors 3.59e-01, 2.90e+01, 2.75e-01
  Iter  4, contraction factor 0.71, errors 1.95e-01, 5.53e+01, 1.96e-01
  Iter  5, contraction factor 0.63, errors 8.28e-02

### What happens if we re-use the same sketch in the iterative part?

Our theory doesn't hold since the problem data $b$ is no longer a constant (it's a random variable that is dependent on the sketch $S$)

But maybe it will work??
- actually, this idea (or a variant) is in [Faster Least Squares Optimization
](https://arxiv.org/abs/1911.02675) by Lacotte and Pilanci, 2019
- See also this journal version [Optimal Randomized First-Order Methods for Least-Squares Problems](http://proceedings.mlr.press/v119/lacotte20a.html) by Lacotte and Pilanci, ICML 2020

In [None]:
k   = 10  # number of iterations for Iterative Hessian Sketch
xHat= np.zeros((N,1))
bHat= b.copy()  # important!!!
print('Iterative Hessian Sketch, RE-USING OLD SKETCHES!! This is off-label usage')
for i in range(k):
  xx = partial_sketch( SA, A.T@bHat ) # full SA matrix
  rho = norm( A@xx-A@(xLS-xHat) )/norm(A@(xLS-xHat) )
  xHat += xx
  bHat -= A@xx
  bHat = b.copy() - A@xHat  # if you're worried about accumulating error
  err1, err2, err3 = errors(xHat)
  print(f'  Iter {i+1:2d}, contraction factor {rho:.2f}, errors {err1:5.2e}, {err2:5.2e}, {err3:5.2e}')

Iterative Hessian Sketch, RE-USING OLD SKETCHES!! This is off-label usage
  Iter  1, contraction factor 0.16, errors 1.29e-01, 1.81e+01, 1.57e-01
  Iter  2, contraction factor 0.24, errors 7.94e-03, 3.35e+00, 3.78e-02
  Iter  3, contraction factor 0.28, errors 6.37e-04, 1.01e+00, 1.07e-02
  Iter  4, contraction factor 0.31, errors 6.01e-05, 3.40e-01, 3.28e-03
  Iter  5, contraction factor 0.32, errors 6.24e-06, 1.25e-01, 1.06e-03
  Iter  6, contraction factor 0.33, errors 6.91e-07, 4.42e-02, 3.52e-04
  Iter  7, contraction factor 0.34, errors 7.97e-08, 1.66e-02, 1.20e-04
  Iter  8, contraction factor 0.34, errors 9.48e-09, 5.96e-03, 4.12e-05
  Iter  9, contraction factor 0.35, errors 1.15e-09, 2.19e-03, 1.44e-05
  Iter 10, contraction factor 0.35, errors 1.43e-10, 7.89e-04, 5.06e-06


# BLENDENPIK/LSRN Sketch-to-precondition

Let's start by using a standard linear solver for least squares, [`lsqr`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html)

In [None]:
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator, aslinearoperator

%time xHat, flag, iter, nrm = lsqr( A, b, show=True, iter_lim=int(1e2))[:4]

err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )

 
LSQR            Least-squares solution of  Ax = b
The matrix A has   140000 rows  and      500 cols
damp = 0.00000000000000e+00   calc_var =        0
atol = 1.00e-08                 conlim = 1.00e+08
btol = 1.00e-08               iter_lim =      100
 
   Itn      x[0]       r1norm     r2norm   Compatible    LS      Norm A   Cond A
     0  0.00000e+00   4.841e+07  4.841e+07    1.0e+00  1.2e-01
     1  2.63044e-01   2.663e+07  2.663e+07    5.5e-01  5.3e-01   7.1e+06  1.0e+00
     2  2.76601e-01   2.120e+07  2.120e+07    4.4e-01  2.6e-01   9.7e+06  2.2e+00
     3  3.78793e-01   1.844e+07  1.844e+07    3.8e-01  1.9e-01   1.1e+07  3.6e+00
     4  4.53280e-01   1.716e+07  1.716e+07    3.5e-01  1.2e-01   1.3e+07  5.2e+00
     5  4.84351e-01   1.626e+07  1.626e+07    3.4e-01  1.0e-01   1.5e+07  7.0e+00
     6  4.62640e-01   1.560e+07  1.560e+07    3.2e-01  7.2e-02   1.7e+07  9.1e+00
     7  4.91588e-01   1.516e+07  1.516e+07    3.1e-01  4.7e-02   1.8e+07  1.1e+01
     8  4.97109e-01   1.493e

Now let's precondition.  We use the `R` from the thin `QR` decomposition of the *sketched* matrix $SA$.

Then, we want to solve the system
$$
\min_z || AR^{-1}z - b ||^2
$$
where we've done the change-of-variables $x=R^{-1}z$
so after solving the system for $z$, we do one final conversion back to $x$.

We need to give `scipy` a linear operator that can multiply $x\mapsto AR^{-1}x$, which is easy using the `LinearOperator` class.

In [None]:
%time R = numpy.linalg.qr( SA, mode='r')
Rinv_f = lambda x : scipy.linalg.solve_triangular( R, x)
Rinv_t = lambda x : scipy.linalg.solve_triangular( R, x, trans='T')
Rinv = LinearOperator((N,N), matvec = Rinv_f, rmatvec = Rinv_t)

AR = aslinearoperator(A)@Rinv
AR.shape

CPU times: user 1.34 s, sys: 206 ms, total: 1.55 s
Wall time: 880 ms


(140000, 500)

### Solving the preconditioned system
Now we solve via `lsqr` and see if it converges more quickly

In [None]:
%time zHat, flag, iter, nrm = lsqr( AR, b, show=True, atol=1e-16,btol=1e-16, iter_lim=10)[:4]
xHat = Rinv_f(zHat)

err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )
print( f'\tLSQR took {iter} iterations')

 
LSQR            Least-squares solution of  Ax = b
The matrix A has   140000 rows  and      500 cols
damp = 0.00000000000000e+00   calc_var =        0
atol = 1.00e-16                 conlim = 1.00e+08
btol = 1.00e-16               iter_lim =       10
 
   Itn      x[0]       r1norm     r2norm   Compatible    LS      Norm A   Cond A
     0  0.00000e+00   4.841e+07  4.841e+07    1.0e+00  2.0e-08
     1  5.95586e+06   1.545e+07  1.545e+07    3.2e-01  4.4e-01   1.0e+00  1.0e+00
     2  5.88502e+06   1.392e+07  1.392e+07    2.9e-01  5.2e-02   1.5e+00  2.0e+00
     3  5.91503e+06   1.388e+07  1.388e+07    2.9e-01  5.8e-03   1.8e+00  3.0e+00
     4  5.92222e+06   1.388e+07  1.388e+07    2.9e-01  7.1e-04   2.1e+00  4.1e+00
     5  5.92247e+06   1.388e+07  1.388e+07    2.9e-01  9.3e-05   2.3e+00  5.1e+00
     6  5.92267e+06   1.388e+07  1.388e+07    2.9e-01  1.3e-05   2.5e+00  6.1e+00
     7  5.92271e+06   1.388e+07  1.388e+07    2.9e-01  1.7e-06   2.7e+00  7.1e+00
     8  5.92271e+06   1.388e

In [None]:
# Find the condition number. This may be slow...
AR_explicit = AR@np.eye(N)
cnd         = np.linalg.cond( AR_explicit )
print(f'Condition number of AR^{-1} is {cnd:.2e}')

Condition number of AR^-1 is 1.34e+00


### Repeat for using the Count Sketch
Let's see how fast we are

In [None]:
%%time 
S   = sk.Count( (m,M) )
R   = numpy.linalg.qr( S@A, mode='r')

Rinv_f = lambda x : scipy.linalg.solve_triangular( R, x)
Rinv_t = lambda x : scipy.linalg.solve_triangular( R, x, trans='T')
Rinv   = LinearOperator((N,N), matvec = Rinv_f, rmatvec = Rinv_t)
AR     = aslinearoperator(A)@Rinv


zHat, flag, iter, nrm = lsqr( AR, b, show=False,iter_lim=7)[:4]
xHat = Rinv_f(zHat)

err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )
print( f'\tLSQR took {iter} iterations')


	Errors are 4.2e-11, 7.0e-04 and 2.7e-06
	LSQR took 7 iterations
CPU times: user 3.16 s, sys: 176 ms, total: 3.33 s
Wall time: 1.92 s
