# Setup

In [13]:
torch.__version__

'1.8.1+cu101'

In [1]:
!pip install pkbar

Collecting pkbar
  Downloading https://files.pythonhosted.org/packages/95/8f/28e0a21b27f836a8903315050db17dd68e55bf477b6fde52d1c68da3c8a6/pkbar-0.5-py3-none-any.whl
Installing collected packages: pkbar
Successfully installed pkbar-0.5


In [18]:
# !git clone https://github.com/gpauloski/kfac_pytorch.git
# !cd kfac_pytorch
!pip install .

Processing /content/kfac_pytorch
Building wheels for collected packages: kfac-pytorch
  Building wheel for kfac-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for kfac-pytorch: filename=kfac_pytorch-0.3.1-cp37-none-any.whl size=37245 sha256=7f81a07f3d0526e585ea38d05f25f3e2dea342af458c707c36eea76fd64598fb
  Stored in directory: /tmp/pip-ephem-wheel-cache-0uqxbdfj/wheels/ac/dc/3c/85d9891f34779445f0bb1501db581dd504920ea516cd993e29
Successfully built kfac-pytorch
Installing collected packages: kfac-pytorch
Successfully installed kfac-pytorch-0.3.1


In [None]:
from google.colab import drive
from google.colab import files
import sys
import time

drive.mount('/content/gdrive/', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'Colab Notebooks/MS Thesis/PSGD Paper/'
results_dir = base_dir + 'RNN/results_add/'
logs_dir = base_dir + 'log'
sys.path.append(base_dir)
import preconditioned_stochastic_gradient_descent as psgd 


Mounted at /content/gdrive/


In [12]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import time
import math

import pkbar
from tqdm import tqdm
from tabulate import tabulate
import scipy.io
from sklearn import metrics
import plotly.express as px
from torchsummary import summary
import torch.nn as nn
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.get_device_name(0)


'Tesla T4'

# Functions

In [19]:
def plot_loss_metrics(xaxis,yaxis,title, x_label,y_label):

    fig = go.Figure()
    i = 0
    if(xaxis != None):
        for opt in opts:
            fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1
    else:
        for opt in opts:
            fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1

    fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis_type="log")
    fig.show()
    fig.write_html(results_dir + title + ".html")

def plot_acc_metrics(xaxis,yaxis,title, x_label,y_label):
 
    fig = go.Figure()
    i = 0
    if(xaxis != None):
        for opt in opts:
            fig.add_trace(go.Scatter(x = xaxis[opt], y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1
    else:
        for opt in opts:
            fig.add_trace(go.Scatter(y=yaxis[opt], name = opt, mode='lines', line = dict(color = colors[i])))
            i = i + 1

    fig.update_layout(title=title, xaxis_title=x_label, yaxis_title=y_label, yaxis=dict(range=[0.97, 1]))
    fig.show()
    fig.write_html(results_dir + title + ".html")


def update_lambda(loss1, loss2, M, lambd, omega):
    
    r = abs(loss2 - loss1)/(M)
    if r > 3/4:
        lambd = lambd*omega
    elif r < 1/4:
        lambd = lambd / omega
    return lambd
    

In [20]:
np.random.seed(0)

# Parameter Settings
BATCH_SIZE = 100
test_BATCH_SIZE = 100
EPOCHS = 50
GAP = 100

# Data Download

In [21]:
from sklearn.model_selection import train_test_split
batch_size, seq_len0 = 100, 30
dim_in, dim_hidden, dim_out = 2, 20, 1

# generate training data for the add problem
def get_dataset(batch_size=50000):
    seq_len = round(seq_len0 + 0.1*np.random.rand()*seq_len0)
    x = torch.zeros([batch_size, seq_len, dim_in])
    y = torch.zeros([batch_size, dim_out])
    for i in range(batch_size):
        x[i,:,0] = 2.0*torch.rand(seq_len) - 1.0
        while True:
            i1, i2 = list(np.floor(np.random.rand(2)*seq_len/2).astype(int))
            if i1 != i2:
                break
        x[i, i1, 1] = 1.0
        x[i, i2, 1] = 1.0
        y[i] = 0.5*(x[i,i1,0] + x[i,i2,0])
    # tranpose x to dimensions: sequence_length * batch_size * dimension_input 
#     return torch.transpose(x, 1,0).to(device), y.to(device)
    return x, y

X, y = get_dataset()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)


train_dataset = torch.utils.data.TensorDataset(torch.Tensor(X_train),torch.Tensor(y_train))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers=4)

