In [0]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import torch.nn as nn
import time


# ---------------------------------- Hessian-vector product operator----------------------------------

class Operator():
    def __init__(self, size):
        self._size = size

    def apply(self, vec):
        raise NotImplementedError()

    def __matmul__(self, vec):
        return self.apply(vec)

    def size(self):
        return self._size


class ModelHessianOperator(Operator):
    def __init__(self, model, criterion, data_input, data_target):
        size = int(sum(p.numel() for p in model.parameters()))
        super(ModelHessianOperator, self).__init__(size)
        self._model = model
        self._criterion = criterion
        self.set_model_data(data_input, data_target)

    def apply(self, vec):
        return to_vector(torch.autograd.grad(self._grad, self._model.parameters()
                                             , grad_outputs=vec, only_inputs=True, retain_graph=True))

    def set_model_data(self, data_input, data_target):
        self._data_input = data_input
        self._data_target = data_target
        self._output = self._model(self._data_input)
        self._loss = self._criterion(self._output, self._data_target)
        self._grad = to_vector(torch.autograd.grad(self._loss, self._model.parameters(), create_graph=True))

    def get_input(self):
        return self._data_input

    def get_target(self):
        return self._data_target


def to_vector(tensors):
    return torch.cat([t.contiguous().view(-1) for t in tensors])


# ----------------------------------- slq_upd.py -----------------------------------

import numpy as np
import torch
from scipy.sparse.linalg import LinearOperator as ScipyLinearOperator
# from scipy.sparse.linalg import eigsh
# from warnings import warn
import scipy.sparse as sps


def _lanczos_m_upd(A, m, matrix_shape, nv=1, rademacher=False, SV=None):
    orthtol = 1e-2

    if type(SV) != np.ndarray:
        if rademacher:
            # SV = np.sign(np.random.randn(A.shape[0], nv))
            SV = np.sign(np.random.randn(matrix_shape[0], nv))
        else:
            # SV = np.random.randn(A.shape[0], nv)  # init random vectors in columns: n x nv
            SV = np.random.randn(matrix_shape[0], nv)

    V = np.zeros((SV.shape[0], m, nv))
    T = np.zeros((nv, m, m))

    np.divide(SV, np.linalg.norm(SV, axis=0), out=SV)  # normalize each column
    V[:, 0, :] = SV


    w = A.matvec(SV.squeeze())
    w = w.reshape(-1,1)
    alpha = np.einsum('ij,ij->j', w, SV)
    w -= alpha[None, :] * SV
    beta = np.einsum('ij,ij->j', w, w)
    np.sqrt(beta, beta)

    T[:, 0, 0] = alpha
    T[:, 0, 1] = beta
    T[:, 1, 0] = beta

    np.divide(w, beta[None, :], out=w)
    V[:, 1, :] = w
    t = np.zeros((m, nv))

    for i in range(1, m):
        SVold = V[:, i - 1, :]
        SV = V[:, i, :]

        w = A.dot(SV.squeeze())  # sparse @ dense
        w = w.reshape(-1, 1)
        w -= beta[None, :] * SVold  # n x nv
        np.einsum('ij,ij->j', w, SV, out=alpha)

        T[:, i, i] = alpha

        if i < m - 1:
            w -= alpha[None, :] * SV  # n x nv
            # reortho
            np.einsum('ijk,ik->jk', V, w, out=t)
            w -= np.einsum('ijk,jk->ik', V, t)
            np.einsum('ij,ij->j', w, w, out=beta)
            np.sqrt(beta, beta)
            np.divide(w, beta[None, :], out=w)

            T[:, i, i + 1] = beta
            T[:, i + 1, i] = beta

            # more reotho
            innerprod = np.einsum('ijk,ik->jk', V, w)
            reortho = False
            for _ in range(100):
                if (innerprod <= orthtol).sum():
                    reortho = True
                    break
                np.einsum('ijk,ik->jk', V, w, out=t)
                w -= np.einsum('ijk,jk->ik', V, t)
                np.divide(w, np.linalg.norm(w, axis=0)[None, :], out=w)
                innerprod = np.einsum('ijk,ik->jk', V, w)

            V[:, i + 1, :] = w

            if (np.abs(beta) > 1e-2).sum() == 0 or not reortho:
                break
    return T, V



