In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer
import timeit

import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

In [None]:
def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

In [None]:
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = resnet20()
model.to(device)
torch.save(model.state_dict(), "cifar10_resnet.pth")
criterion = nn.CrossEntropyLoss()

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ]), download=True),
    batch_size=128, shuffle=True,
    num_workers=4, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=64, shuffle=False,
    num_workers=4, pin_memory=True)

In [None]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    
    acc = torch.round(acc * 100)
    
    return acc

learning_rate_list = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.08, 0.1]

In [None]:
adam_data = []

for learning_rate in learning_rate_list:

    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr = learning_rate)

    timelist_adam = []
    losslist_adam_train = []
    acclist_adam_train = []
    losslist_adam_test = []
    acclist_adam_test = []

    start_time = timeit.default_timer()

    for e in range(1, 200+1):
        epoch_loss = 0
        epoch_acc = 0
        model.train()
        for X_batch, y_batch in train_loader:

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            y_pred = model(X_batch)

            loss = criterion(y_pred, y_batch)
            acc = multi_acc(y_pred, y_batch)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_time = timeit.default_timer()


        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()


        if e % 50 == 0:
            print(f'Epoch {e+0:03}: | train Loss: {epoch_loss/len(train_loader):.5f} | train Acc: {epoch_acc/len(train_loader):.3f}')
            print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')


        epoch_time = timeit.default_timer()
        timelist_adam.append(epoch_time - start_time)
        losslist_adam_train.append(epoch_loss/len(train_loader))
        acclist_adam_train.append(epoch_acc/len(train_loader))
        losslist_adam_test.append(val_epoch_loss/len(val_loader))
        acclist_adam_test.append(val_epoch_acc/len(val_loader))
    
    adam_data.append(timelist_adam)
    adam_data.append(losslist_adam_train)
    adam_data.append(acclist_adam_train)
    adam_data.append(losslist_adam_test)
    adam_data.append(acclist_adam_test)


In [None]:
sgd_data = []

for learning_rate in learning_rate_list:

    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = optim.SGD(model.parameters(), lr = learning_rate)

    timelist_sgd = []
    losslist_sgd_train = []
    acclist_sgd_train = []
    losslist_sgd_test = []
    acclist_sgd_test = []

    start_time = timeit.default_timer()

    for e in range(1, 200+1):
        epoch_loss = 0
        epoch_acc = 0
        model.train()
        for X_batch, y_batch in train_loader:

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            y_pred = model(X_batch)

            loss = criterion(y_pred, y_batch)
            acc = multi_acc(y_pred, y_batch)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_time = timeit.default_timer()


        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()


        if e % 50 == 0:
            print(f'Epoch {e+0:03}: | train Loss: {epoch_loss/len(train_loader):.5f} | train Acc: {epoch_acc/len(train_loader):.3f}')
            print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')


        epoch_time = timeit.default_timer()
        timelist_sgd.append(epoch_time - start_time)
        losslist_sgd_train.append(epoch_loss/len(train_loader))
        acclist_sgd_train.append(epoch_acc/len(train_loader))
        losslist_sgd_test.append(val_epoch_loss/len(val_loader))
        acclist_sgd_test.append(val_epoch_acc/len(val_loader))
    
    sgd_data.append(timelist_sgd)
    sgd_data.append(losslist_sgd_train)
    sgd_data.append(acclist_sgd_train)
    sgd_data.append(losslist_sgd_test)
    sgd_data.append(acclist_sgd_test)


In [None]:
import torch_optimizer

adahes_data = []

for learning_rate in learning_rate_list:

    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = torch_optimizer.Adahessian(
    model.parameters(),
    lr= learning_rate,
    betas= (0.9, 0.999),
    eps= 1e-4,
    weight_decay=0.0,
    hessian_power=1.0,
    )

    timelist_adahes = []
    losslist_adahes_train = []
    acclist_adahes_train = []
    losslist_adahes_test = []
    acclist_adahes_test = []

    start_time = timeit.default_timer()

    for e in range(1, 200+1):
        epoch_loss = 0
        epoch_acc = 0

        model.train()
        for X_batch, y_batch in train_loader:

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            y_pred = model(X_batch)

            loss = criterion(y_pred, y_batch)
            acc = multi_acc(y_pred, y_batch)

            loss.backward(create_graph = True)

            optimizer.step()

            epoch_loss += loss.item()
            epoch_acc += acc.item()

        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()


        if e % 50 == 0:
            print(f'Epoch {e+0:03}: | train Loss: {epoch_loss/len(train_loader):.5f} | train Acc: {epoch_acc/len(train_loader):.3f}')
            print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')


        epoch_time = timeit.default_timer()
        losslist_adahes_train.append(epoch_loss/len(train_loader))
        timelist_adahes.append(epoch_time - start_time)
        acclist_adahes_train.append(epoch_acc/len(train_loader))
        losslist_adahes_test.append(val_epoch_loss/len(val_loader))
        acclist_adahes_test.append(val_epoch_acc/len(val_loader))
    
    adahes_data.append(timelist_adahes)
    adahes_data.append(losslist_adahes_train)
    adahes_data.append(acclist_adahes_train)
    adahes_data.append(losslist_adahes_test)
    adahes_data.append(acclist_adahes_test)