test_dataset = torch.utils.data.TensorDataset(torch.Tensor(X_test),torch.Tensor(y_test))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = test_BATCH_SIZE, shuffle = True, num_workers=4)


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



In [22]:
n_batches = len(train_loader)
n_test_batches = len(test_loader)

In [None]:
X_train[0].T

tensor([[ 0.4446,  0.5660, -0.8962,  0.8201,  0.8791,  0.5808, -0.3638, -0.4248,
         -0.5227,  0.8594, -0.1990,  0.2132,  0.1217, -0.4132,  0.8183, -0.2639,
          0.9488,  0.5652, -0.1195, -0.7546, -0.6997, -0.7075, -0.4938, -0.0930,
         -0.4396, -0.7521, -0.7292, -0.4287,  0.9327, -0.6806,  0.7967, -0.4497],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])

In [None]:
0.5*(0.5775 + 0.9636)

0.7705500000000001

# Model

In [31]:
# generate a random orthogonal matrix for recurrent matrix initialization 

# initialize the RNN weights
def initialize_weights():
    W1_np = np.concatenate((np.random.normal(loc=0.0, scale=0.1, size=[dim_in, dim_hidden]),
                            get_rand_orth(dim_hidden),
                            np.zeros([1, dim_hidden])), axis=0)
    W2_np = np.concatenate((np.random.normal(loc=0.0, scale=0.1, size=[dim_hidden, dim_out]),
                            np.zeros([1, dim_out])), axis=0)
    W1 = torch.tensor(W1_np, dtype=torch.float, requires_grad=True).to(device)
    W2 = torch.tensor(W2_np, dtype=torch.float, requires_grad=True).to(device)
    Ws = [W1, W2]
    return Ws

def get_rand_orth( dim ):
    temp = np.random.normal(size=[dim, dim])
    q, _ = np.linalg.qr(temp)
    return q

def model(x):
    W1, W2 = Ws
    # print(x.shape, W1.shape, W2.shape)
    ones = torch.ones(batch_size, 1).to(device)
    h = torch.zeros(batch_size, dim_hidden).to(device)
    for xt in x.transpose(1,0):
        net_in = torch.tensor(xt, dtype=torch.float)
        h = torch.tanh( torch.cat((net_in, h, ones), dim=1).mm(W1) )
        # print(h.shape, xt.shape)
        
    # h_dropout = h*torch.bernoulli(0.9+0.1*torch.rand(h.shape).to(device))  
    
    net_out = torch.cat((h, ones), dim=1).mm(W2)
    return net_out

In [32]:
torch.manual_seed(1)
np.random.seed(0)
Ws = initialize_weights()

[w.shape for w in Ws]

[torch.Size([23, 20]), torch.Size([21, 1])]

# Loss Function

In [None]:
def train_loss(data, target):
    out = model(data)
    loss = F.mse_loss(out, target)
    return loss

def test_loss():
    loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            out = model(data)
            loss += F.mse_loss(out, target)
            
    return loss.item()/n_test_batches


def save_start_condition(trainlosslist, testlosslist, timelist):
    trainloss = 0.0

    for (data, target) in tqdm(train_loader, ncols = 80):
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        trainloss += loss

    timelist.append(0)

    testloss = test_loss()

    trainlosslist.append(trainloss.item()/n_batches)
    testlosslist.append(testloss)
    print('Epoch: {}; train loss: {}; test loss: {},  time: {}'.format(0, trainlosslist[-1], testlosslist[-1], np.sum(timelist)))

# SGD

In [None]:
torch.manual_seed(1)
np.random.seed(0)
Ws = initialize_weights()
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.numel() for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []

with torch.no_grad():
    save_start_condition(TrainLoss, TestLoss, times)
for epoch in range(EPOCHS):
    n = 0
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for (data, target) in train_loader:
        
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        trainloss += loss
        grads = grad(loss, Ws)  
        

        with torch.no_grad():
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

            for (W,pG) in zip(Ws, grads):
                W -= step_adjust*step_size*pG
            kbar.update(n, values=[("loss", loss.item())])    
            n = n + 1
            
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    
    
