# *AI-SARAH (Algorithm 2)* in the main paper

In [None]:
import os
import sys

path = os.getcwd()
parent_path = os.path.abspath(os.path.join(path, os.pardir))
sys.path.append(parent_path)

import random

import pprint as pp
import numpy as np
import time
import os
import shutil
from numpy import genfromtxt
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

torch.set_default_dtype(torch.float64)
torch.set_num_threads(1) #cpu num

import itertools
import numpy.linalg  as lin

import cProfile, pstats

from collections import OrderedDict

from Sparse_Init.sparseinit import *    
from Sparse_Init.sparsedata import *
from Sparse_Init.sparsemodule_v2 import * # here, for AI-SARAH, user should use version 2 implementation
from sklearn.preprocessing import normalize

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print (device)

# Configuration

In [None]:
algo = 'ai_sarah' # algorithm
dname = 'news20' # dataset name
BS = 64 # mini-batch size
StrongConvex = True # L2 regularization
if StrongConvex:
    case = 'reg'
else:
    case = 'non_reg'

# Load data - user need to download datasets from LIBSVM# generate data

In [None]:
# specify data directory
datafolder = '../Data/'+dname+'/' # please download libsvm dataset to this folder before executing this code
# Specify directory to save log files - optional
logfolder = '../Logs/'+dname+'/'+case+'/'+algo+'/'
# to compare with other algorithms that have all hyper-parameters, please use the following log folders
# logfolder = '../AllLogs/'+dname+'/'+case+'/'+algo+'/'

if not os.path.exists(logfolder):
    os.makedirs(logfolder)
    

# dataset files - need to be downloaded from LIBSVM website
if dname == 'covtype':
    file = datafolder+'covtype.libsvm.binary.scale.bz2'
    
if dname == 'ijcnn1':
    trfile = datafolder+'ijcnn1.bz2'
    tefile = datafolder+'ijcnn1.t.bz2'
    
if dname == 'rcv1':
    trfile = datafolder+'rcv1_train.binary.bz2'
    tefile = datafolder+'rcv1_test.binary.bz2'
    
if dname == 'news20':
    file = datafolder+'news20.binary.bz2'
    
if dname == 'real-sim':
    file = datafolder+'real-sim.bz2'
    
    
try:
    data = SparseData(dname,device,file=file)
    csr = data.read()
    normalize(csr[0],copy=False)
    data.load(_csr=csr)
except:
    data = SparseData(dname,device,trfile=trfile,tefile=tefile)
    train_csr, test_csr = data.read()
    normalize(train_csr[0],copy=False)
    normalize(test_csr[0],copy=False)
    data.load(_trainCSR=train_csr,_testCSR=test_csr)
print(data)

In [None]:
if StrongConvex:
    lam = 1/data.trSize

#### a. experiment setup 

In [None]:
# 10 random seeds
SEED = [0,1,2,3,4,5,6,7,8,9]

#### b. parameters 

In [None]:
# running budget
if case=='reg':
    if dname =='rcv1':
        TotalEP = 30.0
    if dname =='ijcnn1':
        TotalEP = 20.0
    if dname =='news20':
        TotalEP = 40.0
    if dname =='covtype':
        TotalEP = 20.0
    if dname =='real-sim':
        TotalEP = 20.0
        
if case=='non_reg':
    if dname =='rcv1':
        TotalEP = 40.0
    if dname =='ijcnn1':
        TotalEP = 20.0
    if dname =='news20':
        TotalEP = 50.0
    if dname =='covtype':
        TotalEP = 20.0
    if dname =='real-sim':
        TotalEP = 30.0
        
perEpoch = data.trSize//BS

#### c. run 

