<a href="https://colab.research.google.com/github/stephenbeckr/randomized-algorithm-class/blob/master/Demos/demo08_higherAccuracyRegression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# High-accuracy sketched least-squares

Demo of the 
1. Iterative Hessian Sketch (IHS) cf. Pilanci and Wainwright; and of the 
2. preconditioned approaches (BLENDENPIK, LSRN)

These are two methods to get high-accuracy l2 regression

The goal is to approximate the solution of
$$  \min_{x} \| Ax-b \|_2^2 $$
where $A$ is $M \times N$ and we are assuming $M \gg N$.

Code: Stephen Becker, Oct 2021

References:
- "Iterative Hessian Sketch: Fast and Accurate Solution
Approximation for Constrained Least-Squares" (Pilanci, Wainwright; JMLR 2016
http://www.jmlr.org/papers/volume17/14-460/14-460.pdf )
- "Blendenpik: Supercharging LAPACK's Least-Squares Solver" (Avron et al. 2010, https://epubs.siam.org/doi/abs/10.1137/090767911); 
- "LSRN: A Parallel Iterative Solver for Strongly Over- or Underdetermined Systems" (Meng et al. 2014, https://epubs.siam.org/doi/abs/10.1137/120866580 )

In [1]:
import numpy as np
import numpy.linalg
from numpy.linalg import norm
from numpy.random import default_rng
rng = default_rng()
from matplotlib import pyplot as plt

import scipy.linalg

# Download sketching code
!wget -q https://raw.githubusercontent.com/stephenbeckr/randomized-algorithm-class/master/Code/sketch.py
import sketch as sk

Setup some problem data

In [2]:
M, N = int(14e4), int(5e2)
# M, N = int(8e4), int(5e2)

A   = rng.standard_normal( (M,N) )@np.diag(np.logspace(0,3,N))@(
    rng.standard_normal((N,N) ) + 0.1*np.eye(N) )

x   = rng.standard_normal( (N,1) )
b   = A@x
b   += 0.3*norm(b)/np.sqrt(M)*rng.standard_normal( (M,1) ) # add noise
# (The larger the noise, the worse sketch-to-solve will perform )

#### Solve via standard direct solver, nothing randomized

In [3]:
print("Solving via classical dense method")
%time xLS, residLS, rank, singVals = np.linalg.lstsq(A,b,rcond=None)

print(f'Condition number of A is {singVals[0]/singVals[-1]:.3e}')

AxLS = A@xLS
# print(f'Relative residual ||Ax-b||/||b|| is {norm(AxLS-b)/norm(b):.2f}')
print(f'Relative residual ||Ax-b||/||b|| is {np.sqrt(residLS[0])/norm(b):.2f}')
# and use this to create error metrics
def errors(x):
  Ax  = np.ravel(A@x)

  # Careful: do the ravel() since if we try (n,) - (n,1) then numpy
  #   tries to broadcast this to something huge, and isn't what we want.
  err1 = norm( Ax-np.ravel(b) )/norm(np.ravel(AxLS)-np.ravel(b))  - 1  # error in objective value
  err2 = norm( np.ravel(x) - np.ravel(xLS) )/norm( xLS )     # error in x - xLS (relative error)
  err3 = norm( Ax-np.ravel(AxLS) )/norm(AxLS)      # error in IHS analysis
  return err1, err2, err3

Solving via classical dense method
CPU times: user 11.8 s, sys: 473 ms, total: 12.3 s
Wall time: 6.68 s
Condition number of A is 2.265e+07
Relative residual ||Ax-b||/||b|| is 0.29


#### Make some sketches

In [38]:
from scipy.sparse.linalg import LinearOperator
from scipy.linalg import clarkson_woodruff_transform
import scipy.fft
# dct = lambda X : scipy.fft.dct( X, norm='ortho', type=2, axis=0)
# Note: this is the transpose of Matlab's dct ("type 2") actually!
# Types 2 and 3 are inverses of each other
# It doesn't really matter which kind we use...

# axis=0 is VERY IMPORTANT!! otherwise default is axis=-1, the last axis
#   so then we get wrong answer when applying to a matrix
dct = lambda X : scipy.fft.dct( X, norm='ortho', type=3, axis=0)

def MyElementwiseMultiply( d, X):
  """ like d*X aka np.multiply(d,X)
  except it handles the case when d is size (n,) and X is size (n,1)
  since then naively doing d*X does an outer product since numpy doesn't
  consider (n,) and (n,1) to be the same...  but we also want to allow
  for the case when X is size (n,)
  """
  if d.ndim == X.ndim:
      # Great
      y = d*X
  elif d.ndim == 1:
      y = d.reshape(-1,1) * X
  else:
      y = d * X.reshape(-1,1)
  
  return y

def MySubsample( X, ind):
  """ like X[ind,:] but works in case X has size (n,) """
  if X.ndim == 1:
      y = X[ind]
  elif X.ndim == 2:
      y = X[ind,:]
  else:
      raise ValueError("Expected 1D or 2D array")
  return y


def FJLT(m, M, rng=np.random.default_rng() ):
  d   = np.sign( rng.standard_normal(size=M) ) #.astype( np.int64 )
  ind = rng.choice( M, size=m, replace=False, shuffle=False)

  fjltMatMat = lambda X : np.sqrt(M/m)*MySubsample( dct( MyElementwiseMultiply(d,X)) , ind)
  S   = LinearOperator( (m,M), matvec = fjltMatMat, matmat = fjltMatMat )
  return S

### Choose a sketch to use
Usually choose FJLT, but could choose Gaussian (if problem isn't too big) or CountSketch (if problem is huge)

In [6]:
%%time
m   = 40*N  # sketch size
print(f"m is {m}, M is {M}, N is {N}")

if M < 1e4 and False:
  # This runs out of memory if M is too large
  S   = sk.Gaussian( (m,M) )
  #S   = rng.standard_normal( (m,M) )/np.sqrt(m)
  SA  = S@A
  Sb  = S@b
  print('Using a Gaussian sketch')

elif False:
  # == Use a count-sketch:
  # SAb = clarkson_woodruff_transform( np.hstack( (A,b) ), m )
  # SA  = SAb[:,:-1]
  # Sb  = SAb[:,-1]
  S   = sk.Count( (m,M) )
  print('Using a Count sketch')

else:
  # == ... or try a FJLT ...
  # S   = FJLT(m,M,rng)
  S   = sk.FJLT( (m,M) )

  # SA  = S@A
  # Sb  = S@b
  print('Using a FJLT sketch')

SA  = S@A
Sb  = S@b
print(f'||Sb||/||b|| is {norm(Sb)/norm(b):.4f}')

m is 20000, M is 140000, N is 500
Using a FJLT sketch
||Sb||/||b|| is 0.9970
CPU times: user 3.26 s, sys: 335 ms, total: 3.6 s
Wall time: 3.59 s


In [7]:
def full_sketch( SA, Sb, cond=1e-12,columnVec = True):
  """ SA should be S@A and Sb should be S@b 
  Solves  min_x || S(Ax-b) ||_2 """
  # return np.linalg.lstsq(SA,Sb,rcond=None)[0]
  x = scipy.linalg.lstsq(SA,Sb,cond=cond,lapack_driver='gelsd')[0]
  if columnVec:
    return np.reshape( x, (-1,1) ) # make sure it is (n,1) not (n,)
  else:
    # it will have the same shape convention as Sb, so if Sb is (m,1)
    #   then x will be (n,1) and if Sb is (m,) then x will be (n,)
    return x

def partial_sketch(SA,Atb, printOutput=False, solver=0, reg=0,columnVec = True):
  """ SA should be S@A and Atb should be A.T@b 
  Solves min_x ||SAx||_2^2 - 2<x,A^T b>,
  i.e., x = ( (SA)^T SA )^{-1} A^T b
  
  Solver choices:
    solver=0  is using scipy.linalg.solve on (SA)^T(SA) which is fast
      but less accurate since it square the condition number of SA,
      so recommended for all but the most ill-conditioned problems.
      Set reg>0 (e.g., reg=1e-10) to add a small amount of regularization
      (relative to the largest singular value)

    solver=1  uses a pivoted QR decomposition and is more appropriate when
      the matrix is ill-conditioned, but a bit slower.  `reg` has no effect

    solver=2  uses an unpivoted QR decomposition and is a bit faster than
      solver=1.  `reg` has no effect
   """
  
  if solver == 0:
    # == Below is the basic code that fails if ill-conditioned: ==
    if reg is None or reg==0:
      x = scipy.linalg.solve(  SA.T@SA, Atb, assume_a='pos')
    else:
      # == Slightly better for ill-conditioned, still not good at all though ==
      G = SA.T@SA
      normG = norm(G,ord=2)
      if printOutput:
        print(f"||G|| is {normG:.2e} and has condition number {np.linalg.cond(G):.2e}")
      # Add in a bit of regularization:
      x = scipy.linalg.solve(  G + reg*normG*np.eye(N), Atb, assume_a='pos')
  elif solver == 1:
    # == The above still has problems when ill-conditioned. Let's do SA = QR
    # Then G = R^T R and we can do back substitution
    R, perm = scipy.linalg.qr( SA, mode='r', pivoting=True )
    R = R[:N,:] # Annoyingly, in mode='r', R is rectangular not square, but 'economic' mode is slow.

    y = scipy.linalg.solve_triangular( R, Atb[perm], trans='T')
    x = np.zeros_like(y)
    x[perm] = scipy.linalg.solve_triangular( R, y, trans='N')

  elif solver == 2:
    # == Same as solver==1 but no pivoting, and use numpy not scipy
    #  since it gives us thin factorization (but doesn't support pivoting)
    R = numpy.linalg.qr( SA, mode='r')
    y = scipy.linalg.solve_triangular( R, Atb, trans='T')
    x = scipy.linalg.solve_triangular( R, y, trans='N')

  if printOutput:
    res = norm( SA.T@(SA@x) - Atb )/norm(Atb)
    print(f'Relative residual ||(SA)^T (SA)x - A^T b||/||A^T b|| is {res:.2e}')

  if columnVec:
    return np.reshape( x, (-1,1) ) # make sure it is (n,1) not (n,)
  else:
    # it will have the same shape convention as Sb, so if Sb is (m,1)
    #   then x will be (n,1) and if Sb is (m,) then x will be (n,)
    return x

# IHS (Iterative Hessian Sketch) demo
#### Start solving regression problems with the sketches

The "full sketch" is the standard "sketch-to-solve" which is our baseline method.  We don't expect it to be that good in $\|\hat{x}-x_\text{LS}\|$ unless the data $b$ is almost entirely in the column space of $A$.

In [9]:
print(f'\nFull sketch')
%time xFull = full_sketch( SA, Sb )
err1, err2, err3 = errors(xFull)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )

print(f'\nPartial sketch')
%time xPartial = partial_sketch( SA, A.T@b, printOutput=True, solver=0)
err1, err2, err3 = errors(xPartial)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )


Full sketch
CPU times: user 1.42 s, sys: 116 ms, total: 1.53 s
Wall time: 849 ms

	Errors are 1.1e-02, 1.6e+02 and 4.4e-02

Partial sketch
Relative residual ||(SA)^T (SA)x - A^T b||/||A^T b|| is 1.08e-11
CPU times: user 429 ms, sys: 138 ms, total: 567 ms
Wall time: 302 ms

	Errors are 1.3e-01, 1.1e+02 and 1.6e-01


In [10]:
k   = 5  # number of iterations for Iterative Hessian Sketch

def IHS(k=5):
  mm  = m // k
  xHat= np.zeros((N,1))
  bHat= b.copy()  # important!!!
  print(f'Iterative Hessian Sketch, dividing {m} total rows into {k} blocks of {mm}')
  for i in range(k):
    xx = partial_sketch( np.sqrt(m/mm)*SA[i*mm:(i+1)*mm,:], A.T@bHat )
    rho = norm( A@xx-A@(xLS-xHat) )/norm(A@(xLS-xHat) )
    xHat += xx
    bHat -= A@xx
    err1, err2, err3 = errors(xHat)
    print(f'  Iter {i+1:2d}, contraction factor {rho:.2f}, errors {err1:5.2e}, {err2:5.2e}, {err3:5.2e}')
  print(f'\n\n')

IHS(1)

IHS(5)

IHS(8)

IHS(10)

IHS(20)

Iterative Hessian Sketch, dividing 20000 total rows into 1 blocks of 20000
  Iter  1, contraction factor 0.16, errors 1.29e-01, 1.14e+02, 1.56e-01



Iterative Hessian Sketch, dividing 20000 total rows into 5 blocks of 4000
  Iter  1, contraction factor 0.43, errors 7.45e-01, 4.99e+02, 4.27e-01
  Iter  2, contraction factor 0.46, errors 2.00e-01, 4.41e+02, 1.98e-01
  Iter  3, contraction factor 0.45, errors 4.29e-02, 1.13e+02, 8.84e-02
  Iter  4, contraction factor 0.44, errors 8.62e-03, 1.34e+02, 3.93e-02
  Iter  5, contraction factor 0.46, errors 1.86e-03, 4.95e+00, 1.82e-02



Iterative Hessian Sketch, dividing 20000 total rows into 8 blocks of 2500
  Iter  1, contraction factor 0.66, errors 1.43e+00, 2.35e+02, 6.63e-01
  Iter  2, contraction factor 0.65, errors 7.57e-01, 1.21e+03, 4.31e-01
  Iter  3, contraction factor 0.67, errors 3.95e-01, 6.79e+02, 2.91e-01
  Iter  4, contraction factor 0.67, errors 1.92e-01, 4.35e+01, 1.94e-01
  Iter  5, contraction factor 0.62, errors 7.82e-02



  Iter  2, contraction factor 2.29, errors 1.48e+01, 3.72e+03, 4.70e+00




  Iter  3, contraction factor 2.55, errors 3.91e+01, 2.00e+04, 1.20e+01
  Iter  4, contraction factor 2.51, errors 9.94e+01, 3.04e+03, 3.00e+01




  Iter  5, contraction factor 2.24, errors 2.24e+02, 1.57e+05, 6.73e+01
  Iter  6, contraction factor 2.35, errors 5.29e+02, 2.60e+05, 1.58e+02




  Iter  7, contraction factor 1.99, errors 1.06e+03, 6.13e+05, 3.16e+02




  Iter  8, contraction factor 2.23, errors 2.36e+03, 4.77e+05, 7.05e+02
  Iter  9, contraction factor 2.18, errors 5.14e+03, 2.47e+06, 1.54e+03




  Iter 10, contraction factor 2.21, errors 1.14e+04, 5.37e+06, 3.40e+03




  Iter 11, contraction factor 2.13, errors 2.42e+04, 1.16e+07, 7.22e+03




  Iter 12, contraction factor 2.39, errors 5.78e+04, 3.64e+07, 1.73e+04
  Iter 13, contraction factor 2.14, errors 1.24e+05, 2.89e+07, 3.70e+04




  Iter 14, contraction factor 2.06, errors 2.55e+05, 8.42e+07, 7.61e+04




  Iter 15, contraction factor 2.33, errors 5.95e+05, 5.86e+07, 1.78e+05




  Iter 16, contraction factor 2.26, errors 1.34e+06, 2.50e+08, 4.01e+05
  Iter 17, contraction factor 2.52, errors 3.38e+06, 1.57e+09, 1.01e+06




  Iter 18, contraction factor 2.13, errors 7.21e+06, 1.96e+09, 2.15e+06
  Iter 19, contraction factor 2.56, errors 1.85e+07, 2.30e+09, 5.52e+06




  Iter 20, contraction factor 2.32, errors 4.28e+07, 2.16e+10, 1.28e+07





### What happens if we re-use the same sketch in the iterative part?

Our theory doesn't hold since the problem data $b$ is no longer a constant (it's a random variable that is dependent on the sketch $S$)

