#Setup

In [None]:
from google.colab import drive
from google.colab import files
import sys
import time

drive.mount('/content/gdrive/', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'Colab Notebooks/MS Thesis/PSGD_PyTorch/PSGD_MNIST/'
results_dir = base_dir + 'results3_MNIST_tanh/'
logs_dir = base_dir + 'log'
sys.path.append(base_dir)
import preconditioned_stochastic_gradient_descent as psgd 

Mounted at /content/gdrive/


In [None]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from hessian_torch import Hessian
import numpy as np
import time
import math

from tabulate import tabulate
import scipy.io
from sklearn import metrics
import plotly.express as px
from torchsummary import summary


# Functions

In [None]:
def plot_loss_metrics(xaxis,yaxis,title, x_label,y_label):
 
  fig = go.Figure()
  i = 0
  if(xaxis != None):
    for opt in opts:
      fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
      i = i + 1
  else:
    for opt in opts:
      fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
      i = i + 1

  fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis_type="log")
  fig.show()
  fig.write_html(results_dir + title + ".html")

def plot_acc_metrics(xaxis,yaxis,title, x_label,y_label):
 
  fig = go.Figure()
  i = 0
  if(xaxis != None):
    for opt in opts:
      fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
      i = i + 1
  else:
    for opt in opts:
      fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
      i = i + 1

  fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis=dict(range=[0.97, 1]))
  fig.show()
  fig.write_html(results_dir + title + ".html")

def plot_spectrum(loader, model, model_parameters, num_classes):
  total_params = sum(w.shape[0]*w.shape[1] for w in model_parameters)
  Hess = Hessian(loader=train_loader, model=model, model_parameters = model_parameters, hessian_type='Hessian')
  Hess_eigval, Hess_eigval_density = Hess.LanczosApproxSpec(init_poly_deg=64, poly_deg=256)     
  Hess.vecs, Hess.vals, _ = Hess.SubspaceIteration(num_classes, iters=128)
  Hess_bulk_eigvals, Hess_bulk_eigval_density = Hess.LanczosApproxSpec(init_poly_deg=64, poly_deg=256)
  plt.figure()
  plt.semilogy(Hess_eigval, Hess_eigval_density, label='Hessian')
  plt.semilogy(Hess_bulk_eigvals, Hess_bulk_eigval_density, label='Bulk')
  plt.semilogy(Hess.vals, np.ones(len(Hess.vals)) / total_params, '*', label='Outliers')
  plt.legend()
  plt.xlabel('Eigenvalue')
  plt.ylabel('Density of Spectrum')


In [None]:
# plot_spectrum(train_loader, LeNet5, Ws, 10)

# Data Download

In [None]:
np.random.seed(0)

# Parameter Settings
BATCH_SIZE = 64
test_BATCH_SIZE = 1000
EPOCHS = 20
GAP = 100

In [None]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,           
                       transform=transforms.Compose([                       
                               transforms.ToTensor()])),    
                        batch_size=BATCH_SIZE, shuffle=True, num_workers = 4, pin_memory = True)
