# Setup

In [1]:
!pip install pkbar

Collecting pkbar
  Downloading https://files.pythonhosted.org/packages/95/8f/28e0a21b27f836a8903315050db17dd68e55bf477b6fde52d1c68da3c8a6/pkbar-0.5-py3-none-any.whl
Installing collected packages: pkbar
Successfully installed pkbar-0.5


In [2]:
from google.colab import drive
from google.colab import files
import sys
import time

drive.mount('/content/gdrive/', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'Colab Notebooks/MS Thesis/PSGD Paper/'
dpsgd_dir = base_dir + 'DPSGD/'
results_dir = base_dir + 'CNN/results_fashion/'
logs_dir = base_dir + 'log'
sys.path.append(base_dir)
sys.path.append(dpsgd_dir)

from densekron import dense_kron
from scankron import scan_kron
from rahkron import rah_kron
from scawkron import scaw_kron
from Shampoo import Shampoo
import preconditioned_stochastic_gradient_descent as psgd

Mounted at /content/gdrive/


In [3]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import time
import math

import pkbar
from tqdm import tqdm
from tabulate import tabulate
import scipy.io
from sklearn import metrics
import plotly.express as px
from torchsummary import summary
import torch.nn as nn
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.get_device_name(0)


'Tesla P100-PCIE-16GB'

# Functions

In [4]:
def plot_loss_metrics(xaxis,yaxis,title, x_label,y_label):

    fig = go.Figure()
    i = 0
    if(xaxis != None):
        for opt in opts:
            fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1
    else:
        for opt in opts:
            fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1

    fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis_type="log")
    fig.show()
    fig.write_html(results_dir + title + ".html")

def plot_acc_metrics(xaxis,yaxis,title, x_label,y_label):
 
    fig = go.Figure()
    i = 0
    if(xaxis != None):
        for opt in opts:
            fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1
    else:
        for opt in opts:
            fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1

    fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis=dict(range=[0.97, 1]))
    fig.show()
    fig.write_html(results_dir + title + ".html")


def update_lambda(loss1, loss2, M, lambd, omega):
    
    r = abs(loss2 - loss1)/(M)
    if r > 3/4:
        lambd = lambd*omega
    elif r < 1/4:
        lambd = lambd / omega
    return lambd
    

In [5]:
np.random.seed(0)

# Parameter Settings
BATCH_SIZE = 64
test_BATCH_SIZE = 1000
EPOCHS = 20
GAP = 100

# Data Download

In [6]:
train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('./data', train=True, download=True,           
                       transform=transforms.Compose([                       
                               transforms.ToTensor()])),    
                        batch_size=BATCH_SIZE, shuffle=True, num_workers = 4, pin_memory = True)
