In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable, grad
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.optimizer import Optimizer, required
import random
import matplotlib.pyplot as plt
import numpy as np
import copy
from data.stanford_dogs_data import dogs
torch.manual_seed(0)

In [None]:
batch_size = 32
n_epoch = 2
learning_rate = 0.01
input_size=224

In [None]:
input_transforms = transforms.Compose([
            transforms.RandomResizedCrop(input_size, ratio=(1, 1.3)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
            ])

trainset = dogs(root='./data',train=True,cropped=False,transform=input_transforms,download=True)
testset = dogs(root='./data',train=False,cropped=False,transform=input_transforms,download=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False, num_workers=2)


In [None]:
#SVRG Optimizer inherit from torch.optim.optimizer
class SVRG(Optimizer):
    def __init__(self, params, learn_rate=required, SVRG_inner_freq =20):
        if learn_rate is not required and learn_rate < 0.0:
            raise ValueError("Invalid learning rate: {}".format(learn_rate))

        defaults = dict(learn_rate=learn_rate, freq=SVRG_inner_freq)
        self.counter = 0
        self.counter2 = 0
        self.flag = False
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            freq = group['freq']
            for param in group['params']:
                if param.grad is None:
                    continue
                w_grad = param.grad.data
                param_state = self.state[param]
                
                if 'outer' not in param_state:
                    u_hat = param_state['outer'] = torch.zeros_like(param.data)
                    u_hat.add_(w_grad)
                    w_hat = param_state['inner'] = torch.zeros_like(param.data)

                u_hat = param_state['outer']
                w_hat = param_state['inner']

                if self.counter == freq:
                    u_hat.data = w_grad.clone()
                    temp = torch.zeros_like(param.data)
                    w_hat.data = temp.clone()
                    
                if self.counter2 == 1:
                    w_hat.data.add_(w_grad)

                #dont update parameters when computing large batch (low variance gradients)
                if self.counter != freq and self.flag != False:
                    param.data.add_(-group['learn_rate'], (w_grad - w_hat + u_hat) )

        self.flag = True
        
        if self.counter == freq:
            self.counter = 0
            self.counter2 = 0

        self.counter += 1    
        self.counter2 += 1

        return loss

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.batchnorm1 = nn.BatchNorm2d(6)
        self.dropout1 = nn.Dropout(p = 0.1)

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.batchnorm2 = nn.BatchNorm2d(16)
        self.dropout2 = nn.Dropout(p = 0.1)
        self.fc1 = nn.Linear(16*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 120)

    def forward(self, x):
        x = self.dropout1(self.pool(F.relu(self.batchnorm1(self.conv1(x)))))
        x = self.dropout2(self.pool(F.relu(self.batchnorm2(self.conv2(x)))))
        x = x.view(-1, 16*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def partial_grad(self, data, target, loss_function):
        outputs = self.forward(data)
        loss = loss_function(outputs, target)
        loss.backward()
        return loss
    
    def calculate_loss_grad(self, dataset, loss_function, n_samples):
        total_loss = 0.0
        full_grad = 0.0
        for i_grad, data_grad in enumerate(dataset):
            inputs, labels = data_grad
            inputs, labels = Variable(inputs), Variable(labels) #wrap data and target into variable
            total_loss += (1./n_samples) * self.partial_grad(inputs, labels, loss_function).item()
        
        for para in self.parameters():
            full_grad += para.grad.data.norm(2)**2
        
        return total_loss, (1./n_samples) * np.sqrt(full_grad)
    
    def backward(self, dataset,testloader, loss_function, n_epoch, learning_rate):
        total_loss_epoch = [0 for i in range(n_epoch)]
        grad_norm_epoch = [0 for i in range(n_epoch)]
        net.train()
        record=[]
        train_re=[]
        test_re=[]
        #SGD
#         optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
        #SVRG2
#         optimizer = SVRG(net.parameters(), lr=0.01)
        #Adam
#         optimizer = optim.Adam(net.parameters(), lr=0.01)
        iters=0
        for epoch in range(n_epoch):
            running_loss = 0.0
            
            #SVRG1 parameters start
            w_hat = copy.deepcopy(self)
            u_hat = copy.deepcopy(self)
            u_hat.zero_grad()
            total_loss_epoch[epoch], grad_norm_epoch[epoch] = u_hat.calculate_loss_grad(dataset, loss_function, n_samples)
            #SVRG1 parameters end
            
            for i_data, data in enumerate(dataset):
                inputs, labels = data
                #SGD,Adam,SVRG2
#                 optimizer.zero_grad()
#                 outputs = net(inputs)
#                 loss = criterion(outputs, labels)
#                 loss.backward()
#                 optimizer.step()
#                 running_loss += loss.item()

                #SVRG1 backward start
                w_hat.zero_grad() 
                prev_loss = w_hat.partial_grad(inputs, labels, loss_function)
                self.zero_grad()
                cur_loss = self.partial_grad(inputs, labels, loss_function)
                for param1, param2, param3 in zip(self.parameters(), w_hat.parameters(), u_hat.parameters()): 
                    param1.data -= (learning_rate) * (param1.grad.data - param2.grad.data + (1./n_samples) * param3.grad.data)
                running_loss += cur_loss.item()
                #SVRG1 backward end
                
                iters+=1
                
                #evalue present weight accuracy
                if iters%50==0:
                    train_loss, _ = eval_net(dataset)
                    test_loss, _ = eval_net(testloader)
                    print(epoch)
                    record.append(running_loss/50)
                    train_re.append(train_loss)
                    test_re.append(test_loss)
                    running_loss=0.0
                  
        return total_loss_epoch, grad_norm_epoch,record,train_re,test_re
def eval_net(dataloader):
    correct = 0
    total = 0
    total_loss = 0
    net.eval()
    criterion = nn.CrossEntropyLoss(size_average=False)
    for data in dataloader:
        images, labels = data
        images, labels = Variable(images), Variable(labels)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.data).sum()
        loss = criterion(outputs, labels)
        total_loss += loss.item()
    net.train() 
    return total_loss / total, correct / total

In [None]:
net = Net()
criterion = nn.CrossEntropyLoss()
n_samples = len(trainloader)
print(n_samples)
# SGD
# record,Train_record,Test_record=[],[],[]
#SVRG1
record1,Train_record1,Test_record1=[],[],[]
#Adam
# record2,Train_record2,Test_record2=[],[],[]
#SVRG2
# record3,Train_record3,Test_record3=[],[],[]

total_loss_epoch, grad_norm_epoch,record1,Train_record1,Test_record1= net.backward(trainloader,testloader, criterion, n_epoch, learning_rate)

print('Finished Training')

In [None]:
plt.plot(record1,label='Train')
plt.plot(Train_record1,label='Train_valid')
plt.plot(Test_record1,label='Test_valid')
plt.title('SVRG1')
plt.legend()
plt.show()

In [None]:
plt.plot(record,label='Train')
plt.plot(Train_record,label='Train_valid')
plt.plot(Test_record,label='Test_valid')
plt.title('SGD')
plt.legend()
plt.show()