test_loader = torch.utils.data.DataLoader(    
        datasets.MNIST('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor()])),    
                        batch_size=test_BATCH_SIZE, shuffle=True, num_workers=4, pin_memory = True)

In [None]:
n_batches = np.ceil(len(train_loader.dataset)/BATCH_SIZE)
n_test_batches = np.ceil(len(test_loader.dataset)/test_BATCH_SIZE)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")
torch.cuda.device_count()
torch.cuda.get_device_name(0)

Running on the GPU


'Tesla K80'

# Model

In [None]:
"""input image size for the original LeNet5 is 32x32, here is 28x28"""
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
def initialize_weights():
    W1 = torch.nn.init.xavier_uniform_((0.1*torch.randn(1*5*5+1,  6)).clone().detach().requires_grad_(True)).to(device)
    W2 = torch.nn.init.xavier_uniform_((0.1*torch.randn(6*5*5+1,  16)).clone().detach().requires_grad_(True)).to(device)
    W3 = torch.nn.init.xavier_uniform_((0.1*torch.randn(16*4*4+1, 120)).clone().detach().requires_grad_(True)).to(device)
    W4 = torch.nn.init.xavier_uniform_((0.1*torch.randn(120+1,    84)).clone().detach().requires_grad_(True)).to(device)
    W5 = torch.nn.init.xavier_uniform_((0.1*torch.randn(84+1,     10)).clone().detach().requires_grad_(True)).to(device)
    Ws = [W1, W2, W3, W4, W5]
    return Ws

def LeNet5(x, return_all = False): 
    W1, W2, W3, W4, W5 = Ws
    x1 = torch.tanh(F.conv2d(x, W1[:-1].view(6,1,5,5), bias=W1[-1]))
    x2 = F.max_pool2d(x1, 2)
    x3 = torch.tanh(F.conv2d(x2, W2[:-1].view(16,6,5,5), bias=W2[-1]))
    x4 = F.max_pool2d(x3, 2)
    x5 = torch.tanh(x4.view(-1, 16*4*4).mm(W3[:-1]) + W3[-1])
    x6 = torch.tanh(x5.mm(W4[:-1]) + W4[-1])
    # x = F.dropout(x, 0.2, training = True)
    y = x6.mm(W5[:-1]) + W5[-1]
    if return_all:
      return F.log_softmax(y, dim = 1), [x, x1, x2, x3, x4, x5, x6, y]
    return F.log_softmax(y, dim=1)


# Loss Function

In [None]:
def train_loss(data, target):
    y = LeNet5(data)
    loss = F.nll_loss(y, target)
    # loss = F.cross_entropy(y, target)
    _, max_indices = torch.max(y, dim = 1)
    accuracy = (max_indices == target).sum(dtype=torch.float32)/max_indices.size(0)
    return loss, accuracy

def test_loss():
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            y = LeNet5(data)
            loss += F.nll_loss(y, target)
            _, pred = torch.max(y, dim=1)
            accuracy += (pred == target).sum(dtype=torch.float32)/pred.size(0)
    return loss.item()/n_test_batches, accuracy.item()/n_test_batches

def save_start_condition(trainlosslist, testlosslist,trainacclist, testacclist, timelist):
    trainloss = 0.0
    trainacc = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
      data, target = data.to(device), target.to(device)
      loss, accuracy = train_loss(data, target)
      trainloss += loss
      trainacc += accuracy
      
  
    timelist.append(0)

    testloss, testacc = test_loss()

    trainlosslist.append(trainloss.item()/n_batches)
    trainacclist.append(trainacc.item()/n_batches)
    testlosslist.append(testloss)
    testacclist.append(testacc)
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
    .format(0, trainlosslist[-1], testlosslist[-1], trainacclist[-1], testacclist[-1],np.sum(timelist)))

# First Order

## Adam

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
m0 = [torch.zeros(W.shape).to(device) for W in Ws]
v0 = [torch.zeros(W.shape).to(device) for W in Ws]
step_size = 0.005
cnt = 0
n = 0
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws)#, create_graph=True)
        trainloss += loss
        trainacc += accuracy
        

        with torch.no_grad():
            lmbd = min(cnt/(cnt+1), 0.9)
            m0 = [lmbd*old + (1.0-lmbd)*new for (old, new) in zip(m0, grads)]
            lmbd = min(cnt/(cnt+1), 0.999)
            v0 = [lmbd*old + (1.0-lmbd)*new*new for (old, new) in zip(v0, grads)]
            for i in range(len(Ws)):
                Ws[i] -= step_size*(m0[i]/torch.sqrt(v0[i] + 1e-8))
            cnt = cnt + 1

        
    t1 = time.time() - t0
    times.append(t1)

    testloss, testacc = test_loss()

    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
    .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'adam.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 1; train loss: 0.15490940638950892; test loss: 0.07568362355232239, train_accuracy: 0.9532249466950959, test_accuracy:0.9756999969482422, time: 11.204861402511597
