# CIFAR10

## Package

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt



## Data set

In [None]:
trainset = torchvision.datasets.CIFAR10(root='data/cifar10_train', train=True, download=True,  transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='data/cifar10_test', train=False, download=True,transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)

len(trainset)
len(testset)

## Sample

In [None]:
img, label = traindata[np.random.randint(0,60000-1)]
print(label)

In [None]:
print(img)

In [None]:
img.size()

In [None]:
plt.imshow(img.reshape(28,28), cmap='gray')

In [None]:

img_train = img.view(-1).unsqueeze(0)
img_train.size()

## Model

In [None]:

model = nn.Sequential(
        nn.Linear(784, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
        nn.ReLU(),
        nn.Softmax(dim=1))



In [None]:
predict = model(img_train)

print(predict)

predict.size()

## Loss

### one-hot encoding

In [None]:
label_one_hot = torch.zeros(10).scatter_(0, torch.tensor(label), 1.0).unsqueeze(0)

label, label_one_hot, label_one_hot.size()

### Mean square loss

In [None]:
loss = torch.nn.MSELoss()

loss(predict, label_one_hot)

## Train

### Data loader

In [None]:
batch_sz = 600
train_loader = torch.utils.data.DataLoader(traindata, batch_size=batch_sz, shuffle=True)

#img, label = next(iter(train_loader))

### Learn rate

In [None]:
learning_rate = 0.5

### Optimizer

In [None]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

### Batch training

In [None]:
n_epochs = 100
for epoch in range(n_epochs):
    epoch_loss = 0
    for img, label in train_loader:
        label_one_hot = torch.zeros(batch_sz, 10).scatter_(1, label.view(batch_sz,1), 1.0)
        predict = model(img.view(batch_sz, -1))
        curr_loss = loss(predict, label_one_hot)
    
        optimizer.zero_grad()
        curr_loss.backward()
        optimizer.step()
        
        epoch_loss += curr_loss
    print("Epoch: %d, Loss: %f" % (epoch, float(epoch_loss)))

## Test accuracy

### Test set

In [None]:
testdata = datasets.MNIST('data/mnist_test', train=False, download=False, transform=transforms.ToTensor()) # download=True to download, train=False means test set

test_loader = torch.utils.data.DataLoader(testdata, batch_size=1, shuffle=True)

img, label = next(iter(test_loader))

predict = model(img.view(-1).unsqueeze(0))

_, predicted_label = torch.max(predict, dim=1)

print(predicted_label.item())

plt.imshow(img.reshape(28,28), cmap='gray')

# correct = 0
# total = 0
# with torch.no_grad():
# for imgs, labels in val_loader:
# batch_size = imgs.shape[0]
# outputs = model(imgs.view(batch_size, -1))
# _, predicted = torch.max(outputs, dim=1)
# total += labels.shape[0]
# correct += int((predicted == labels).sum())
# print("Accuracy: %f", correct / total)
# Accuracy: 0.794000

## Visualization

In [None]:
!rm -rf runs
writer = SummaryWriter('runs/mnist')

### add loss

In [None]:
n_epochs = 10
learning_rate = 0.1
for epoch in range(n_epochs):
    epoch_loss = 0
    for img, label in train_loader:
        label_one_hot = torch.zeros(batch_sz, 10).scatter_(1, label.view(batch_sz,1), 1.0)
        predict = model(img.view(batch_sz, -1))
        curr_loss = loss(predict, label_one_hot)
    
        optimizer.zero_grad()
        curr_loss.backward()
        optimizer.step()
        
        epoch_loss += curr_loss
    
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    print("Epoch: %d, Loss: %f" % (epoch, float(epoch_loss)))

### add model

In [None]:
img, _ = next(iter(train_loader))
writer.add_graph(model, img.view(batch_sz,-1))

In [None]:
writer.flush()
writer.close()

In [None]:
!tensorboard --logdir=runs/mnist