But maybe it will work??

In [11]:
k   = 10  # number of iterations for Iterative Hessian Sketch
xHat= np.zeros((N,1))
bHat= b.copy()  # important!!!
print('Iterative Hessian Sketch, RE-USING OLD SKETCHES!! This is off-label usage')
for i in range(k):
  xx = partial_sketch( SA, A.T@bHat ) # full SA matrix
  rho = norm( A@xx-A@(xLS-xHat) )/norm(A@(xLS-xHat) )
  xHat += xx
  bHat -= A@xx
  bHat = b.copy() - A@xHat  # if you're worried about accumulating error
  err1, err2, err3 = errors(xHat)
  print(f'  Iter {i+1:2d}, contraction factor {rho:.2f}, errors {err1:5.2e}, {err2:5.2e}, {err3:5.2e}')

Iterative Hessian Sketch, RE-USING OLD SKETCHES!! This is off-label usage
  Iter  1, contraction factor 0.16, errors 1.29e-01, 1.14e+02, 1.56e-01
  Iter  2, contraction factor 0.24, errors 7.91e-03, 7.09e+01, 3.76e-02
  Iter  3, contraction factor 0.29, errors 6.50e-04, 1.46e+01, 1.08e-02
  Iter  4, contraction factor 0.31, errors 6.33e-05, 6.70e+00, 3.36e-03
  Iter  5, contraction factor 0.33, errors 6.76e-06, 1.63e+00, 1.10e-03
  Iter  6, contraction factor 0.34, errors 7.62e-07, 6.04e-01, 3.69e-04
  Iter  7, contraction factor 0.34, errors 8.90e-08, 1.70e-01, 1.26e-04
  Iter  8, contraction factor 0.35, errors 1.06e-08, 5.84e-02, 4.36e-05
  Iter  9, contraction factor 0.35, errors 1.30e-09, 1.80e-02, 1.52e-05
  Iter 10, contraction factor 0.35, errors 1.60e-10, 6.06e-03, 5.34e-06