Epoch: 2; train loss: 0.056185740651860674; test loss: 0.05741769075393677, train_accuracy: 0.9834754797441365, test_accuracy:0.9822000503540039, time: 22.507003784179688
Epoch: 3; train loss: 0.030792767050932212; test loss: 0.045707696676254274, train_accuracy: 0.9902052238805971, test_accuracy:0.9861001014709473, time: 33.82865071296692
Epoch: 4; train loss: 0.01608825085768059; test loss: 0.03831199705600739, train_accuracy: 0.9953025053304904, test_accuracy:0.9881000518798828, time: 45.02947640419006
Epoch: 5; train loss: 0.008911412407848627; test loss: 0.036005797982215884, train_accuracy: 0.998101012793177, test_accuracy:0.9887001037597656, time: 56.22361326217

# PSGD Netwon



## Full Kronecker

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'Kron.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.15021819092317432; test loss: 0.06422612071037292, train_accuracy: 0.955123933901919, test_accuracy:0.9799000740051269, time: 22.566417455673218
Epoch: 1; train loss: 0.04365609907138068; test loss: 0.054288947582244874, train_accuracy: 0.986224013859275, test_accuracy:0.9824001312255859, time: 45.271703004837036
Epoch: 2; train loss: 0.01879835027113144; test loss: 0.039466971158981325, train_accuracy: 0.9943363539445629, test_accuracy:0.9881000518798828, time: 68.0515923500061
Epoch: 3; train loss: 0.006270146319098564; test loss: 0.03503313362598419, train_accuracy: 0.9983508795309168, test_accuracy:0.9890000343322753, time: 90.62533736228943
Epoch: 4; train loss: 0.0016348644106118663; test loss: 0.03515526950359345, train_accuracy: 0.9996668443496801, test_accuracy:0.9906000137329102, time: 112.930965900421

## SCAN

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.cat([torch.ones((1, W.shape[0])), torch.zeros((1, W.shape[0]))]).to(device),
       torch.ones((1, W.shape[1])).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [psgd.update_precond_scan(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_scan(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'SCAN.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.2074590491841851; test loss: 0.11962230205535888, train_accuracy: 0.9391657782515992, test_accuracy:0.965300178527832, time: 25.051253080368042
Epoch: 1; train loss: 0.06907676582905783; test loss: 0.07586538791656494, train_accuracy: 0.9801938965884861, test_accuracy:0.9771000862121582, time: 50.34159708023071
Epoch: 2; train loss: 0.037583015620835554; test loss: 0.05006649494171143, train_accuracy: 0.9886727078891258, test_accuracy:0.9857000350952149, time: 75.63912510871887
Epoch: 3; train loss: 0.02047724408635707; test loss: 0.04014335870742798, train_accuracy: 0.9938865938166311, test_accuracy:0.9892000198364258, time: 100.80500745773315
Epoch: 4; train loss: 0.00944234465739366; test loss: 0.03946612477302551, train_accuracy: 0.9975346481876333, test_accuracy:0.9892001152038574, time: 125.95810270309448


## SCAW

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.ones((1, W.shape[1])).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [psgd.update_precond_scaw(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_scaw(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'SCAW.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.16640677380917676; test loss: 0.1138519287109375, train_accuracy: 0.9513259594882729, test_accuracy:0.9672000885009766, time: 22.094501972198486
Epoch: 1; train loss: 0.05126433382664662; test loss: 0.05434907078742981, train_accuracy: 0.9844749466950959, test_accuracy:0.9836000442504883, time: 44.31244707107544
Epoch: 2; train loss: 0.025202521383126914; test loss: 0.038503986597061154, train_accuracy: 0.992204157782516, test_accuracy:0.9874000549316406, time: 66.550128698349
Epoch: 3; train loss: 0.008982271019583826; test loss: 0.043652334809303285, train_accuracy: 0.9975346481876333, test_accuracy:0.987600040435791, time: 88.96731495857239
Epoch: 4; train loss: 0.0029890730436931035; test loss: 0.03942024111747742, train_accuracy: 0.9992337420042644, test_accuracy:0.9895000457763672, time: 111.21108436584473

## Test

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.5
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        if n % 100 == 0:
              for i in range(10):
                Hv = grad(grads, Ws, v, retain_graph=True)
                Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
                grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
        with torch.no_grad():
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]    
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
        n = n + 1
    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    # step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'Kron_test.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 1; train loss: 0.1664891192145439; test loss: 0.07757079601287842, train_accuracy: 0.9495102611940298, test_accuracy:0.9749001502990723, time: 14.266586542129517
Epoch: 2; train loss: 0.05692284346135186; test loss: 0.0622079074382782, train_accuracy: 0.9823427505330491, test_accuracy:0.9805000305175782, time: 28.23250102996826
Epoch: 3; train loss: 0.03430147313359958; test loss: 0.04931778609752655, train_accuracy: 0.9896388592750534, test_accuracy:0.9855000495910644, time: 41.72618269920349
Epoch: 4; train loss: 0.020378133127175922; test loss: 0.04580033421516418, train_accuracy: 0.9939032515991472, test_accuracy:0.9845001220703125, time: 55.622013330459595
Epoch: 5; train loss: 0.011430918280758075; test loss: 0.045965859293937684, train_accuracy: 0.9972681236673774, test_accuracy:0.9864001274108887, time: 69.47506356239319

## Test2 (two sequential preconditioners)

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs1 = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws] #kronecker
Qs2 = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws] #kronecker
# Qs2 = [[torch.cat([torch.ones((1, W.shape[0])), torch.zeros((1, W.shape[0]))]).to(device), torch.ones((1, W.shape[1])).to(device)] for W in Ws] #scan
# Qs2 = [[torch.eye(W.shape[0]).to(device), torch.ones((1, W.shape[1])).to(device)] for W in Ws] # scaw


step_size = 0.05
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v, retain_graph=True)

        with torch.no_grad():
            Qs1 = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs1, v, Hv)]
            Qs2 = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs2, v, Hv)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs1, grads)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs2, pre_grads)] 

            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
        n = n + 1
    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    # step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'Kron_test2.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 1; train loss: 0.13485317596240337; test loss: 0.042934569716453555, train_accuracy: 0.9606376599147122, test_accuracy:0.9862000465393066, time: 27.73645567893982
