In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
seed = 1
batch_size = 64
test_batch_size = 64
momentum = 0.9
epochs = 50
log_interval = 100
lr = 0.01
save_model = False
num_classes = 10
criterion = nn.CrossEntropyLoss()
torch.backends.cudnn.benchmark=True


In [3]:
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5, 1)
        self.conv2 = nn.Conv2d(64, 256, 5, 1)
        self.fc1= nn.Linear(6400, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.dropout1 = nn.Dropout2d(0.2)
        self.dropout2 = nn.Dropout2d(0.2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.dropout1(x)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return F.relu(self.fc2(x))

In [4]:
def train(model, device, train_loader, optimizer, log_interval, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [5]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).sum().item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [6]:
use_cuda =  torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

torch.manual_seed(seed)


train_dataset=datasets.CIFAR10('./data_c', train=True, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
#                       transforms.Normalize((0.1307,), (0.3081,))
                       
                   ]))

test_dataset=datasets.CIFAR10('./data_c', train=False, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
#                       transforms.Normalize((0.1307,), (0.3081,))
    ]))

image, label = train_dataset[0]
print(image.size())
print(label)

Files already downloaded and verified
Files already downloaded and verified
torch.Size([3, 32, 32])
6


In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=test_batch_size, shuffle=False, num_workers=2)


In [8]:
for images, labels in train_loader:
    print(images.size())
    print(images[0].size())
    print(labels.size())
    
    break

torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([64])


In [9]:
model = CNNNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=5e-4)

def run():
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, log_interval, epoch)
        test(model, device, test_loader)

    if (save_model):
        torch.save(model.state_dict(),"mlp.pt")

In [10]:
%%time

run()


Test set: Average loss: 0.0227, Accuracy: 4816/10000 (48%)


Test set: Average loss: 0.0190, Accuracy: 5720/10000 (57%)


Test set: Average loss: 0.0177, Accuracy: 6004/10000 (60%)


Test set: Average loss: 0.0159, Accuracy: 6444/10000 (64%)


Test set: Average loss: 0.0154, Accuracy: 6558/10000 (66%)


Test set: Average loss: 0.0137, Accuracy: 6946/10000 (69%)


Test set: Average loss: 0.0140, Accuracy: 6975/10000 (70%)


Test set: Average loss: 0.0134, Accuracy: 7099/10000 (71%)


Test set: Average loss: 0.0130, Accuracy: 7208/10000 (72%)


Test set: Average loss: 0.0131, Accuracy: 7249/10000 (72%)


Test set: Average loss: 0.0129, Accuracy: 7226/10000 (72%)


Test set: Average loss: 0.0129, Accuracy: 7284/10000 (73%)


Test set: Average loss: 0.0139, Accuracy: 7228/10000 (72%)


Test set: Average loss: 0.0133, Accuracy: 7335/10000 (73%)


Test set: Average loss: 0.0136, Accuracy: 7350/10000 (74%)


Test set: Average loss: 0.0149, Accuracy: 7257/10000 (73%)


Test set: Average loss:


Test set: Average loss: 0.0149, Accuracy: 7297/10000 (73%)


Test set: Average loss: 0.0152, Accuracy: 7283/10000 (73%)


Test set: Average loss: 0.0153, Accuracy: 7317/10000 (73%)


Test set: Average loss: 0.0150, Accuracy: 7358/10000 (74%)


Test set: Average loss: 0.0149, Accuracy: 7377/10000 (74%)


Test set: Average loss: 0.0156, Accuracy: 7381/10000 (74%)


Test set: Average loss: 0.0156, Accuracy: 7352/10000 (74%)


Test set: Average loss: 0.0156, Accuracy: 7407/10000 (74%)


Test set: Average loss: 0.0158, Accuracy: 7396/10000 (74%)


Test set: Average loss: 0.0158, Accuracy: 7343/10000 (73%)


Test set: Average loss: 0.0160, Accuracy: 7303/10000 (73%)


Test set: Average loss: 0.0163, Accuracy: 7283/10000 (73%)


Test set: Average loss: 0.0164, Accuracy: 7309/10000 (73%)


Test set: Average loss: 0.0165, Accuracy: 7355/10000 (74%)


Test set: Average loss: 0.0156, Accuracy: 7443/10000 (74%)


Test set: Average loss: 0.0166, Accuracy: 7351/10000 (74%)


Test set: Average loss:


Test set: Average loss: 0.0158, Accuracy: 7390/10000 (74%)


Test set: Average loss: 0.0159, Accuracy: 7380/10000 (74%)


Test set: Average loss: 0.0158, Accuracy: 7385/10000 (74%)


Test set: Average loss: 0.0158, Accuracy: 7399/10000 (74%)


Test set: Average loss: 0.0159, Accuracy: 7376/10000 (74%)


Test set: Average loss: 0.0165, Accuracy: 7303/10000 (73%)


Test set: Average loss: 0.0162, Accuracy: 7374/10000 (74%)


Test set: Average loss: 0.0163, Accuracy: 7314/10000 (73%)


Test set: Average loss: 0.0154, Accuracy: 7352/10000 (74%)


Test set: Average loss: 0.0161, Accuracy: 7319/10000 (73%)


Test set: Average loss: 0.0158, Accuracy: 7369/10000 (74%)


Test set: Average loss: 0.0157, Accuracy: 7322/10000 (73%)


Test set: Average loss: 0.0164, Accuracy: 7279/10000 (73%)


Test set: Average loss: 0.0157, Accuracy: 7418/10000 (74%)


Test set: Average loss: 0.0162, Accuracy: 7370/10000 (74%)

CPU times: user 4min 28s, sys: 1min 18s, total: 5min 47s
Wall time: 6min 5s
