In [1]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms

In [2]:
data_path = 'data/'
batch_size = 1

mean_tuple, std_tuple = (0.4914, 0.48216, 0.44653), (0.24703, 0.24349, 0.26159)

transformer = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean_tuple, std_tuple)
                                     ])

training_set = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transformer)
test_set = datasets.CIFAR10(root=data_path, train=False, transform=transformer)

train_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

Files already downloaded and verified


In [3]:
class BridgeNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.vgg19_layernames = list(models.vgg19(pretrained=True).children())[:-1]
        self.feature = nn.Sequential(*self.vgg19_layernames)
    
        self.cifar10_layer_list = nn.ModuleList([
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.LogSoftmax(dim=1)
        ])
                
    def forward(self, x):    
        x = self.feature(x)
        x = x.view(-1, 512)
        
        for layer in self.cifar10_layer_list:
            x = layer(x)
        
        return x

In [11]:
from sklearn.metrics import accuracy_score

model = BridgeNet()
loss_func = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

def train(model, train_loader, test_loader, loss_func, optimizer, num_epochs, log_interval=20):
    loss_list, acc_list = [], []
    loss_test_list, acc_test_list = [], []
    
    for epoch in range(num_epochs):
        for i, (data, target) in enumerate(train_loader):
            if torch.cuda.is_available():
                model = model.cuda()
                data, target = data.cuda(), target.cuda()
                
            data, target = Variable(data, requires_grad=True), Variable(target)
#             target = one_hot_embedding(target, 10)
            
            optimizer.zero_grad()
            output = model(data)
            loss_sum = loss_func(output, target).mean()
            
            loss_sum.backward()
            optimizer.step()
            
            loss_list.append(loss_sum.detach())
            
            if i % log_interval == 0:
                print("Epoch: {0}, Iter: {1}, Train loss: {2:.4f}".format(epoch, i, loss_sum))
            
    model.eval()
    
    test_loss = 0
    correct = 0
    num_rows = 0
                  
    with torch.no_grad():
        for i, (data_test, target_test) in enumerate(test_loader):
            if torch.cuda.is_available():
                data_test, target_test = data_test.cuda(), target_test.cuda()

            data_test, target_test = Variable(data_test), Variable(target_test)
            
            output = model(data_test)
            loss_test = loss_func(output, target_test).mean()
            
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
        test_loss /= num_rows

        print("=====================================================================")
        print("Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
            test_loss, correct, num_rows, (100. * correct.item()) / num_rows))
        print("=====================================================================")

In [12]:
train(model, train_loader, test_loader, loss_func, optimizer, num_epochs=10)

Epoch: 0, Iter: 0, Train loss: 2.3476
Epoch: 0, Iter: 20, Train loss: 2.9893
Epoch: 0, Iter: 40, Train loss: 2.2469
Epoch: 0, Iter: 60, Train loss: 2.3012
Epoch: 0, Iter: 80, Train loss: 2.3290
Epoch: 0, Iter: 100, Train loss: 2.2997
Epoch: 0, Iter: 120, Train loss: 2.2877
Epoch: 0, Iter: 140, Train loss: 2.0350
Epoch: 0, Iter: 160, Train loss: 2.2418
Epoch: 0, Iter: 180, Train loss: 2.3570
Epoch: 0, Iter: 200, Train loss: 2.3390
Epoch: 0, Iter: 220, Train loss: 2.2546
Epoch: 0, Iter: 240, Train loss: 2.3376
Epoch: 0, Iter: 260, Train loss: 2.3337
Epoch: 0, Iter: 280, Train loss: 2.6288
Epoch: 0, Iter: 300, Train loss: 2.2586
Epoch: 0, Iter: 320, Train loss: 2.3534
Epoch: 0, Iter: 340, Train loss: 2.2132
Epoch: 0, Iter: 360, Train loss: 2.2299
Epoch: 0, Iter: 380, Train loss: 2.4397
Epoch: 0, Iter: 400, Train loss: 2.2697
Epoch: 0, Iter: 420, Train loss: 2.3274
Epoch: 0, Iter: 440, Train loss: 2.3153
Epoch: 0, Iter: 460, Train loss: 2.2340
Epoch: 0, Iter: 480, Train loss: 2.3025
Epoch:

KeyboardInterrupt: 

In [None]:
torch.__version__