Epoch: 2; train loss: 0.02577865250837574; test loss: 0.03742868900299072, train_accuracy: 0.9926372601279317, test_accuracy:0.9875, time: 57.035680294036865
Epoch: 3; train loss: 0.010024081923560038; test loss: 0.039589452743530276, train_accuracy: 0.9974846748400853, test_accuracy:0.987600040435791, time: 86.22909450531006
Epoch: 4; train loss: 0.005079065050397601; test loss: 0.04754831194877625, train_accuracy: 0.9987673240938166, test_accuracy:0.9875000953674317, time: 115.59606838226318
Epoch: 5; train loss: 0.001943984773875808; test loss: 0.0516740620136261, train_accuracy: 0.9994836087420043, test_accuracy:0.9879000663757325, time: 144.366694688797
Epoch: 6; 

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

lambd = 150
omega = 0.5
n = 0
for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
   
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        grads1 = grads
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]

        if n % 100 == 0:
              for i in range(10):
                Hv = grad(grads, Ws, v, retain_graph=True)
                Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
                grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
              pre_grads = grads
        else:
              pre_grads = [precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
        with torch.no_grad():
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
            if n % 1 == 0:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads1, pre_grads)])
                loss2 = F.nll_loss(LeNet5(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, M, loss2, lambd, omega)
        n = n+1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)

    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

# scipy.io.savemat(results_dir + 'Kron_damped.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 1; train loss: 0.21988214358592084; test loss: 0.08379580974578857, train_accuracy: 0.9358009061833689, test_accuracy:0.9742001533508301, time: 19.442260026931763
Epoch: 2; train loss: 0.07304357109801855; test loss: 0.060818684101104734, train_accuracy: 0.9790778251599147, test_accuracy:0.9814000129699707, time: 38.619646310806274
Epoch: 3; train loss: 0.05614369764511011; test loss: 0.05423896312713623, train_accuracy: 0.9842750533049041, test_accuracy:0.9839000701904297, time: 58.1361289024353
Epoch: 4; train loss: 0.049935822802057654; test loss: 0.051707619428634645, train_accuracy: 0.9860407782515992, test_accuracy:0.9839000701904297, time: 77.45327568054199
Epoch: 5; train loss: 0.0469316348338178; test loss: 0.050432002544403075, train_accuracy: 0.9869402985074627, test_accuracy:0.9843001365661621, time: 96.7525458335876

