In [None]:
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import torch.optim as optim
import torch.nn.functional as F
import os

print(torch.__version__)
print(torchvision.__version__)

In [None]:
DOWNLOAD_DATASET = False
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = torchvision.datasets.CIFAR10(
    root='./CIFAR10',
    train = True,
    transform= transform,
    download=DOWNLOAD_DATASET
) 
print('train set:', len(train_data))


test_data = torchvision.datasets.CIFAR10(
    root = './CIFAR10',
    train = False,
    transform= transform,
    download=DOWNLOAD_DATASET
)
print('test set:', len(test_data))

In [None]:
print(train_data.classes)

print(train_data.class_to_idx)

In [None]:
print(train_data.data[16].shape)

In [None]:
plt.imshow(test_data.data[2021])
plt.show()
plt.savefig('test_2021.png')

plt.imshow(test_data.data[5230])
plt.show()
plt.savefig('test_5230.png')

plt.imshow(test_data.data[16])
plt.show()
plt.savefig('test_16.png')

In [None]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(64*4*4, 32),
            torch.nn.ReLU(),
        )
        self.fc2 = torch.nn.Linear(32,10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        out = F.log_softmax(x, dim=1)
        return out

model = ConvNet()
print(model)

In [None]:
def train(model, device, train_loader, optimizer, loss_func, epochs, summary):
    model.to(device)
    model.train() 

    for epoch in range(epochs):
        train_loss  = 0

        correct = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad() 

            output = model(data)
            _, preds = torch.max(output.data, 1)

            loss = loss_func(output, target) 
            train_loss += loss.item() * len(data)

            correct += (preds == target).sum().item()

            loss.backward()
            optimizer.step()

            if (batch_idx+1) % 50 == 0:
                print("Train Epoch: {} [{:5d}/{:5d} ({:.4f}%)]\tLoss: {:.6f}".format(
                    epoch, 
                    batch_idx * len(data), 
                    len(train_loader.dataset),  
                    100. * batch_idx / len(train_loader), 
                    loss.item()
                )) 

        train_loss /= len(train_loader.dataset)
        acc = 1.0 * correct / len(train_loader.dataset)

        summary.add_scalars("loss", {'ave_loss': train_loss}, epoch)
        summary.add_scalars("acc", {'acc': acc}, epoch)
        
    torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict()
                },
                './model/cifar10_epoch{}.pth'.format(epoch))


In [None]:
def test(model, device, test_loader):
    model.to(device)
    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data) 
            _, preds = torch.max(output.data, 1)

            correct += (preds == target).sum().item()
            total += len(target)

    print('>> Test accuracy: {:.4f}'.format(correct / total))

In [None]:
DEVICE = torch.device("cuda"if torch.cuda.is_available() else "cpu")
EPOCHS = 50
BATCH_SIZE = 128
LR = 0.001

train_loader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    pin_memory=True)

optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

loss_func = nn.CrossEntropyLoss()
summary = SummaryWriter("./logs")

train(model, DEVICE, train_loader, optimizer, loss_func, EPOCHS, summary)

In [None]:
id_num = 16
indices = range(id_num * 10, id_num * 10 + 10)

test_loader = DataLoader(
    test_data, 
    batch_size=BATCH_SIZE, 
    sampler = indices,
    pin_memory=True)

test(model = model, device = DEVICE, test_loader = test_loader)