In [None]:
for seed in SEED:
    timer=[] # timer
    
    run_status = logfolder+'RUN-seed-%s/'%seed
    done_status = logfolder+'DONE-seed-%s/'%seed
    savefile = logfolder+'seed-%s.tar'%seed
    
    if os.path.exists(run_status) or os.path.exists(done_status) or os.path.exists(savefile):
        print(done_status)
        continue
    else:
        os.makedirs(run_status)
    print('======\nseed - %s\n======'%seed)  
    
    # results
    HIST=[]
    STAT=[]
    ALPHA=[]
    alpha_max = np.inf # initial bound for first iteration - no upper bound
    alpha = 0.0 # initial alpha - no initial step-size
    TIME = time.time()# total run timer
    
    # initialize random stream
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # define one layer model for linear model with logistic regression
    if StrongConvex:
        # L2 regularized case
        model = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True)
        prev_net = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True)
    else:
        # un-regularized case
        model = ConvexModel(data.num_feature,data.num_label)
        prev_net = ConvexModel(data.num_feature,data.num_label)
      
    # push model to GPU for non-leaf variable
    model.to(device)
    prev_net.to(device)
    
    # for weight wrt features in testing but not in training dataset, set them to ZERO
    if len(data.in_te_not_tr)>0:
        model.del_in_te_not_tr(data.in_te_not_tr)
        prev_net.del_in_te_not_tr(data.in_te_not_tr)

    for wi,pi in zip(model.parameters(),prev_net.parameters()):
        with torch.no_grad():
            pi.set_(wi+0.0)   
    
    allSamples = list(range(data.trSize))
        
    # intialize counter    
    ep=0.0 # count effective pass 
    innerT=0 # count inner iterations
    outerT=0 # count outer iterations
    odmT=0 # count one-dim-minimization iterations
    st=0 # mini-batch loop counter
    # initialize stopping flag
    converge=False
    fatal=False
    # epoch time - time for one epoch
    epoch_time = time.time()
    
    alpha = 0.0
    t=0
    
    outer_record=True # for print&save log purpose
    while ep<=TotalEP+1:
        
        if converge or fatal: 
            break
          
        # compute batch loss,grad,test
        Loss, V = prev_net.LossGrad(data)
        Grad = np.sum([(gi.data**2).sum().item() for gi in V])
        Test = prev_net.ComputeAccuracy(data)
        
        if outer_record:
            timeT = time.time() - epoch_time
            if ep==0 and t==0:
                timeT=0.0
            HIST.append([ep,Loss,Grad,Test])
            STAT.append([ep,outerT,innerT,timeT,1])
            print('outer-ep: %.2f, alpha: %.4f, max: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\
                  %(ep,alpha,alpha_max,Loss,Grad,Test,timeT,t))
            epoch_time = time.time()
 
        normV0 = Grad
        normVold = Grad
        
        ep+=1.0
        outerT+=data.trSize//BS
        
        if np.isnan(Loss) or np.isnan(Grad) or np.isnan(Test):
            fatal = True
        if Grad < 1e-15:
            converge=True
                    
        # initialize inner loop
        t=0
        normV = np.inf
        outer_record=True
        # inner loop
        while ep<=TotalEP+1:
            
            if fatal or converge or normV<normV0.item()/32.0:
                break
            
            # random mini-batch
            st=st%perEpoch
            if st==0:
                np.random.shuffle(allSamples)
            if st==perEpoch-1:
                sample = allSamples[st*BS:]
            else:
                sample = allSamples[st*BS:(st+1)*BS]
                
            x_sample,y_sample = data.mb(sample)
            # compute sample grad: g0
            _,g0 = prev_net.LossGrad(data,sample=sample)  
                        
            # intialize implicit method
            ti=-1
            alpha = 0.0 # starting point
            newtonD = 0.0
            # destruct computing graph
            for w in model.parameters():
                w.detach_()
            # no loop version - only one iterration on newton    
            alpha = torch.tensor(alpha - newtonD,requires_grad=True).to(device)
            # update model parameter to potential w_1
            for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):
                with torch.no_grad():
                    wi.set_(pi+0.0) # no gradient in operation
                wi.sub_(alpha*vi) # with gradient 
            # compute sample grad: g1
            _,g1 = model.LossGrad(data,sample=sample,second_order=True)
            Vtemp = [g1i - g0i + vi for g1i,g0i,vi in zip(g1,g0,V)]
            normVtemp = torch.stack([(vi**2).sum() for vi in Vtemp]).sum()
            # compute 1st/2nd derivative of alpha
            alphaGrad = torch.autograd.grad(normVtemp,alpha,create_graph=True)
            alphaHess = torch.autograd.grad(alphaGrad,alpha)
            # newton direction
            newtonD = 1.0/np.abs(alphaHess[0].item())*alphaGrad[0].item() # abs(hessian)
            GradHess = np.abs(newtonD)
            alpha = alpha.item()
            improvement = normVtemp.item()/normVold
            ti+=1
            odmT+=1
            alpha=alpha-newtonD

            # update alpha for current iterate
            alpha_newton = alpha
            alpha = min(alpha_newton,alpha_max)
            
            # update upper bound for next iterate
            if ep==1.0 and t==0:
                delta = 1.0/alpha_newton
                alpha_max = alpha_newton
            else:
                delta = 0.999*delta + 0.001*(1.0/alpha_newton)
                alpha_max = 1.0/delta
                                
            # update iterate - wt
            for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):
                with torch.no_grad():
                    wi.set_(pi-alpha*vi)
            _,g1 = model.LossGrad(data,sample=sample)
            # update V_t, squared norm of V_t
            V = [g1i.data - g0i.data + vi.data for g1i,g0i,vi in zip(g1,g0,V)]
            normV = np.sum([(vi.data**2).sum().item() for vi in V]) 
            # update ratio
            improvement = normV/normVold
            
            # update for next inner iteration
            normVold = normV
            for wi,pi in zip(model.parameters(),prev_net.parameters()):
                with torch.no_grad():
                    pi.set_(wi+0.0)
                        
            st+=1 # sample counter
            t+=1 # count inner iteration
            innerT+=1
            ep+=BS/data.trSize # count effective pass
            
            inner_record=False
            if (t-1)%perEpoch==0:
                timeT = time.time()-epoch_time
                inner_record=True
                Lossprint, Vprint = prev_net.LossGrad(data)
                Gradprint = np.sum([(gi.data**2).sum().item() for gi in Vprint])
                Testprint = prev_net.ComputeAccuracy(data)

                HIST.append([ep,Lossprint,Gradprint,Testprint])
                STAT.append([ep,outerT,innerT,timeT,0])
                print('inner-ep: %.2f, alpha: %.4f, max: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\
                      %(ep,alpha,alpha_max,Lossprint,Gradprint,Testprint,timeT,t))
                
                epoch_time = time.time()
            
        if inner_record:
            outer_record=False  
            
    TIME = time.time() - TIME # total running time per run
    
    RESULTS = OrderedDict()
    RESULTS = {
        'parm': [BS,seed],
        'end': [Loss,Grad,Test,TIME,converge,fatal],
        'hist': HIST,
        'stat': STAT
    }
    torch.save(RESULTS,savefile)
    
    # update running status
    if os.path.exists(run_status):
        os.rmdir(run_status)
    if not os.path.exists(done_status):
        os.mkdir(done_status)

In [None]:
exit(0)