## Tikhonov Damping Kronecker

In [None]:
_tiny = 1.2e-38 
 # pi = (torch.trace(Ql)*Qr.shape[0])/(torch.trace(Qr)*Ql.shape[0])
    # 


def precond_grad_kron(Ql, Qr, Grad):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    pi = (torch.trace(P1)*P2.shape[0])/(torch.trace(P2)*P1.shape[0])
    IL = torch.eye(P1.shape[0]).to(device)
    IR = (torch.eye(P2.shape[0])).to(device)
    P1 = P1 + torch.sqrt((pi)*(eta+lambd**0.5)*IL)
    P2 = P2 + torch.sqrt((1/pi)*(eta+lambd**0.5)*IR)

    return P1.mm(Grad).mm(P2)

def update_lambda(loss1, loss2, M, lambd, omega):
    
    r = abs(loss2 - loss1)/(M)
    # print(r, M, lambd)
    if r > 3/4:
      lambd = lambd*omega
    elif r < 1/4:
      lambd = lambd / omega
    return lambd


In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

lambd = 1
omega = 0.001
eta = 1e-3

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    n = 0
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
            if n % 1 == 0:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                loss2 = F.nll_loss(LeNet5(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M, lambd, omega)
        n = n+1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)

    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'Kron_damped.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 1; train loss: 0.14737161339473115; test loss: 0.0668679118156433, train_accuracy: 0.9569396321961621, test_accuracy:0.978700065612793, time: 28.99625849723816
Epoch: 2; train loss: 0.043970465914272804; test loss: 0.052971625328063966, train_accuracy: 0.9860907515991472, test_accuracy:0.98410005569458, time: 57.191593647003174
Epoch: 3; train loss: 0.018941096405484782; test loss: 0.042642155289649965, train_accuracy: 0.9940531716417911, test_accuracy:0.987700080871582, time: 85.8046703338623
Epoch: 4; train loss: 0.005727907233655072; test loss: 0.04124319553375244, train_accuracy: 0.9982675906183369, test_accuracy:0.9897001266479493, time: 114.23541712760925
Epoch: 5; train loss: 0.0012513507149621112; test loss: 0.038233903050422666, train_accuracy: 0.9996668443496801, test_accuracy:0.9908000946044921, time: 142.741577386856

In [None]:
M = [0.5*torch.dot(g.view(-1,),step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)]
M
# loss2 = F.nll_loss(LeNet5(data), target)
# loss1 = loss
# update_lambda(loss1, loss2, M, 150, 0.5)
lambd

0.0

# PSGD Fisher

## Fisher Full Kronecker

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.005
damping = 0.0005
grad_norm_clip_thr = 1e10
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)


for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws)#, create_graph=True)
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grads#grad(grads, Ws, v)   
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg + damping*dw) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)            
    
    testloss, testacc = test_loss()
    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'fisher_kron.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.29857711954665844; test loss: 0.06439159512519836, train_accuracy: 0.9117137526652452, test_accuracy:0.9769001007080078, time: 17.96696448326111
Epoch: 1; train loss: 0.04710177927891583; test loss: 0.0662577509880066, train_accuracy: 0.9855743603411514, test_accuracy:0.9792000770568847, time: 35.82441282272339
Epoch: 2; train loss: 0.0250717195620669; test loss: 0.03892034888267517, train_accuracy: 0.9921875, test_accuracy:0.9879000663757325, time: 53.63473677635193
Epoch: 3; train loss: 0.013934312344614123; test loss: 0.04280449450016022, train_accuracy: 0.9954524253731343, test_accuracy:0.9885000228881836, time: 71.48122811317444
Epoch: 4; train loss: 0.006625388476894354; test loss: 0.04239860773086548, train_accuracy: 0.9979510927505331, test_accuracy:0.9883999824523926, time: 89.32672452926636
Epoch: 5; t