In [None]:
lbfgs_data = []


for learning_rate in learning_rate_list:

    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = optim.LBFGS(model.parameters(), lr = learning_rate, history_size=10, line_search_fn='strong_wolfe')

    timelist_lbfgs = []
    losslist_lbfgs_test = []
    acclist_lbfgs_test = []

    start_time = timeit.default_timer()

    for e in range(1, 200+1):

        model.train()
        
        for X_batch, y_batch in train_loader:

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            def closure():
                optimizer.zero_grad()
                y_pred = model(X_batch)
                loss = criterion(y_pred, y_batch)
                loss.backward()
                return loss
            optimizer.step(closure)


        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()


        if e % 50 == 0:
             print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')

        epoch_time = timeit.default_timer()
        timelist_lbfgs.append(epoch_time - start_time)
        losslist_lbfgs_test.append(val_epoch_loss/len(val_loader))
        acclist_lbfgs_test.append(val_epoch_acc/len(val_loader))
    
    lbfgs_data.append(timelist_lbfgs)
    lbfgs_data.append(losslist_lbfgs_test)
    lbfgs_data.append(acclist_lbfgs_test)
    

In [None]:
def group_product(xs, ys):
    
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])

def normalization(v):
    # normalize a vector
    
    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item()
    v = [vi / (s + 1e-6) for vi in v]
    return v

class NysHessianpartial():
    
    def __init__(self, rank, rho):
        self.rank = rank
        # rho is the regularization in Nystrom sketch
        self.rho = rho
    
    def get_params_grad(self, model):
        # get parameters and differentiation
        params = []
        grads = []
        for param in model.parameters():
            if not param.requires_grad:
                continue
            params.append(param)
            grads.append(0. if param.grad is None else param.grad + 0.)
        return params, grads
    
    def update_Hessian(self, X_batch, y_batch, model, criterion, device):
        
        shift = 0.001
        # get the model parameters and gradients
        params, gradsH = self.get_params_grad(model)
        # remember the size for each group of parameters
        self.size_vec = [p.size() for p in params]
        # store random gaussian vector to a matrix
        test_matrix = []
        # Hessian vector product
        hv_matrix = []
        
        for i in range(self.rank):
            # generate gaussian random vector
            v = [torch.randn(p.size()).to(device) for p in params]
            # normalize
            v = normalization(v)
            # zero vector to store the shape
            hv_add = [torch.zeros(p.size()).to(device) for p in params]
        
            # update hessian with a subsample batch
            
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            model.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward(create_graph=True)
            params, gradsH = self.get_params_grad(model)
            # calculate the Hessian vector product
            hv = torch.autograd.grad(gradsH, params, grad_outputs=v,only_inputs=True,retain_graph=True)
            # add initial shift
            for i in range(len(hv)):
                hv_add[i].data = hv[i].data.add_(hv_add[i].data)    
                hv_add[i].data = hv_add[i].data.add_(v[i].data * torch.tensor(shift)) 
            
            # reshape the Hessian vector product into a long vector
            hv_ex = torch.cat([gi.reshape(-1) for gi in hv_add])
            # reshape the random vector into a long vector
            test_ex = torch.cat([gi.view(-1) for gi in v])
            
            # append long vectors into a large matrix
            hv_matrix.append(hv_ex)
            test_matrix.append(test_ex)
        
        # assemble the large matrix
        hv_matrix_ex = torch.column_stack(hv_matrix)
        test_matrix_ex = torch.column_stack(test_matrix)
        # calculate Omega^T * A * Omega for Cholesky
        choleskytarget = torch.mm(test_matrix_ex.t(), hv_matrix_ex)
        # perform Cholesky, if fails, do eigendecomposition
        # the new shift is the abs of smallest eigenvalue (negative) plus the original shift
        try:
            C_ex = torch.linalg.cholesky(choleskytarget)
        except:
            # eigendecomposition, eigenvalues and eigenvector matrix
            eigs, eigvectors = torch.linalg.eigh(choleskytarget)
            shift = shift + torch.abs(torch.min(eigs))
            # add shift to eigenvalues
            eigs = eigs + shift
            # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T 
            C_ex = torch.linalg.cholesky(torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)))
        
        # triangular solve
        # B_ex = torch.linalg.solve_triangular(C_ex, hv_matrix_ex, upper = False, left = False)
        B_ex = torch.triangular_solve(hv_matrix_ex.t(), C_ex.t(), upper = True)
        # SVD
        # U, S, V = torch.linalg.svd(B_ex, full_matrices = False)
        U, S, V = torch.linalg.svd(B_ex[0].t(), full_matrices = False)
        self.U = U
        self.S = torch.max(torch.square(S) - torch.tensor(shift), torch.tensor(0.0))

