In [1]:
%matplotlib inline

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as dtrans

print(torch.__version__)
print(torchvision.__version__)

1.0.1.post2
0.2.2


In [49]:
device = 'cuda'
batch_size = 64
torch.manual_seed(10)

<torch._C.Generator at 0x7f02edab0a70>

In [42]:
# transform =  dtrans.Compose([dtrans.ToTensor()])
# train_data = dset.MNIST(root='../../data/', train=True, transform=transform)
# train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# transform =  dtrans.Compose([dtrans.ToTensor()])
# test_data = dset.MNIST(root='../../data/', train=False, transform=transform)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [44]:
transform =  dtrans.Compose([dtrans.RandomHorizontalFlip(),
                             dtrans.RandomCrop(size=(32,32), padding=(4,4)),
                             dtrans.ToTensor(),
                             dtrans.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
train_data = dset.CIFAR10(root='../../data/', train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

transform =  dtrans.Compose([dtrans.ToTensor(),
                            dtrans.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
test_data = dset.CIFAR10(root='../../data/', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [34]:
class LookAhead(optim.Optimizer):
    def __init__(self, params, alpha, k, base_optim, A_args):
        self.alpha = alpha
        self.A = base_optim
        self.A_args = A_args
        self.k = k
        self.inner_step = 0
        
        self._sync_params(params)
#         self._reset_A(self.theta, A_args)
        
        defaults = {}
        super(LookAhead, self).__init__(self.phi, defaults)
        
    def _sync_params(self, params):
        with torch.no_grad():
            self.theta = list(params)
            self.phi = []
            for p in self.theta:
                new_p = torch.tensor(p.data)
                self.phi.append(new_p)
                
#     def _reset_A(self, params, A_args):
#         self.A = optim.SGD(params, **A_args)
    
    def step(self, closure=None):
        if self.inner_step == self.k:   
            self.inner_step = 0
    
            for pg1, pg2 in zip(self.param_groups, self.A.param_groups):
                for p1, p2 in zip(pg1['params'], pg2['params']):
                    p1.data = p1.data + self.alpha * (p2.data - p1.data)
                    p2.data = p1.data.clone()
        else:
            self.inner_step += 1
            self.A.step()
            
    def zero_grad(self):
        self.A.zero_grad()

In [31]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28*1, 100)
        self.fc2 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return out

In [12]:
def train(model, optimiser, train_loader, device):
    model.train()
    
    loss_epoch, error = 0, 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        out = model(images)
        loss = F.cross_entropy(out, labels)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        loss_epoch += loss.item()
        error += (out.argmax(1) != labels).sum().item()
        
    print('[Train] Loss: %.4f, Err: %d,' %(loss_epoch/i, error))

    
def test(model, test_loader, device):
    model.eval()
    
    loss_epoch, error = 0, 0
    for i, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            loss = F.cross_entropy(out, labels)        
        loss_epoch += loss.item()
        error += (out.argmax(1) != labels).sum().item()        
    print('[Test] Loss: %.4f, Err: %d,' %(loss_epoch/i, error))

In [55]:
epochs = 200

# net = MLP().to(device)
from models import resnet
net = resnet.ResNet18().to(device)

A_args = {'lr': 0.5, 'weight_decay': 0.0001}
sgd_optim = optim.SGD(net.parameters(),**A_args)
sgd_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=sgd_optim, milestones=[100, 150])
la_optim = LookAhead(net.parameters(), alpha=0.5, k=5, base_optim=sgd_optim, A_args=A_args)



In [56]:
for epoch in range(epochs):
    test(net, test_loader, device)
    train(net, la_optim, train_loader, device)
    sgd_scheduler.step(epoch)

[Test] Loss: 2.3178, Err: 9000,


KeyboardInterrupt: 