## Fisher SCAN

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.cat([torch.ones((1, W.shape[0])), torch.zeros((1, W.shape[0]))]).to(device),
       torch.ones((1, W.shape[1])).to(device)] for W in Ws]
step_size = 0.005
damping = 0.0005
grad_norm_clip_thr = 1e10
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)


for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws)#, create_graph=True)
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grads#grad(grads, Ws, v)   
        with torch.no_grad():
            Qs = [psgd.update_precond_scan(q[0], q[1], dw, dg + damping*dw) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_scan(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)            
    
    testloss, testacc = test_loss()
    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'fisher_SCAN.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.34073627528859607; test loss: 0.1019331693649292, train_accuracy: 0.8989872068230277, test_accuracy:0.9692000389099121, time: 21.15277099609375
Epoch: 1; train loss: 0.07207591498076026; test loss: 0.0619676947593689, train_accuracy: 0.9783615405117271, test_accuracy:0.9795000076293945, time: 42.306801319122314
Epoch: 2; train loss: 0.046426915410739275; test loss: 0.060597938299179074, train_accuracy: 0.9859408315565032, test_accuracy:0.9821000099182129, time: 63.613298177719116
Epoch: 3; train loss: 0.029446372091134727; test loss: 0.04941503703594208, train_accuracy: 0.9908382196162047, test_accuracy:0.9847000122070313, time: 85.37881517410278
Epoch: 4; train loss: 0.016976590349730144; test loss: 0.04190981090068817, train_accuracy: 0.9948360874200426, test_accuracy:0.9883001327514649, time: 106.900672912597

## Fisher SCAW

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.ones((1, W.shape[1])).to(device)] for W in Ws]
step_size = 0.005
damping = 0.0005
grad_norm_clip_thr = 1e10
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)


for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws)#, create_graph=True)
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grads  
        with torch.no_grad():
            Qs = [psgd.update_precond_scaw(q[0], q[1], dw, dg + damping*dw) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_scaw(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)            
    
    testloss, testacc = test_loss()
    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'fisher_SCAW.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.31194726744694495; test loss: 0.08530771136283874, train_accuracy: 0.9076159381663113, test_accuracy:0.9737001419067383, time: 18.02068257331848
Epoch: 1; train loss: 0.05509298035839219; test loss: 0.0724665641784668, train_accuracy: 0.9831090085287847, test_accuracy:0.9769000053405762, time: 35.50975060462952
Epoch: 2; train loss: 0.030611717370527387; test loss: 0.04809662401676178, train_accuracy: 0.990221881663113, test_accuracy:0.9851000785827637, time: 52.87089276313782
Epoch: 3; train loss: 0.016857143149955442; test loss: 0.043058812618255615, train_accuracy: 0.9948860607675906, test_accuracy:0.9871000289916992, time: 70.34986805915833
Epoch: 4; train loss: 0.00823022307617578; test loss: 0.050789809226989745, train_accuracy: 0.9973680703624733, test_accuracy:0.9863001823425293, time: 87.80099892616272


## Test 3

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
# Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.005
damping = 0.0005
grad_norm_clip_thr = 1e10
TrainLoss, TestLoss = [], []
lambd = 1
omega = 0.001
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)


for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws)
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grads   
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg + damping*dw) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
            if n % 1 == 0:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                loss2 = F.nll_loss(LeNet5(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M, lambd, omega)
        n = n+1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)            
    
    testloss, testacc = test_loss()
    TestLoss.append(testloss)
    TestAcc.append(testacc)
    step_size = 0.01**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'fisher_kron_damped.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.2955695764342351; test loss: 0.07428519725799561, train_accuracy: 0.9138626066098081, test_accuracy:0.976400089263916, time: 22.939367532730103