#     step_size = 0.01**(1/9)*step_size
    testloss = test_loss()
    kbar.add(1, values=[("val_loss", testloss)])

    TestLoss.append(testloss)
#     print('Epoch: {}; train loss: {}; test loss: {}, train_acc: {}, test_acc:{}, time: {}'\
#      .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'SGD.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:01<00:00, 263.33it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
Epoch: 16/50
Epoch: 17/50
Epoch: 18/50
Epoch: 19/50
Epoch: 20/50
Epoch: 21/50
Epoch: 22/50
Epoch: 23/50
Epoch: 24/50
Epoch: 25/50
Epoch: 26/50
Epoch: 27/50
Epoch: 28/50
Epoch: 29/50
Epoch: 30/50
Epoch: 31/50
Epoch: 32/50
Epoch: 33/50
Epoch: 34/50
Epoch: 35/50
Epoch: 36/50
Epoch: 37/50
Epoch: 38/50
Epoch: 39/50
Epoch: 40/50
Epoch: 41/50
Epoch: 42/50
Epoch: 43/50
Epoch: 44/50
Epoch: 45/50
Epoch: 46/50
Epoch: 47/50
Epoch: 48/50
Epoch: 49/50
Epoch: 50/50


# Adam

In [None]:
torch.manual_seed(1)
np.random.seed(0)
Ws = initialize_weights()
step_size = 0.001
m0 = [torch.zeros(W.shape).to(device) for W in Ws]
v0 = [torch.zeros(W.shape).to(device) for W in Ws]
TrainLoss, TestLoss = [], []
times = []

with torch.no_grad():
    save_start_condition(TrainLoss, TestLoss, times)
for epoch in range(EPOCHS):
    n = 0
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    trainacc = 0.0
    t0 = time.time()
    for (data, target) in train_loader:
        
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        trainloss += loss
        grads = grad(loss, Ws, create_graph=True)    

        with torch.no_grad():
            lmbd = min(n/(n+1), 0.9)
            m0 = [lmbd*old + (1.0-lmbd)*new for (old, new) in zip(m0, grads)]
            lmbd = min(n/(n+1), 0.999)
            v0 = [lmbd*old + (1.0-lmbd)*new*new for (old, new) in zip(v0, grads)]
            for (W,m,v) in zip(Ws, m0, v0):
                W -= step_size*(m/torch.sqrt(v + 1e-8))
            kbar.update(n, values=[("loss", loss.item())])    
            n = n + 1
            
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    
    
#     step_size = 0.01**(1/9)*step_size
    testloss = test_loss()
    kbar.add(1, values=[("val_loss", testloss)])

    TestLoss.append(testloss)
#     print('Epoch: {}; train loss: {}; test loss: {}, train_acc: {}, test_acc:{}, time: {}'\
#      .format(epoch+1, TrainLoss[-1], TestLoss[-1], TrainAcc[-1], TestAcc[-1],np.sum(times)))