# ------------------------------ lanczos_upd.py ------------------------------

def lanczos(
    operator,
    size,
    num_lanczos_vectors,
    use_gpu=False,
):
    shape = (size, size)


    def _scipy_apply(x):
        x = torch.from_numpy(x)
        if use_gpu:
            x = x.cuda()
        return operator.apply(x.float()).cpu().numpy()

    scipy_op = ScipyLinearOperator(shape, _scipy_apply)
    T, V = _lanczos_m_upd(A=scipy_op, m=num_lanczos_vectors, matrix_shape=shape, SV=None)
    return T, V



# ----------------------------------------- Training parameters -----------------------------------------

batch_size = 200

train_dataset = MNIST(root='MNIST', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = MNIST(root='MNIST', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

n_iters = 500
epochs = n_iters / (len(train_dataset) / batch_size)
input_dim = 784
output_dim = 10
lr_rate = 0.001


# ----------------------------------------- Logistic Regression -----------------------------------------

# class LogisticRegression(torch.nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super(LogisticRegression, self).__init__()
#         self.linear = torch.nn.Linear(input_dim, output_dim)
#
#     def forward(self, x):
#         outputs = self.linear(x)
#         return outputs
#
#
# model = LogisticRegression(input_dim, output_dim)
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)
#
# iter_num = 0
# for epoch in range(int(epochs)):
#     for i, (images, labels) in enumerate(train_loader):
#         images = Variable(images.view(-1, 28 * 28))
#         labels = Variable(labels)
#
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#
#         iter_num+=1
#         if iter_num % 500==0:
#             # calculate Accuracy
#             correct = 0
#             total = 0
#             for images, labels in test_loader:
#                 images = Variable(images.view(-1, 28*28))
#                 outputs = model(images)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total+= labels.size(0)
#                 # for gpu, bring the predicted and labels back to cpu fro python operations to work
#                 correct+= (predicted == labels).sum()
#             accuracy = 100 * correct/total
#             print("Iteration: {}. Loss: {}. Accuracy: {}%.".format(iter_num, loss.item(), accuracy))
#
# test_batch = next(iter(test_loader))
# test_batch[0] = test_batch[0].view(-1, 28 * 28)
#
# data_input = test_batch[0]
# data_target = test_batch[1]
# op = ModelHessianOperator(model, criterion, data_input, data_target)
# size = to_vector(model.parameters()).shape[0]



# ----------------------------------------- ConvNNs -----------------------------------------

# model = nn.Sequential(
#     nn.Conv2d(1, 8, kernel_size=3),
#     nn.ReLU(),
#     nn.Conv2d(8, 8, kernel_size=3),
#     nn.ReLU(),

#     nn.MaxPool2d(2),

#     nn.Conv2d(8, 16, kernel_size=3),
#     nn.ReLU(),
#     nn.Conv2d(16, 16, kernel_size=3),
#     nn.ReLU(),

#     nn.MaxPool2d(2),
#     nn.Flatten(),
#     nn.Linear(256, 10),
# )


def create_model(a, b, c):
    return  nn.Sequential(
        nn.Conv2d(1, 8, kernel_size=3), # 28 - 2 = 26
        nn.ReLU(),
        nn.Conv2d(8, 8, kernel_size=3), # 26 - 2 = 24
        nn.ReLU(),

        nn.MaxPool2d(2), # 24 / 2 = 12

        nn.Conv2d(8, 16, kernel_size=3), # 12 - 2 = 10
        nn.ReLU(),
        nn.Conv2d(16, 16, kernel_size=3), # 10 - 2 = 8
        nn.ReLU(),

        nn.MaxPool2d(2), # 8 / 2 = 4
        nn.Flatten(), # nchannels * m * m = 16 * 4 * 4 = 256
        nn.Linear(256, a),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(a, b),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(b, c),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(c, 10),
    )


a_lst = np.arange(25, 100, 5)
b_lst = a_lst - 5
c_lst = b_lst - 5

time_spent_lst = []
params_num_lst = []

for i in range(len(a_lst)):
    a = a_lst[i]
    b = b_lst[i]
    c = c_lst[i]

    print('------------ {}/{} model -------------'. format(i+1, len(a_lst)))
    model = create_model(a, b, c)
    print('a = {}, b = {}, c = {}'.format(a,b,c))
    params_num = sum(p.numel() for p in model.parameters())
    print('parameters number: {}'.format(params_num))
    params_num_lst.append(params_num)


    # model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)

    lr_rate = 0.00007
    iter_num = 0
    for epoch in range(int(epochs)):
        for i, (images, labels) in enumerate(train_loader):
            # images = images.cuda()
            # labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            iter_num+=1
            if iter_num % 500==0:
                # calculate Accuracy
                correct = 0
                total = 0
                for images, labels in test_loader:
                    # images = images.cuda()
                    # labels = labels.cuda()
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total+= labels.size(0)
                    # for gpu, bring the predicted and labels back to cpu fro python operations to work
                    correct+= (predicted == labels).sum()
                accuracy = 100 * correct/total
                print("Iteration: {}. Loss: {}. Accuracy: {}%.".format(iter_num, loss.item(), accuracy))

    test_batch = next(iter(test_loader))
    data_input = test_batch[0]
    data_target = test_batch[1]
    # op = ModelHessianOperator(model, criterion, data_input.cuda(), data_target.cuda())
    op = ModelHessianOperator(model, criterion, data_input, data_target)
    size = to_vector(model.parameters()).shape[0]
    print ('The model has been trained')

    num_lanczos_vectors = int(0.5 * size)
    print('Starting Lanczoc method to find {} vectors'.format(num_lanczos_vectors))
    start = time.time()
    T, V = lanczos(operator=op, num_lanczos_vectors=num_lanczos_vectors, size=size, use_gpu=False)
    end = time.time()
    print(f'Time spent: {end - start}')
    time_spent_lst.append(end - start)
    # print(T)
    # print(V)



# ----------------------------------------- Results downloading -----------------------------------------

with open('params_num.txt', 'w') as f:
  for item in params_num_lst:
    f.write("%s\n" % item)

with open('time_spent.txt', 'w') as f:
  for item in time_spent_lst:
    f.write("%s\n" % item)

from google.colab import files
files.download('params_num.txt')
files.download('time_spent.txt')

------------ 1/15 model -------------
a = 25, b = 20, c = 15
parameters number: 11572
The model has been trained
Starting Lanczoc method to find 5786 vectors
Time spent: 91.53587698936462
------------ 2/15 model -------------
a = 30, b = 25, c = 20
parameters number: 13367
The model has been trained
Starting Lanczoc method to find 6683 vectors
Time spent: 127.41355729103088
------------ 3/15 model -------------
a = 35, b = 30, c = 25
parameters number: 15262
The model has been trained
Starting Lanczoc method to find 7631 vectors
Time spent: 181.0299243927002
------------ 4/15 model -------------
a = 40, b = 35, c = 30
parameters number: 17257
The model has been trained
Starting Lanczoc method to find 8628 vectors
Time spent: 268.0902545452118
------------ 5/15 model -------------
a = 45, b = 40, c = 35
parameters number: 19352
The model has been trained
Starting Lanczoc method to find 9676 vectors
Time spent: 264.46770763397217
------------ 6/15 model -------------
a = 50, b = 45, c = 

KeyboardInterrupt: ignored