Epoch: 1; train loss: 0.04860029240915262; test loss: 0.059427326917648314, train_accuracy: 0.9845582356076759, test_accuracy:0.9801000595092774, time: 45.69829273223877
Epoch: 2; train loss: 0.02596639214294043; test loss: 0.04786511957645416, train_accuracy: 0.992254131130064, test_accuracy:0.986299991607666, time: 68.63699674606323
Epoch: 3; train loss: 0.014240892202869407; test loss: 0.047960597276687625, train_accuracy: 0.9956689765458422, test_accuracy:0.986600112915039, time: 91.54815554618835
Epoch: 4; train loss: 0.00714082148537707; test loss: 0.046179157495498654, train_accuracy: 0.9977511993603412, test_accuracy:0.9874000549316406, time: 114.56241106987
Epoc

# Others

##KFAC

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
from kfac import KFAC
import torch.nn as nn
import torch.optim as optim

class LeNet5_K(nn.Module):
    def __init__(self):
        super(LeNet5_K, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.tanh(self.conv1(x)), 2)
        x = F.max_pool2d(F.tanh(self.conv2(x)), 2)
        x = x.view(-1, 256)
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

def test_loss_K(model):
    model.eval()
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for data, target in test_loader:
            # data, target = data.to(device), target.to(device)
            y = model(data)
            loss += F.nll_loss(y, target)
            _, pred = torch.max(y, dim=1)
            accuracy += (pred == target).sum(dtype=torch.float32)/pred.size(0)
    return loss.item()/n_test_batches, accuracy/n_test_batches

model = LeNet5_K()
# model.to(device)
preconditioner = KFAC(model, 0.001, alpha=0.05)
lr0 = 0.01
optimizer = optim.SGD(model.parameters(), lr=lr0)
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)
for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    model.train()
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        # data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        _, max_ind = torch.max(output, dim = 1)
        accuracy = (max_ind == target).sum(dtype=torch.float32)/max_ind.size(0) 

        trainloss += loss
        trainacc += accuracy

        loss.backward()
        preconditioner.step()
        optimizer.step()
        
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    
    lr0 = 0.01**(1/9)*lr0
    optimizer.param_groups[0]['lr'] = lr0
    testloss, testacc = test_loss_K(model)

    TestLoss.append(testloss)
    TestAcc.append(testacc)

    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'KFAC.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4626680874367004; test loss: 2.4614776611328124, train_accuracy: 0.07027918443496801, test_accuracy:0.07090001106262207, time: 0



nn.functional.tanh is deprecated. Use torch.tanh instead.



Epoch: 0; train loss: 0.110662456260307; test loss: 0.047950685024261475, train_accuracy: 0.9682169509594882, test_accuracy:0.9842000007629395, time: 38.568052530288696
Epoch: 1; train loss: 0.03361264373193672; test loss: 0.03781397938728333, train_accuracy: 0.9897887793176973, test_accuracy:0.987500011920929, time: 76.72868132591248
Epoch: 2; train loss: 0.015551872090744311; test loss: 0.03592078387737274, train_accuracy: 0.996202025586354, test_accuracy:0.9890000224113464, time: 115.49309253692627
Epoch: 3; train loss: 0.008015484698037348; test loss: 0.03478087186813354, train_accuracy: 0.9989339019189766, test_accuracy:0.9891999959945679, time: 153.78922843933105
Epoch: 4; train loss: 0.005590427658959492; test loss: 0.033990001678466795, train_accuracy: 0.9994502931769723, test_accuracy:0.9896000027656555, time: 192.4853377342224
Epoch: 5; train loss: 0.00468602821008483; test loss: 0.033995941281318665, train_accuracy: 0.9996335287846482, test_accuracy:0.9901999235153198, time:

## Shampoo