scipy.io.savemat(results_dir + 'adam.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:01<00:00, 264.44it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
Epoch: 16/50
Epoch: 17/50
Epoch: 18/50
Epoch: 19/50
Epoch: 20/50
Epoch: 21/50
Epoch: 22/50
Epoch: 23/50
Epoch: 24/50
Epoch: 25/50
Epoch: 26/50
Epoch: 27/50
Epoch: 28/50
Epoch: 29/50
Epoch: 30/50
Epoch: 31/50
Epoch: 32/50
Epoch: 33/50
Epoch: 34/50
Epoch: 35/50
Epoch: 36/50
Epoch: 37/50
Epoch: 38/50
Epoch: 39/50
Epoch: 40/50
Epoch: 41/50
Epoch: 42/50
Epoch: 43/50
Epoch: 44/50
Epoch: 45/50
Epoch: 46/50
Epoch: 47/50
Epoch: 48/50
Epoch: 49/50
Epoch: 50/50


# PSGD

In [None]:
torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
# Qs = [[0.1*torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []

times = []
save_start_condition(TrainLoss, TestLoss, times)

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    n = 0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    
    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.01**(1/9)*step_size
    

scipy.io.savemat(results_dir + 'Kron.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:02<00:00, 213.30it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
  0/450 [..............................] - ETA: 0s - loss: 0.0000e+00

KeyboardInterrupt: ignored

# DPSGD APPROACH 1

In [None]:
_tiny = 1.2e-38 
 # pi = (torch.trace(Ql)*Qr.shape[0])/(torch.trace(Qr)*Ql.shape[0])
    # 

    
def precond_grad_kron(Ql, Qr, Grad):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    pi = (torch.trace(P1)*P2.shape[0])/(torch.trace(P2)*P1.shape[0])
    IL = torch.ones(P1.shape[0]).to(device)
    IR = (torch.ones(P2.shape[0])).to(device)
    P1 = P1 + torch.diag(torch.sqrt((pi)*(eta + lambd))*IL)
    P2 = P2 + torch.diag(torch.sqrt((1/pi)*(eta + lambd))*IR)

    return P1.mm(Grad).mm(P2)

def update_lambda(loss1, loss2, M, lambd, omega):
    
    r = abs(loss2 - loss1)/(M)
    # print(r, M, lambd)
    if r > 3/4:
      lambd = lambd*omega
    elif r < 1/4:
      lambd = lambd / omega
    return lambd


In [None]:
torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []
save_start_condition(TrainLoss, TestLoss, times)

lambd = 1
update_after = 5
omega = (19/20)**update_after

eta = 1e-10
for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    n = 0
    t0 = time.time()
    
    
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        if n % 1 == 0:
          v = [torch.randn(W.shape).to(device) for W in Ws]
          Hv = grad(grads, Ws, v)
        
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

            if n % update_after == 0 and lambd > 1e-10:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(g*pg) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(-step_size * pg*psgd.precond_grad_kron(q[0], q[1], -step_size * pg)) \
                #              for (g, pg, q) in zip(grads, pre_grads, Qs)])
                loss2 = F.mse_loss(model(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M,  lambd, omega)

        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)

    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.01**(1/9)*step_size
    

scipy.io.savemat(results_dir + 'Kron_damped.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})

# DPSGD APPROACH 2

In [None]:
torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
dQs = [[torch.ones(W.shape[0],1).to(device), torch.ones(1,W.shape[1]).to(device)] for W in Ws]

step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []
save_start_condition(TrainLoss, TestLoss, times)

lambd = 1
update_after = 5
omega = (19/20)**update_after
# 
eta = 1e-5
for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    n = 0
    t0 = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        if n % 1 == 0:
          v = [torch.randn(W.shape).to(device) for W in Ws]
          Hv = grad(grads, Ws, v)
        
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            pre_grads = [psgd.precond_grad_kron(q[0], q[1], g) for (q, g) in zip(Qs, grads)]
            damp_grads = [((lambd+eta)**0.5)*g for g in grads]
            pre_grads = [pg+dg for (pg, dg) in zip(pre_grads, damp_grads)] 
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

            if n % update_after == 0 and lambd > 1e-10:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(g*pg) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(-step_size * pg*psgd.precond_grad_kron(q[0], q[1], -step_size * pg)) \
                #              for (g, pg, q) in zip(grads, pre_grads, Qs)])
                loss2 = F.mse_loss(model(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M,  lambd, omega)

        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)

    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.01**(1/9)*step_size
    

scipy.io.savemat(results_dir + 'Kron_damped2.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:02<00:00, 219.23it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
Epoch: 16/50
Epoch: 17/50
Epoch: 18/50
Epoch: 19/50
Epoch: 20/50
Epoch: 21/50
Epoch: 22/50
Epoch: 23/50
Epoch: 24/50
Epoch: 25/50
Epoch: 26/50
Epoch: 27/50
Epoch: 28/50
Epoch: 29/50
Epoch: 30/50
Epoch: 31/50
Epoch: 32/50
Epoch: 33/50
Epoch: 34/50
Epoch: 35/50
Epoch: 36/50
Epoch: 37/50
Epoch: 38/50
Epoch: 39/50
Epoch: 40/50
Epoch: 41/50
Epoch: 42/50
Epoch: 43/50
Epoch: 44/50
Epoch: 45/50
Epoch: 46/50
Epoch: 47/50
Epoch: 48/50
Epoch: 49/50
Epoch: 50/50


# DPSGD-M

In [None]:
def precond_kron(Ql, Qr, Pl, Pr, beta):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    pi = (torch.trace(P1)*P2.shape[0])/(torch.trace(P2)*P1.shape[0])
    IL = torch.ones(P1.shape[0]).to(device)
    IR = (torch.ones(P2.shape[0])).to(device)
    P1 = P1 + torch.diag(torch.sqrt((pi)*(eta + lambd))*IL)
    P2 = P2 + torch.diag(torch.sqrt((1/pi)*(eta + lambd))*IR)

    Pl = beta*Pl + (1-beta)*P1 
    Pr = beta*Pr + (1-beta)*P2 

    return [P1, P2, Pl, Pr]

def precond_kron2(Ql, Qr, Pl, Pr, beta):
    P1 = Ql.t().mm(Ql)
    P2 = Qr.t().mm(Qr)
    Pl = beta*Pl + (1-beta)*P1 
    Pr = beta*Pr + (1-beta)*P2 
    return [P1, P2, Pl, Pr]

def precond_grad_kron2(Pl, Pr, Grad):
    return Pl.mm(Grad).mm(Pr)

torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
Ps = [[torch.zeros(W.shape[0]).to(device), torch.zeros(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []
save_start_condition(TrainLoss, TestLoss, times)

lambd = 1
update_after = 5
omega = (19/20)**update_after
eta = 1e-10
beta = 0.7

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    n = 0
    trainloss = 0.0
    t0 = time.time()
   
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss= train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss

        v = [torch.randn(W.shape).to(device) for W in Ws]
        Hv = grad(grads, Ws, v)
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            beta = min(n/(n+1), 0.7)
            Ps = [precond_kron(q[0], q[1], p[0], p[1], beta) for (q, p) in zip(Qs, Ps)]
            pre_grads = [precond_grad_kron2(p[2], p[3], g) for (p, g) in zip(Ps, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]
            if n % update_after == 0 and lambd > 1e-10:
              
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                loss2 = F.mse_loss(model(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M, lambd, omega)
        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)

    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.01**(1/9)*step_size
    

scipy.io.savemat(results_dir + 'mod_psgd1.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})

# dpsgd-m 2

In [None]:

torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
Qs = [[torch.eye(W.shape[0]).to(device), torch.eye(W.shape[1]).to(device)] for W in Ws]
dQs = [[torch.ones(W.shape[0],1).to(device), torch.ones(1,W.shape[1]).to(device)] for W in Ws]

step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []
save_start_condition(TrainLoss, TestLoss, times)

lambd = 1
update_after = 5
omega = (19/20)**update_after
beta = 0.7
eta = 1e-10
for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    n = 0
    t0 = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
      
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        if n % 1 == 0:
          v = [torch.randn(W.shape).to(device) for W in Ws]
          Hv = grad(grads, Ws, v)
        
        with torch.no_grad():
            Qs = [psgd.update_precond_kron(q[0], q[1], dw, dg) for (q, dw, dg) in zip(Qs, v, Hv)]
            beta = min(n/(n+1), 0.7)
            Ps = [precond_kron2(q[0], q[1], p[0], p[1], beta) for (q, p) in zip(Qs, Ps)]
            pre_grads = [precond_grad_kron2(p[2], p[3], g) for (p, g) in zip(Ps, grads)]
            
            damp_grads = [((lambd+eta)**0.5)*g for g in grads]
            pre_grads = [pg+dg for (pg, dg) in zip(pre_grads, damp_grads)] 
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)

            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

            if n % update_after == 0 and lambd > 1e-10:
                M = min([0.5*torch.dot(g.view(-1,), step_size*pg.view(-1,)) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(g*pg) for (g, pg) in zip(grads, pre_grads)])
                # M = 0.5*sum([torch.sum(-step_size * pg*psgd.precond_grad_kron(q[0], q[1], -step_size * pg)) \
                #              for (g, pg, q) in zip(grads, pre_grads, Qs)])
                loss2 = F.mse_loss(model(data), target)
                loss1 = loss
                lambd = update_lambda(loss1, loss2, M,  lambd, omega)

        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)

    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.01**(1/9)*step_size
    

scipy.io.savemat(results_dir + 'mod_psgd2.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:02<00:00, 216.34it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
Epoch: 16/50
Epoch: 17/50
Epoch: 18/50
Epoch: 19/50
Epoch: 20/50
Epoch: 21/50
Epoch: 22/50
Epoch: 23/50
Epoch: 24/50
Epoch: 25/50
Epoch: 26/50
Epoch: 27/50
Epoch: 28/50
Epoch: 29/50
Epoch: 30/50
Epoch: 31/50
Epoch: 32/50
Epoch: 33/50
Epoch: 34/50
Epoch: 35/50
Epoch: 36/50
Epoch: 37/50
Epoch: 38/50
Epoch: 39/50
Epoch: 40/50
Epoch: 41/50
Epoch: 42/50
Epoch: 43/50
Epoch: 44/50
Epoch: 45/50
Epoch: 46/50
Epoch: 47/50
Epoch: 48/50
Epoch: 49/50
Epoch: 50/50


# KFAC

In [23]:
from kfac import KFAC
kfac.module.

In [35]:
from kfac import KFAC
import torch.nn as nn
import torch.optim as optim

class myModel_K(nn.Module):
    def __init__(self):
        super(myModel_K, self).__init__()
        self.lstm = nn.RNN(input_size = 20, hidden_size = 20, )
        self.dense = nn.Linear(20, 1)
        

    def forward(self, x):
        x = self.lstm(x)
        x = self.dense(x)
        return x

def test_loss_K(model):
    model.eval()
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            # data = data.view(-1, 28*28*1)
            y = model(data)
            loss += F.mse_loss(y, data)

    return loss.item()/n_test_batches

In [36]:
[print(p.size()) for p in myModel_K().parameters()]

torch.Size([20, 20])
torch.Size([20, 20])
torch.Size([20])
torch.Size([20])
torch.Size([1, 20])
torch.Size([1])


[None, None, None, None, None, None]

In [None]:
torch.manual_seed(1)  
Ws = initialize_weights()
model = myModel_K()
model.to(device)
preconditioner = KFAC(model, 0.001, lr = 0.1, accumulate_data=True)
lr0 = 0.5
optimizer = optim.SGD(model.parameters(), lr=lr0)
TrainLoss, TestLoss = [], []

times = []
with torch.no_grad():
  save_start_condition(TrainLoss, TestLoss, times)

for epoch in range(EPOCHS):
    trainloss = 0.0
    n = 0
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    model.train()
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss = F.mse_loss(output, data)

        trainloss += loss
        
        loss.backward()
        preconditioner.step()
        optimizer.step()

        kbar.update(n, values=[("loss", loss.item())])
        n += 1
    t1 = time.time() - t0
    times.append(t1)

    TrainLoss.append(trainloss.item()/n_batches)
    
    # lr0 = 0.01**(1/9)*lr0
    optimizer.param_groups[0]['lr'] = lr0
    testloss = test_loss_K(model)
    kbar.add(1, values=[("val_loss", testloss)])

    TestLoss.append(testloss)
    
scipy.io.savemat(results_dir + 'KFAC.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss,  'Time':times})

# Shampoo

In [None]:
torch.manual_seed(1)
np.random.seed(0)

Ws = initialize_weights()
Qs = [[0.01*torch.eye(W.shape[0]).to(device), 0.01*torch.eye(W.shape[1]).to(device)] for W in Ws]
step_size = 0.1
grad_norm_clip_thr = 0.1*sum(W.shape[0]*W.shape[1] for W in Ws)**0.5
TrainLoss, TestLoss = [], []
times = []
save_start_condition(TrainLoss, TestLoss, times)

def matrix_power(matrix, power):
    # use CPU for svd for speed up
    matrix = matrix.cpu()
    u, s, v = torch.svd(matrix)
    return (u @ s.pow_(power).diag() @ v.t()).cuda()

for epoch in range(EPOCHS):
    kbar = pkbar.Kbar(target=n_batches, epoch=epoch, num_epochs=EPOCHS, width=30, always_stateful=False, interval = 1)
    trainloss = 0.0
    n = 0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        
      
        data, target = data.to(device), target.to(device)
        loss = train_loss(data, target)
        
        grads = grad(loss, Ws, create_graph=True)
        
        trainloss += loss
        
        with torch.no_grad():
            Qs = [[q[0] + g.mm(g.t()), q[1] + (g.t()).mm(g)] for (q, g) in zip(Qs, grads)]
            inv_Qs = [[matrix_power(q[0], -1/4), matrix_power(q[1], -1/4)]for q in Qs]
            pre_grads = [q[0].mm(g).mm(q[1]) for (q, g) in zip(inv_Qs, grads)]
            grad_norm = torch.sqrt(sum([torch.sum(g*g) for g in pre_grads]))
            step_adjust = min(grad_norm_clip_thr/(grad_norm + 1.2e-38), 1.0)
            for i in range(len(Ws)):
                Ws[i] -= step_adjust*step_size*pre_grads[i]

        kbar.update(n, values=[("loss", loss.item())])
        n += 1

    t1 = time.time() - t0
    times.append(t1)
    TrainLoss.append(trainloss.item()/n_batches)
    
    testloss = test_loss()

    TestLoss.append(testloss)
    kbar.add(1, values=[("val_loss", testloss)])
    # step_size = 0.1**(1/9)*step_size
   

scipy.io.savemat(results_dir + 'shampoo.mat', {'TrainLoss': TrainLoss, 'TestLoss': TestLoss, 'Time':times})


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

100%|████████████████████████████████████████| 450/450 [00:02<00:00, 221.92it/s]


Epoch: 0; train loss: 0.17355697631835937; test loss: 0.17691728591918945,  time: 0
Epoch: 1/50
Epoch: 2/50
Epoch: 3/50
Epoch: 4/50
Epoch: 5/50
Epoch: 6/50
Epoch: 7/50
Epoch: 8/50
Epoch: 9/50
Epoch: 10/50
Epoch: 11/50
Epoch: 12/50
Epoch: 13/50
Epoch: 14/50
Epoch: 15/50
Epoch: 16/50
Epoch: 17/50
Epoch: 18/50
Epoch: 19/50
Epoch: 20/50
Epoch: 21/50
Epoch: 22/50
Epoch: 23/50
Epoch: 24/50
Epoch: 25/50
Epoch: 26/50
Epoch: 27/50
Epoch: 28/50
Epoch: 29/50
Epoch: 30/50
Epoch: 31/50
Epoch: 32/50
Epoch: 33/50
Epoch: 34/50
Epoch: 35/50
Epoch: 36/50
Epoch: 37/50
Epoch: 38/50
Epoch: 39/50
Epoch: 40/50
Epoch: 41/50
Epoch: 42/50
Epoch: 43/50
Epoch: 44/50
Epoch: 45/50
Epoch: 46/50
Epoch: 47/50
Epoch: 48/50
Epoch: 49/50
Epoch: 50/50


# Comparison

In [None]:
# opts = ['adam','Kron','SCAN','SCAW', 'fisher_kron','fisher_SCAN','fisher_SCAW','Kron_test','Kron_test2','Kron_damped','fisher_kron_damped','KFAC', 'shampoo']
opts = ['SGD','adam','Kron','Kron_damped2','mod_psgd2', 'shampoo']

total_train_time = {}
opts_data = {}
times = {}
train_times = {}
test_times = {}
train_losses = {}
test_losses = {}
train_accs = {}
test_accs = {}


for opt in opts:
	opts_data[opt] = scipy.io.loadmat(results_dir+opt+'.mat')	

In [None]:
colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF', '#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']
# colors = ['#0000FF','#00FF00','#FF0000','#33F0FF','#FFA833','#FFF933','#000000','#33E0FF','#FF33E6','#D433FF','#888A0B','#8A0B1E','#B498DF','#1B786D']

In [None]:
for opt in opts:
    # print(opt)
    data = opts_data[opt]
    times[opt] = data.get('Time')
    train_times[opt] = np.cumsum(times[opt])
    test_times[opt] = np.cumsum(times[opt])
    total_train_time[opt] = np.sum(times[opt])
    train_losses[opt] = data.get('TrainLoss').reshape(EPOCHS+1,)
    test_losses[opt] = data.get('TestLoss').reshape(EPOCHS+1,)


In [None]:
# plot train_losses vs Iterations
plot_loss_metrics(None,train_losses,'Train Loss vs EPOCHS', 'EPOCHS','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(None,test_losses,'Test Loss vs EPOCHS', 'EPOCHS','Test Loss')
# # plot test_losses vs Iterations
plot_loss_metrics(train_times,train_losses,'Train Loss vs Time', 'Time','Train Loss')
# plot test_losses vs Iterations
plot_loss_metrics(test_times,test_losses,'Test Loss vs Time', 'Time','Test Loss')