# BLENDENPIK/LSRN Sketch-to-precondition

In [16]:
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator, aslinearoperator, 

%time xHat, flag, iter, nrm = lsqr( A, b, show=True, iter_lim=int(1e2))[:4]

err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )

 
LSQR            Least-squares solution of  Ax = b
The matrix A has   140000 rows  and      500 cols
damp = 0.00000000000000e+00   calc_var =        0
atol = 1.00e-08                 conlim = 1.00e+08
btol = 1.00e-08               iter_lim =      100
 
   Itn      x[0]       r1norm     r2norm   Compatible    LS      Norm A   Cond A
     0  0.00000e+00   5.197e+07  5.197e+07    1.0e+00  1.1e-01
     1  3.96508e-01   3.234e+07  3.234e+07    6.2e-01  5.5e-01   7.4e+06  1.0e+00
     2  5.37958e-01   2.411e+07  2.411e+07    4.6e-01  3.0e-01   1.0e+07  2.3e+00
     3  4.81235e-01   2.043e+07  2.043e+07    3.9e-01  1.8e-01   1.2e+07  3.7e+00
     4  4.77483e-01   1.867e+07  1.867e+07    3.6e-01  1.1e-01   1.4e+07  5.2e+00
     5  4.93412e-01   1.753e+07  1.753e+07    3.4e-01  9.6e-02   1.5e+07  7.0e+00
     6  5.73973e-01   1.677e+07  1.677e+07    3.2e-01  6.6e-02   1.6e+07  9.0e+00
     7  6.12775e-01   1.632e+07  1.632e+07    3.1e-01  6.0e-02   1.7e+07  1.1e+01
     8  6.14310e-01   1.605e