class NysHessianOpt(Optimizer):
    r"""Implements NysHessian.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        rank (int): sketch rank
        rho: regularization
    """
    def __init__(self, params, rank = 100, rho = 0.1):
        # initialize the optimizer    
        defaults = dict(rank = rank, rho = rho)
        self.nysh = NysHessianpartial(rank, rho)
        super(NysHessianOpt, self).__init__(params, defaults)
         
    def step(self, lr):
        # one step update
        for group in self.param_groups:
            rho = group['rho']
            # compute gradient as a long vector
            g = torch.cat([p.grad.view(-1) for p in group['params']])
            # calculate the search direction by Nystrom sketch and solve
            UTg = torch.mv(self.nysh.U.t(), g) 
            g_new = torch.mv(self.nysh.U, (self.nysh.S + rho).reciprocal() * UTg) + g / rho - torch.mv(self.nysh.U, UTg) / rho            
            ls = 0
            # update model parameters
            for p in group['params']:
                gp = g_new[ls:ls+torch.numel(p)].view(p.shape)
                ls += torch.numel(p)
                p.data.add_(-lr * gp)

In [None]:
skechysgd_data = []

hes_interval = 2 * len(train_loader) - 1
# update Hessian and Nystrom sketch every couple of steps

for learning_rate in learning_rate_list:
    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = NysHessianOpt(model.parameters())

    hes_iter = 0

    timelist_skechysgd = []
    losslist_skechysgd_train = []
    acclist_skechysgd_train = []
    losslist_skechysgd_test = []
    acclist_skechysgd_test = []

    lr = torch.tensor(learning_rate)

    start_time = timeit.default_timer()


    for e in range(1, 200+1):
        epoch_loss = 0
        epoch_acc = 0
        model.train()
        for X_batch, y_batch in train_loader:

            if hes_iter % hes_interval == 0:
                # update Hessian and sketch
                optimizer.nysh.update_Hessian(X_batch, y_batch, model, criterion, device)

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            y_pred = model(X_batch)

            loss = criterion(y_pred, y_batch)
            acc = multi_acc(y_pred, y_batch)

            loss.backward()

            optimizer.step(lr)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            hes_iter += 1
            epoch_time = timeit.default_timer()


        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()


        if e % 50 == 0:
            print(f'Epoch {e+0:03}: | train Loss: {epoch_loss/len(train_loader):.5f} | train Acc: {epoch_acc/len(train_loader):.3f}')
            print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')


        epoch_time = timeit.default_timer()
        timelist_skechysgd.append(epoch_time - start_time)
        losslist_skechysgd_train.append(epoch_loss/len(train_loader))
        acclist_skechysgd_train.append(epoch_acc/len(train_loader))
        losslist_skechysgd_test.append(val_epoch_loss/len(val_loader))
        acclist_skechysgd_test.append(val_epoch_acc/len(val_loader))
    
    skechysgd_data.append(timelist_skechysgd)
    skechysgd_data.append(losslist_skechysgd_train)
    skechysgd_data.append(acclist_skechysgd_train)
    skechysgd_data.append(losslist_skechysgd_test)
    skechysgd_data.append(acclist_skechysgd_test)

In [None]:
def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x*y) for (x, y) in zip(xs, ys)])

def group_add(params, update, alpha=1):
    """
    params = params + update*alpha
    :param params: list of variable
    :param update: list of data
    :return:
    """
    for i,p in enumerate(params):
        params[i].data.add_(update[i]*alpha) 
    return params

