In [6]:
import torch as torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision as torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [None]:
t = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# train set
trainset = torchvision.datasets.CIFAR10(download=True, root='./data', train=True, transform=t)
trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=2)

# test set
testset = torchvision.datasets.CIFAR10(download=True, root='./data', train=False, transform=t)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=4, shuffle=False, num_workers=2)

# classes
classes = ('plane', 'car', 'bird', 'cat',\
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


In [8]:
class Net(nn.Module):
    def __init__(self):
        super()
        
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.conv4 = nn.Conv2d(256, 256, 3)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv5 = nn.Conv2d(256, 512, 3)
        self.conv6 = nn.Conv2d(512, 512, 3)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(7*7*512, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool3(x)
        
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool4(x)
        
        # Flatten
        x = x.view(-1, self.flat_size(x))
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
            
    def flat_size(self, x):
        size = x.size()
        result_size = 1
        for s in size:
            result_size *= s
        return result_size