test_loader = torch.utils.data.DataLoader(    
        datasets.FashionMNIST('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor()])),    
                        batch_size=test_BATCH_SIZE, shuffle=True, num_workers=4, pin_memory = True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!





This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



In [7]:
n_batches = len(train_loader)
n_test_batches = len(test_loader)

# Model

In [8]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# Loss Function

In [9]:
def train_loss(data, target):
    y = model(data)
    loss = F.nll_loss(y, target)
    _, max_indices = torch.max(y, dim = 1)
    accuracy = (max_indices == target).sum(dtype=torch.float32)/max_indices.size(0)
    return loss, accuracy

def test_loss():
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            y = model(data)
            loss += F.nll_loss(y, target)
            _, pred = torch.max(y, dim=1)
            accuracy += (pred == target).sum(dtype=torch.float32)/pred.size(0)
    return loss.item()/n_test_batches, accuracy.item()/n_test_batches
            
    return loss.item()/n_test_batches

def test_loss_K(model):
    model.eval()
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            y = model(data)
            loss += F.nll_loss(y, target)
            _, pred = torch.max(y, dim=1)
            accuracy += (pred == target).sum(dtype=torch.float32)/pred.size(0)
    return loss.item()/n_test_batches, accuracy/n_test_batches


def save_start_condition(trainlosslist, testlosslist,trainacclist, testacclist, timelist):
    trainloss = 0.0
    trainacc = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
      data, target = data.to(device), target.to(device)
      loss, accuracy = train_loss(data, target)
      trainloss += loss
      trainacc += accuracy
      
  
    timelist.append(0)

    testloss, testacc = test_loss()

    trainlosslist.append(trainloss.item()/n_batches)
    trainacclist.append(trainacc.item()/n_batches)
    testlosslist.append(testloss)
    testacclist.append(testacc)
    print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
    .format(0, trainlosslist[-1], testlosslist[-1], trainacclist[-1], testacclist[-1],np.sum(timelist)))

# SGD

In [10]:
torch.manual_seed(1)

model = LeNet5().to(device)
lr0 = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr0)
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
with torch.no_grad():
  save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    model.train()
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        _, max_ind = torch.max(output, dim = 1)
        accuracy = (max_ind == target).sum(dtype=torch.float32)/max_ind.size(0) 

        trainloss += loss
        trainacc += accuracy

        loss.backward()
        optimizer.step()

        kbar.update(n, values=[("loss", loss.item()), ("acc", accuracy.item())])
        n += 1
        
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    
    # lr0 = 0.01**(1/9)*lr0
    optimizer.param_groups[0]['lr'] = lr0
    testloss, testacc = test_loss_K(model)

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    kbar.add(1, values=[("val_loss", testloss), ("val_acc", testacc)])
  
scipy.io.savemat(results_dir + 'sgd.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Epoch: 0; train loss: 2.3051812470848883; test loss: 2.305193328857422, train_accuracy: 0.0999966684434968, test_accuracy:0.10000001192092896, time: 0
Epoch: 1/20
Epoch: 2/20
Epoch: 3/20
Epoch: 4/20
Epoch: 5/20
Epoch: 6/20
Epoch: 7/20
Epoch: 8/20
Epoch: 9/20
Epoch: 10/20
Epoch: 11/20
Epoch: 12/20
Epoch: 13/20
Epoch: 14/20
Epoch: 15/20
Epoch: 16/20
Epoch: 17/20
Epoch: 18/20
Epoch: 19/20
Epoch: 20/20


# Adam

In [11]:
torch.manual_seed(1)

model = LeNet5().to(device)
lr0 = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr0)
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
with torch.no_grad():
  save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    model.train()
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        _, max_ind = torch.max(output, dim = 1)
        accuracy = (max_ind == target).sum(dtype=torch.float32)/max_ind.size(0) 

        trainloss += loss
        trainacc += accuracy

        loss.backward()
        optimizer.step()

        kbar.update(n, values=[("loss", loss.item()), ("acc", accuracy.item())])
        n += 1
        
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    
    # lr0 = 0.01**(1/9)*lr0
    optimizer.param_groups[0]['lr'] = lr0
    testloss, testacc = test_loss_K(model)

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    kbar.add(1, values=[("val_loss", testloss), ("val_acc", testacc)])
  
