In [1]:
import torch
from torchvision import datasets
from torchvision import transforms
import numpy as np


In [2]:
data_path = './data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)
tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
transformed_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4915, 0.4823, 0.4468), std=(0.2470, 0.2435, 0.2616))
]))
transformed_val_cifar10 = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4915, 0.4823, 0.4468), std=(0.2470, 0.2435, 0.2616))
]))

In [4]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label]) for img, label in transformed_cifar10 if label in [0, 2]]
cifar2_val = [(img, label_map[label]) for img, label in transformed_val_cifar10 if label in [0, 2]]

In [5]:
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.Tanh(),
    nn.MaxPool2d(2),
    nn.Conv2d(16, 8, kernel_size=3, padding=1),
    nn.Tanh(),
    nn.Maxpool2d(2),
    
    nn.Linear(8 * 8 * 8, 32),
    nn.Tanh(),
    nn.Linear(32, 2)
)

In [18]:
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8*8*16, 32)
        self.fc2 = nn.Linear(32, 2)
        
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.view(-1, 8 * 8 * 16)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        
        return out

In [19]:
import datetime
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)

In [20]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_train += loss.item()
            
        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Training loss {}'.format(datetime.datetime.now(), epoch, float(loss_train)))

In [21]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 100,
    optimizer = optimizer, 
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader
)

2020-04-16 16:51:00.028561 Epoch 1, Training loss 87.72056120634079
2020-04-16 16:51:13.958849 Epoch 10, Training loss 49.86131227016449
2020-04-16 16:51:29.228696 Epoch 20, Training loss 43.5352101624012
2020-04-16 16:51:44.611489 Epoch 30, Training loss 38.35382113605738
2020-04-16 16:52:00.516558 Epoch 40, Training loss 33.384215250611305
2020-04-16 16:52:15.900498 Epoch 50, Training loss 29.6724616214633
2020-04-16 16:52:31.415779 Epoch 60, Training loss 25.647116858512163
2020-04-16 16:52:46.735010 Epoch 70, Training loss 22.863835744559765
2020-04-16 16:53:02.101228 Epoch 80, Training loss 19.48645779490471
2020-04-16 16:53:17.601919 Epoch 90, Training loss 16.401355229318142
2020-04-16 16:53:32.788015 Epoch 100, Training loss 13.673299312591553


In [22]:
correct = 0
total = 0

with torch.no_grad():
    for imgs, labels in val_loader:
        outputs = model(imgs)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
        
    
print("Accuracy: %f" % (correct / total))

Accuracy: 0.872500


In [16]:
torch.save(model.state_dict(), data_path + 'bird_vs_airplanes.pt')

In [17]:
loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path + 'bird_vs_airplanes.pt'))

<All keys matched successfully>