In [None]:
torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[0.01*torch.eye(W.shape[0]).to(device), 0.01*torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.5
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

def matrix_power(matrix, power):
    # use CPU for svd for speed up
    matrix = matrix.cpu()
    u, s, v = torch.svd(matrix)
    return (u @ s.pow_(power).diag() @ v.t()).cuda()

for epoch in range(EPOCHS):
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy
        
        with torch.no_grad():
            Qs = [[q[0] + g.mm(g.t()), q[1] + (g.t()).mm(g)] for (q, g) in zip(Qs, grads)]
            inv_Qs = [[matrix_power(q[0], -1/4), matrix_power(q[1], -1/4)]for q in Qs]
            pre_grads = [q[0].mm(g).mm(q[1]) for (q, g) in zip(inv_Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    # step_size = 0.01**(1/9)*step_size
    step_size = 0.1**(1/9)*step_size
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
     .format(epoch, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'shampoo.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.4627774041344614; test loss: 2.4614776611328124, train_accuracy: 0.07026252665245203, test_accuracy:0.07089999914169312, time: 0
Epoch: 0; train loss: 0.1179326754897388; test loss: 0.08105806112289429, train_accuracy: 0.9643690031982942, test_accuracy:0.9734999656677246, time: 33.79210376739502
Epoch: 1; train loss: 0.03233122673116005; test loss: 0.03650714457035065, train_accuracy: 0.9906549840085288, test_accuracy:0.9873000144958496, time: 67.76992774009705
Epoch: 2; train loss: 0.017851231703117712; test loss: 0.032599040865898134, train_accuracy: 0.9957855810234542, test_accuracy:0.9886000633239747, time: 102.03499507904053
Epoch: 3; train loss: 0.011769646520553621; test loss: 0.03319661021232605, train_accuracy: 0.9976512526652452, test_accuracy:0.9888999938964844, time: 136.12489485740662
Epoch: 4; train loss: 0.008877682024990318; test loss: 0.03210409581661224, train_accuracy: 0.9985007995735607, test_accuracy:0.9897000312805175, time: 170.11892533302

# Comparison

In [None]:
opts = ['adam','Kron','SCAN','SCAW', 'fisher_kron','fisher_SCAN','fisher_SCAW','Kron_test','Kron_test2','Kron_damped', 'fisher_kron_damped','KFAC', 'shampoo']

total_train_time = {}
opts_data = {}
times = {}
train_times = {}
test_times = {}
train_losses = {}
test_losses = {}
train_accs = {}
test_accs = {}


for opt in opts:
	opts_data[opt] = scipy.io.loadmat(results_dir+opt+'.mat')	

In [None]:
colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF', '#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']
# colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF','#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']

In [None]:
for opt in opts:
  # print(opt)
  data = opts_data[opt]
  times[opt] = data.get('Time')
  train_times[opt] = np.cumsum(times[opt])
  test_times[opt] = np.cumsum(times[opt])
  total_train_time[opt] = np.sum(times[opt])
  train_losses[opt] = data.get('TrainLoss').reshape(EPOCHS+1,)
  train_accs[opt] = data.get('TrainAccuracy').reshape(EPOCHS+1,)
  test_losses[opt] = data.get('TestLoss').reshape(EPOCHS+1,)
  test_accs[opt] = data.get('TestAccuracy').reshape(EPOCHS+1)


In [None]:
# plot train_losses vs Iterations
plot_loss_metrics(None,train_losses,'Train Loss vs EPOCHS', 'EPOCHS','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(None,test_losses,'Test Loss vs EPOCHS', 'EPOCHS','Test Loss')
# # plot test_losses vs Iterations
plot_loss_metrics(train_times,train_losses,'Train Loss vs Time', 'Time','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(test_times,test_losses,'Test Loss vs Time', 'Time','Test Loss')

In [None]:
# plot train_losses vs Iterations
plot_loss_metrics(None,train_accs,'Train Accuracy vs EPOCHS', 'EPOCHS','Train Accuracy')
# plot test_losses vs Iterations
plot_loss_metrics(None,test_accs,'Test Accuracy vs EPOCHS', 'EPOCHS','Test Accuracys')
# # plot test_losses vs Iterations
plot_loss_metrics(train_times,train_accs,'Train Accuracy vs Time', 'Time','Train Accuracy')
# plot test_losses vs Iterations
plot_loss_metrics(test_times,test_accs,'Test Accuracy vs Time', 'Time','Test Accuracy')