In [1]:
import numpy as np
from RK_bin import *
import matplotlib.pyplot as plt
import matplotlib as mpl

from IPython.display import clear_output
%load_ext autoreload
%autoreload 2

In [2]:
#mpl.use("pgf")

mpl.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'font.size' : 8,
    'text.usetex': True,
    'pgf.rcfonts': False,
})

In [3]:
d = 100
n = 1000000

np.random.seed(0)

U,_ = np.linalg.qr(np.random.randn(n,d))
V,_ = np.linalg.qr(np.random.randn(d,d))

In [4]:
np.random.seed(0)

κ_kmax = [(100,200),(30,110),(6,45)]
exprs = []
for κ,k_max in κ_kmax:

    expr1 = {
        'name' : 'fast exponential',
        'Λ': 1+np.linspace(0,1,d)*(κ-1)*.1**(d-np.arange(d)-1),
        'k_max': k_max,
    }

    expr2 = {
        'name' : 'slow exponential',
        'Λ': 1+np.linspace(0,1,d)*(κ-1)*.8**(d-np.arange(d)-1),
        'k_max': k_max,
    }

    expr3 = {
        'name' : 'fast algebraic',
        'Λ': np.linspace(0,1,d)**2*(κ-1)+1,
        'k_max': k_max,
    }

    expr4 = {
        'name' : 'slow algebraic',
        'Λ': np.linspace(0,1,d)**1*(κ-1)+1,
        'k_max': k_max,
    }

    exprs += [expr1,expr2,expr3,expr4]

In [5]:
err_HBM = []
err_HBM_mb_all = []
params = []

n_trials = 100

for i,expr in enumerate(exprs):
    
    Σ = np.diag(np.sqrt(expr['Λ']))
    k_max = expr['k_max']
    
    A = U@Σ@V
    x_opt = np.random.randn(d)
    b = A@x_opt

    params.append(get_params(A,c=1e3))
    λ,ℓ,L,κ,κ_,α,β,η,κC = params[i]
     
    B = getB(None,α,β,λ,κC,mode='approx')
    if B>1e6:
        print(f'B too large')
        B=int(1e6)
    
    x_HBM = HBM(A,b,k_max,α,β)
    err_HBM.append(np.linalg.norm(x_HBM - x_opt[None,:],axis=1))
    
    B_scales = [1e0,1e-1,1e-2,1e-3]
    err_HBM_mb = []
    for scale in B_scales:
        
        err_HBM_mb_expr = np.zeros((n_trials,k_max))
        err_HBM_mb_unif_expr = np.zeros((n_trials,k_max))
        np.random.seed(0)
        for j in range(n_trials):
            
            print(f'{i}: {scale}, {int(B*scale)}, trial {j}/{n_trials}')
            clear_output(wait=True)
    
            x_HBM_mb = minibatch_HBM(A,b,k_max,α,β,int(B*scale),sampling='row_norm')
            err_HBM_mb_expr[j] = np.linalg.norm(x_HBM_mb - x_opt[None,:],axis=1)
        
        err_HBM_mb.append(err_HBM_mb_expr)
        
    err_HBM_mb_all.append(err_HBM_mb)

11: 0.001, 178, trial 99/100


In [6]:
np.save('data/B_dependence.npy',[err_HBM,err_HBM_mb_all],allow_pickle=True)
[err_HBM,err_HBM_mb_all] = np.load('data/B_dependence.npy',allow_pickle=True)

In [7]:
colors = ['#1f4c84','#800000','#976100','#007f00']
line_styles = ['--','-.',(0, (4, 2, 1, 1, 1, 2)),':']

In [23]:
fig,axs = plt.subplots(4,3,figsize=(8,7),sharex='col',sharey=True)
axs = axs.T.flatten()

plt.subplots_adjust(wspace=.05,hspace=.3)
plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=.095)

for i,expr in enumerate(exprs):
    
    name = expr['name']
    k_max = expr['k_max']
    
    λ,ℓ,L,κ,κ_,α,β,η,κC = params[i]
            
    rate = np.sqrt(2)*κC*np.sqrt(β)**np.arange(k_max)
#    axs[i].plot(np.arange(k_max),rate,\
#               color='k',ls=':',lw=1)
    
    B = getB(None,α,β,λ,κC,mode='approx')
    B = min(B,1e6)

    
    axs[i].plot(np.arange(k_max),err_HBM[i]/err_HBM[i][0],label='HBM'\
                ,color='k',ls='-',lw=1)

    for j,scale in enumerate(B_scales):
        
        σ = .1
        
        median = np.quantile(err_HBM_mb_all[i][j]/err_HBM_mb_all[i][j][0,0],.5,axis=0)
        upper = np.quantile(err_HBM_mb_all[i][j]/err_HBM_mb_all[i][j][0,0],1.-σ,axis=0)
        lower = np.quantile(err_HBM_mb_all[i][j]/err_HBM_mb_all[i][j][0,0],σ,axis=0)
    
        axs[i].plot(np.arange(k_max),median,\
                color=colors[j],ls=line_styles[j],label=f'$B={scale}B^*$')
        axs[i].fill_between(np.arange(k_max),lower,upper,alpha=.15,\
                color=colors[j],ls=line_styles[j])

    axs[i].set_title(f'$\\kappa = {κ:1.0f}$, $\\bar{{\\kappa}} = {κ_:1.2f}$, $B^*={B}$')
    
    if i%4==3:
        axs[i].set_xlabel('iteration: $k$')

    axs[i].set_yscale('log')
    
axs[0].set_ylim(1e-16,1e2)

#axs[0].legend()
axs[7].legend(loc='upper center', bbox_to_anchor=(.5,-.33), ncol=7)

#plt.show()
plt.savefig('imgs/B_dependence.pdf')
plt.close()