scipy.io.savemat(results_dir + 'adam.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Epoch: 0; train loss: 2.3051812470848883; test loss: 2.305193328857422, train_accuracy: 0.0999966684434968, test_accuracy:0.10000001192092896, time: 0
Epoch: 1/20
Epoch: 2/20
Epoch: 3/20
Epoch: 4/20
Epoch: 5/20
Epoch: 6/20
Epoch: 7/20
Epoch: 8/20
Epoch: 9/20
Epoch: 11/20
Epoch: 12/20
Epoch: 13/20
Epoch: 14/20
Epoch: 15/20
Epoch: 16/20
Epoch: 17/20
Epoch: 18/20
Epoch: 19/20
Epoch: 20/20


# PSGD

In [15]:
def training(model, epochs, step_size, grad_norm_clip_thr, lambd, omega, dpsgd, update_after, use_damping, file_name, T1 = 10):

    # Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
    Qs = [dpsgd.initialize_preconditioner(W) for W in model.parameters()]
    TrainLoss, TestLoss = [], []
    TrainAcc, TestAcc = [], []
    times = []
    with torch.no_grad():
        save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

    n = 0
    for epoch in range(epochs):
        n = 0
        kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
        trainloss = 0.0
        trainacc = 0.0
        t0 = time.time()
        for (data, target) in train_loader:

            data, target = data.to(device), target.to(device)
            loss, accuracy = train_loss(data, target)
            
            grads = grad(loss, model.parameters(), create_graph=True)
            trainloss += loss
            trainacc += accuracy
            if n % T1 == 0:
              v = [torch.randn(W.shape).to(device) for W in model.parameters()]
              Hv =  grad(grads, model.parameters(), v)
            
            with torch.no_grad():
                Qs = [dpsgd.update_preconditioner(q, dX, dG, step = 0.01) for (q, dX, dG) in zip(Qs, v, Hv)]
                pre_grads = [dpsgd.precondition_grads(q, g, lambd) for (q,g) in zip(Qs, grads)]
                grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
                step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

                for (W,pG) in zip(model.parameters(), pre_grads):
                    W.data -= step_adjust*step_size*pG
                    
                if n % update_after == 0 and use_damping and lambd > 1e-10:
                    M = min([0.5*torch.dot(g.reshape(-1,), step_size*pg.reshape(-1,)) for (g, pg) in zip(grads, pre_grads)])
                    loss2 = F.nll_loss(model(data), target)
                    loss1 = loss
                    lambd = update_lambda(loss1, loss2, M,  lambd, omega)
                    
            kbar.update(n, values=[("loss", loss.item()), ("acc", accuracy.item())])    
            n = n + 1

        t1 = time.time() - t0
        times.append(t1)
        # step_size = 0.01**(1/9)*step_size
        TrainLoss.append(trainloss.item()/n_batches)
        TrainAcc.append(trainacc.item()/n_batches)
        testloss, testacc = test_loss()

        kbar.add(1, values=[("val_loss", testloss), ("val_acc", testacc)])
        
        TestLoss.append(testloss)

    scipy.io.savemat(results_dir + file_name, {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

## Dense Kron 

In [16]:
torch.manual_seed(1)
model = LeNet5().to(device)
use_damping = False
kron = dense_kron(model.parameters(), use_4D = False, use_1D = True, use_damping = use_damping, lambd = 150, eta = 1e-5)

grad_norm_clip_thr = 0.1*sum(W.numel() for W in model.parameters())**0.5
lambd = 1
omega = 0.5

training(model, 
         epochs = EPOCHS, 
         step_size = 0.1, 
         grad_norm_clip_thr = grad_norm_clip_thr,
         lambd = lambd,
         omega = omega, 
         dpsgd = kron,
         update_after = 1,
         use_damping = use_damping, 
         file_name = "kron.mat"
         )


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Epoch: 0; train loss: 2.3051812470848883; test loss: 2.305193328857422, train_accuracy: 0.0999966684434968, test_accuracy:0.10000001192092896, time: 0
Epoch: 1/20
Epoch: 2/20
Epoch: 3/20
Epoch: 4/20
Epoch: 5/20
Epoch: 6/20
Epoch: 7/20
Epoch: 8/20
Epoch: 9/20
Epoch: 10/20
Epoch: 11/20
Epoch: 12/20
Epoch: 13/20
Epoch: 14/20
Epoch: 15/20
Epoch: 16/20
Epoch: 17/20
Epoch: 18/20
Epoch: 19/20
Epoch: 20/20


## DAMPED DENSE KRON

In [None]:
torch.manual_seed(1)
model = LeNet5().to(device)
use_damping = True
kron = dense_kron(model.parameters(), use_4D = False, use_1D = True, use_damping = use_damping, lambd = 150, eta = 1e-5)

grad_norm_clip_thr = 0.1*sum(W.numel() for W in model.parameters())**0.5
lambd = 1
update_after = 5
omega = (19/20)**update_after

training(model, 
         epochs = EPOCHS, 
         step_size = 0.1, 
         grad_norm_clip_thr = grad_norm_clip_thr,
         lambd = lambd,
         omega = omega, 
         dpsgd = kron,
         update_after = update_after,
         use_damping = use_damping, 
         file_name = "kron_damped.mat",
         T1 = 1)


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



Epoch: 0; train loss: 2.3051812470848883; test loss: 2.305193328857422, train_accuracy: 0.0999966684434968, test_accuracy:0.10000001192092896, time: 0
Epoch: 1/20
Epoch: 2/20
Epoch: 3/20
Epoch: 4/20
Epoch: 5/20
Epoch: 6/20
Epoch: 7/20
Epoch: 8/20
Epoch: 9/20
Epoch: 10/20
Epoch: 11/20
Epoch: 12/20
Epoch: 13/20


In [None]:
def precond_kron(Ql, Qr, Pl, Pr, beta):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    pi = (torch.trace(P1)*P2.shape[0])/(torch.trace(P2)*P1.shape[0])
    IL = torch.eye(P1.shape[0]).to(device)
    IR = (torch.eye(P2.shape[0])).to(device)
    P1 = P1 + torch.sqrt((pi)*(eta + lambd**0.5)*IL)
    P2 = P2 + torch.sqrt((1/pi)*(eta + lambd**0.5)*IR)

    Pl = beta*Pl + (1-beta)*P1 
    Pr = beta*Pr + (1-beta)*P2 

    return [P1, P2, Pl, Pr]

def precond_kron2(Ql, Qr, Pl, Pr, beta):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    Pl = beta*Pl + (1-beta)*P1 
    Pr = beta*Pr + (1-beta)*P2 
    return [P1, P2, Pl, Pr]

def precond_grad_kron2(Pl, Pr, Grad):
    return Pl.mm(Grad).mm(Pr)

torch.manual_seed(1)
Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
Ps = [[torch.zeros(W.shape[0]).to(device), torch.zeros(W.shape[1]).to(device)] for W in Ws]
step_size = 0.05
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

lambd = 1
update_after = 1
omega = (19/20)**update_after
eta = 1e-5
beta = 0.9

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    n = 0
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
   
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss, accuracy = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        trainacc += accuracy

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            beta = min(n/(n+1), 0.9)
            Ps = [precond_kron(q[0], q[1], p[0], p[1], beta) for (q, p) in zip(Qs, Ps)]
            pre_grads = [precond_grad_kron2(p[2], p[3], g) for (p, g) in zip(Ps, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
            if n % 1 == 0 and lambd > 1e-10:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                loss2 = F.nll_loss(LeNet5(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M, lambd, omega)
        kbar.update(n, values=[("loss", loss.item()), ("acc", accuracy.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)

    testloss, testacc = test_loss()

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    kbar.add(1, values=[("val_loss", testloss), ("val_acc", testacc)])
    # step_size = 0.01**(1/9)*step_size
    # print('Epoch: {}; train loss: {}; test loss: {}, train_accuracy: {}, test_accuracy:{}, time: {}'\
    #  .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'kron_damped_mod.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

# KFAC

In [None]:
torch.manual_seed(1)
from kfac import KFAC

model = LeNet5().to(device)
preconditioner = KFAC(model, 0.001, alpha=0.05)
lr0 = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr0)
TrainLoss, TestLoss = [], []
TrainAcc, TestAcc = [], []
times = []
with torch.no_grad():
  save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    n = 0
    model.train()
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        _, max_ind = torch.max(output, dim = 1)
        accuracy = (max_ind == target).sum(dtype=torch.float32)/max_ind.size(0) 

        trainloss += loss
        trainacc += accuracy

        loss.backward()
        preconditioner.step()
        optimizer.step()

        kbar.update(n, values=[("loss", loss.item()), ("acc", accuracy.item())])
        n += 1
        
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    TrainAcc.append(trainacc.item()/n_batches)
    
    
    # lr0 = 0.01**(1/9)*lr0
    optimizer.param_groups[0]['lr'] = lr0
    testloss, testacc = test_loss_K(model)

    TestLoss.append(testloss)
    TestAcc.append(testacc)
    kbar.add(1, values=[("val_loss", testloss), ("val_acc", testacc)])
   
scipy.io.savemat(results_dir + 'KFAC.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

Epoch: 0; train loss: 2.3051812470848883; test loss: 2.305193328857422, train_accuracy: 0.0999966684434968, test_accuracy:0.10000001192092896, time: 0
Epoch: 1/20
Epoch: 2/20
Epoch: 3/20
Epoch: 4/20
Epoch: 5/20
Epoch: 6/20
Epoch: 7/20
Epoch: 8/20
Epoch: 9/20
Epoch: 10/20
Epoch: 11/20
Epoch: 12/20
Epoch: 13/20
Epoch: 14/20
Epoch: 15/20
Epoch: 16/20
Epoch: 17/20
Epoch: 18/20
Epoch: 19/20
Epoch: 20/20


# Shampoo

In [None]:
from Shampoo import Shampoo
torch.manual_seed(1)
model = LeNet5().to(device)

shampoo = Shampoo(model.parameters(), use_1D = True, use_damping = False, epsilon = 0.1)
Qs = [shampoo.initialize_preconditioner(W) for W in model.parameters()]
grad_norm_clip_thr = 0.1*sum(W.numel() for W in model.parameters())**0.5
step_size = 0.1
TrainLoss, TestLoss = [], []
times = []
with torch.no_grad():
    save_start_condition(TrainLoss, TestLoss, TrainAcc, TestAcc, times)

n = 0
for epoch in range(EPOCHS):
    n = 0
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for (data, target) in train_loader:

        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)


        grads = grad(loss,model.parameters())

        trainloss += loss
        
        with torch.no_grad():
            Qs = [shampoo.update_preconditioner(q, g) for (q, g) in zip(Qs, grads)]
            pre_grads = [shampoo.precondition_grads(q, g) for (q,g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

            for (W,pG) in zip(model.parameters(), pre_grads):
                W.data -= step_adjust*step_size*pG


        kbar.update(n, values=[("loss", loss.item())])    
        n = n + 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)

    testloss = test_loss()
    kbar.add(1, values=[("val_loss", testloss)])

    TestLoss.append(testloss)


scipy.io.savemat(results_dir + 'Shampoo.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'TrainAccuracy': TrainAcc,'TestAccuracy': TestAcc, 'Time':times})

# Comparison

In [None]:
# opts = ['adam','Kron','SCAN','SCAW', 'fisher_kron','fisher_SCAN','fisher_SCAW','Kron_test','Kron_test2','Kron_damped','fisher_kron_damped','KFAC', 'shampoo']
opts = ['SGD','adam','kron','kron_damped', 'Shampoo']

total_train_time = {}
opts_data = {}
times = {}
train_times = {}
test_times = {}
train_losses = {}
test_losses = {}
train_accs = {}
test_accs = {}


for opt in opts:
	opts_data[opt] = scipy.io.loadmat(results_dir+opt+'.mat')	

In [None]:
colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF', '#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']
# colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF','#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']

In [None]:
for opt in opts:
    # print(opt)
    data = opts_data[opt]
    times[opt] = data.get('Time')
    train_times[opt] = np.cumsum(times[opt])
    test_times[opt] = np.cumsum(times[opt])
    total_train_time[opt] = np.sum(times[opt])
    train_losses[opt] = data.get('TrainLoss').reshape(EPOCHS+1,)
    test_losses[opt] = data.get('TestLoss').reshape(EPOCHS+1,)


In [None]:
# plot train_losses vs Iterations
plot_loss_metrics(None,train_losses,'Train Loss vs EPOCHS', 'EPOCHS','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(None,test_losses,'Test Loss vs EPOCHS', 'EPOCHS','Test Loss')
# # plot test_losses vs Iterations
plot_loss_metrics(train_times,train_losses,'Train Loss vs Time', 'Time','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(test_times,test_losses,'Test Loss vs Time', 'Time','Test Loss')