def get_params_grad(model):
        # get parameters and differentiation
        params = []
        grads = []
        for param in model.parameters():
            if not param.requires_grad:
                continue
            params.append(param)
            grads.append(0. if param.grad is None else param.grad + 0.)
        return params, grads

class NewtonCG(Optimizer):

    def __init__(self, params):
        # initialize the optimizer    
        defaults = dict()
        super(NewtonCG, self).__init__(params, defaults)
        
    def cg_step(self, g, gradsH, cg_iter, tol):
        gnorms = group_product(g, g)
        params = self.param_groups[0]['params']
        weight = 0.0
        zs = [0.0*p.data for p in params]
        rs = [g.data + weight*p.data for g,p in zip(g, params)]
        ds = [g.data - weight*p.data for g,p in zip(g, params)]
        
        for i in range(cg_iter):
            if gnorms <= tol:
                return zs
            if i != 0:
                ratio = gnorms / gnorms_prev
                ds = [rsd.data + ratio * dsd.data for rsd, dsd in zip(rs, ds)]
            hv = torch.autograd.grad(gradsH, params, grad_outputs=ds,only_inputs=True,retain_graph=True)
            alpha = gnorms / group_product(ds, hv)
            zs = [zsd.data + alpha * dsd.data for zsd, dsd in zip(zs, ds)]
            rs = [rsd.data - alpha * hvd.data for rsd, hvd in zip(rs, hv)]
            gnorms_prev = gnorms
            gnorms = group_product(rs, rs)
        
        return zs
         
    def step(self, lr, gradsH, cg_iter, tol):
        # one step update
        for group in self.param_groups:
            g = [p.grad for p in group['params']]
            g_new = self.cg_step(g, gradsH, cg_iter, tol)
            # update model parameters
            for p, g_newgroup in zip(group['params'], g_new):
                p.data.add_(-lr * g_newgroup)

In [None]:
trainset = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ]), download=True),
    batch_size=128, shuffle=True,
    num_workers=4, pin_memory=True)

train_full_data = torch.randn(50000, 3, 32, 32)
train_full_label = torch.zeros(50000).type(torch.LongTensor)

for i, (d, l) in enumerate(trainset):
    train_full_data[128*i:128*(i+1), :] = d
    train_full_label[128*i:128*(i+1)] = l

n = len(train_loader.dataset)
batch_size = train_loader.batch_size

In [None]:
newtoncg_data = []

for learning_rate in learning_rate_list:

    model = resnet20()
    model.load_state_dict(torch.load("cifar10_resnet.pth"))

    model.to(device)

    optimizer = NewtonCG(model.parameters())

    timelist_newtoncg = []
    losslist_newtoncg_test = []
    acclist_newtoncg_test = []

    start_time = timeit.default_timer()

    tol = 1e-6
    cg_iter = 10
    lr = torch.tensor(learning_rate)
    cgupdate = 0

    for e in range(1, 100+1):

        model.train()
        for X_batch, y_batch in train_loader:

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            hess_data = train_full_data[cgupdate * 128:(cgupdate + 1)*128, :].to(device)
            hess_label = train_full_label[cgupdate * 128:(cgupdate + 1)*128].to(device)

            model.zero_grad()
            outputs_h = model(hess_data)
            loss = criterion(outputs_h, hess_label)
            loss.backward(create_graph=True)
            params, gradsH = get_params_grad(model)

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()

            optimizer.step(lr, gradsH, cg_iter, tol)
            cgupdate += 1
            cgupdate = cgupdate % len(trainset)

        with torch.no_grad():

            val_epoch_loss = 0
            val_epoch_acc = 0

            model.eval()
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)

                y_val_pred = model(X_val_batch)

                val_loss = criterion(y_val_pred, y_val_batch)
                val_acc = multi_acc(y_val_pred, y_val_batch)

                val_epoch_loss += val_loss.item()
                val_epoch_acc += val_acc.item()

        if e % 50 == 0:
            print(f'Epoch {e+0:03}: | test Loss: {val_epoch_loss/len(val_loader):.5f} | test Acc: {val_epoch_acc/len(val_loader):.3f}')

        epoch_time = timeit.default_timer()
        timelist_newtoncg.append(epoch_time - start_time)
        losslist_newtoncg_test.append(val_epoch_loss/len(val_loader))
        acclist_newtoncg_test.append(val_epoch_acc/len(val_loader))
    
    newtoncg_data.append(timelist_newtoncg)
    newtoncg_data.append(losslist_newtoncg_test)
    newtoncg_data.append(acclist_newtoncg_test)