In [18]:
err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )
print( f'\tLSQR took {iter} iterations')


	Errors are 1.0e-03, 1.0e+00 and 1.3e-02


In [28]:
R = numpy.linalg.qr( SA, mode='r')
Rinv_f = lambda x : scipy.linalg.solve_triangular( R, x)
Rinv_t = lambda x : scipy.linalg.solve_triangular( R, x, trans='T')
Rinv = LinearOperator((N,N), matvec = Rinv_f, rmatvec = Rinv_t)

AR = aslinearoperator(A)@Rinv
AR.shape

(140000, 500)

In [33]:
%time zHat, flag, iter, nrm = lsqr( AR, b, show=True, atol=1e-16,btol=1e-16, iter_lim=int(1e2))[:4]
xHat = Rinv_f(zHat)

err1, err2, err3 = errors(xHat)
print( f'\n\tErrors are {err1:.1e}, {err2:.1e} and {err3:.1e}' )
print( f'\tLSQR took {iter} iterations')

 
LSQR            Least-squares solution of  Ax = b
The matrix A has   140000 rows  and      500 cols
damp = 0.00000000000000e+00   calc_var =        0
atol = 1.00e-16                 conlim = 1.00e+08
btol = 1.00e-16               iter_lim =      100
 
   Itn      x[0]       r1norm     r2norm   Compatible    LS      Norm A   Cond A
     0  0.00000e+00   5.197e+07  5.197e+07    1.0e+00  1.9e-08
     1  7.77406e+06   1.658e+07  1.658e+07    3.2e-01  4.4e-01   1.0e+00  1.0e+00
     2  8.01489e+06   1.491e+07  1.491e+07    2.9e-01  5.1e-02   1.5e+00  2.0e+00
     3  8.02932e+06   1.487e+07  1.487e+07    2.9e-01  6.1e-03   1.8e+00  3.0e+00
     4  8.03676e+06   1.487e+07  1.487e+07    2.9e-01  7.6e-04   2.1e+00  4.1e+00
     5  8.03799e+06   1.487e+07  1.487e+07    2.9e-01  1.0e-04   2.3e+00  5.1e+00
     6  8.03831e+06   1.487e+07  1.487e+07    2.9e-01  1.3e-05   2.5e+00  6.1e+00
     7  8.03833e+06   1.487e+07  1.487e+07    2.9e-01  1.8e-06   2.7e+00  7.1e+00
     8  8.03833e+06   1.487e