In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from lib.models import ConvNet
from lib.train import train
from lib.BB import BB
from lib.AdaHessian import AdaHessian
import warnings
warnings.filterwarnings("ignore")

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:
cifar_trainset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
cifar_testset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(cifar_trainset), len(cifar_testset)

(50000, 10000)

In [5]:
train_dataloader = torch.utils.data.DataLoader(cifar_trainset, batch_size=256, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(cifar_testset, batch_size=256, shuffle=False)

In [6]:
device = 'cuda:1'

In [7]:
model = ConvNet()
model.to(device)
criterion = nn.CrossEntropyLoss()

In [8]:
optims = ['LBFGS', 'Adam', 'SGD', 
          'SGD momentum', 
          'LBFGS', 'AdaHessian', 'BB']

In [9]:
def init_optim(optim):
    if optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)
    elif optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
    elif optim == 'momentum':
        optimizer = torch.optim.SGD(model.parameters(), momentum = 0.9, lr = 1e-2)
    elif optim == 'LBFGS':
        optimizer = torch.optim.LBFGS(model.parameters(), lr = 1e-3)
    elif optim == 'AdaHessian':
        optimizer = AdaHessian(model.parameters(), lr = 1e-1)
    elif optim == 'BB':
        optimizer = BB(model.parameters(), lr = 5e-3)
    return optimizer

In [10]:
optimizer = torch.optim.LBFGS(model.parameters(), lr = 1e-3)

In [None]:
res = train(model, optimizer, criterion, train_dataloader, 
            valid_dataloader, device, epochs=20, optim='LBFGS', verbose=True)

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
results = {}
for optim in optims:
    optimizer = init_optim(optim)
    start = timeit.timeit()
    res = train(model, optimizer, criterion, train_dataloader, 
            valid_dataloader, device, optim=optim, epochs=20, verbose=False)
    results[optim] = res

In [None]:
results