# Testing out the ID algoritihm of Bhaskara et al.
Stephen Becker, March 6 2023

"Residual Based Sampling for Online Low Rank Approximation", by Aditya Bhaskara, Silvio Lattanzi, Sergei Vassilvitskii and Morteza Zadimoghaddam. FOCS, 2019

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import h5py
from numpy.random import default_rng
import scipy.linalg as sli
import scipy.io as sio

%reload_ext autoreload
%autoreload 2
import rID

We'll make a Python "class" for stored columns for the ID.  You can add columns to the class (and their index), and the class will maintain a QR decomposition of the columns

This particular class has three "lengths" associated with it:
1. `k_max` is the maximum number of columns
2. `k_current` is the current number of columns that have been added to it
3. `k_previous` <= `k_current` is the number of columns in the "previous" set (using Bhaskara et al.'s notation). By default, projection is done with respect to the "previous" set
    - Calling the method `.mergePreviousWithCurrent()` will include *all* the columns into the "previous" set

In [198]:
class stored_columns:
    def __init__(self, k = np.Inf):
        self.indices = []   # indices of the columns; of size k_current
        self.k_max   = k    # upper limit on # of columns
        self.k_current  = 0
        self.k_previous = 0 # <= k_current
        self.columns = []   # the columns
        self.Q = []         # orthonormal version of the columns
    def mergePreviousWithCurrent(self):
        self.k_previous = self.k_current

    def addColumn(self, newCol, newInd):
        """ adds a column (and its index) to the set"""
        if self.k_current >= self.k_max:
            # do nothing. Should we give a warning??
            return
        elif self.k_current == 0:
            # No need to orthogonalize it since nothing yet in the set. Just normalize it
            self.columns = np.reshape(newCol, (-1,1) ) # ensure it is column vector of size (n,1), not size (n,)
            self.Q = np.reshape( newCol / sli.norm(newCol), (-1,1) )
        else:
            # Update QR: orthogonalize x <- x - QQ'x, and then normalize
            q            = newCol - self.project(newCol, full=True)
            self.Q       = np.column_stack( (self.Q, q/sli.norm(q)) )
            self.columns = np.column_stack( (self.columns, newCol ) )
        self.indices.append( newInd )
        self.k_current += 1
    
    # Todo: allow adding multiple columns at once: do X - project, then do QR 
    # Todo: allow a re-orthogonalization routine to account for loss in precision
    
    def project(self, X, justQtX = False, full=False):
        """ Projects X into the range of chosen columns 
        Implicitly restricts to just "k_previous"
        This does Q*Q'*X, but if justQtX is True, then does Q'*X (saves some time if you just need the Euclidean norm)
        """
        if full == False:
            k = self.k_previous
        else:
            k = self.k_current
        if k==0:
            if justQtX:
                return np.zeros( size=(1,X.shape[1]) ) # not sure what a good shape is here...
            else:
                return 0*X
        QtX = self.Q[:,:k].T @ X
        if justQtX:
            return QtX  # for efficient norm computations
        # if QtX.ndim == 0:
        #     return QtX*self.Q[:,:k] # issues with scalar
        # else:
        return self.Q[:,:k]@( QtX )


#### Load dataset, make error metric

In [118]:
A    = rID.load_JHTDB_data(which_component="x",nsample=64,data_name="channel")
nrmX = sli.norm(A) # Frobenius norm
relError = lambda S : np.sqrt( 1 - ( sli.norm( S.project(A,justQtX=True) )/nrmX )**2 )

# Another equivalent (slower) way to compute error:
# C, residues, rnk, singA = sli.lstsq( S.columns, A)
# print( np.sqrt(np.sum(residues))/nrmX ) # same as sli.norm( S.columns@C - A )/nrmX

#### Algorithm #3 from the Bhaskara et al. paper

In [203]:
def rID_BhaskaraAlgo3(A, k=10, xi=0.05, rng=default_rng(), SA = None):
    """
    Description: Randomized ID using residual based CSS, solve least square problem to get coefficient of the new columns
    Follows Algorithm 3 of Bhaskara et al. "Residual Based Sampling for Online Low Rank Approximation" (FOCS, 2019)

    A is a m x n matrix, where we "stream" over the columns
    k is the number of columns to save
    xi is an target value of the Frobenius norm error (*absolute* error, not relative)

    Optional: SA is a l x n matrix, a sketched version of A, which we do assume we can store in memory
      (For now, we don't use that... later, when we start making *coefficients*, we will use that)

    March 2023
    """
    m, n = np.shape(A)
    sigma = 0.
    S     = stored_columns(k)

    for column_index in range(n):
        u = A[:,column_index]

        u_perp  = u - S.project(u)
        pu = (k/160./xi) * sli.norm(u_perp)**2
            
        prob_sample = np.minimum(pu, 1)
        roll = rng.random()
        if roll <= prob_sample: 
            S.addColumn( u, column_index )
        if pu < 1:
            sigma += pu
        if pu >= 1 or sigma >= 1:
            sigma = 0
            S.mergePreviousWithCurrent()
    S.mergePreviousWithCurrent()
    return S

#### Run their algorithm:

In [201]:
m,n = A.shape
print(m,n)
k = 10
S = rID_BhaskaraAlgo3(A, k=k, xi=10)
# S.columns
print( S.indices )

# Find optimal error (would require a second pass through data)
print(f"Relative Frobenius norm error is {100*relError(S):.2f}%" )
# S.Q

4096 1000
[0, 58, 123, 182, 276, 292, 322, 341, 405, 468]
Relative Frobenius norm error is 1.71%


#### Sanity check: choose columns at random
Is this better, worse, or about the same as the fancy algorithm?

Conclusion: about the same, or even slightly better

(re-run this several times, since the results are random)

In [202]:
rng=default_rng()
randInd = rng.choice( n, k)
S_rand     = stored_columns(k)
for i in randInd:
    S_rand.addColumn( A[:,i], i )
S_rand.mergePreviousWithCurrent()

# S.columns
print( S_rand.indices )

# Find optimal error (would require a second pass through data)
print(f"Relative Frobenius norm error is {100*relError(S_rand):.2f}%" )
# S.Q

[977, 259, 844, 358, 263, 62, 53, 319, 430, 334]
Relative Frobenius norm error is 1.76%
