## Package

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim

## Data set

In [None]:
traindata = datasets.MNIST('data/mnist_train', train=True, download=False, transform=transforms.ToTensor()) # download = True to download

len(traindata)

## Sample

In [None]:
img, label = traindata[np.random.randint(0,60000-1)]
print(label)

In [None]:
print(img)

In [None]:
img.size()

In [None]:
plt.imshow(img.reshape(28,28), cmap='gray')

In [None]:

img_train = img.view(-1).unsqueeze(0)
img_train.size()


## Model

In [None]:

model = nn.Sequential(
        nn.Linear(784, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
        nn.ReLU(),
        nn.Softmax(dim=1))



In [None]:
res = model(img_train)

print(res)

res.size()

## Loss
### one-hot encoding

In [None]:
one_hot = torch.zeros(10).scatter_(0, torch.tensor(label), 1.0).unsqueeze(0)

label, one_hot, one_hot.size()


In [None]:
loss = torch.nn.MSELoss(reduce=True, size_average=True)

loss(res, one_hot)

## Train

### Optimizer

In [None]:
learning_rate = 1.0
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [None]:

batch_sz = 600
train_loader = torch.utils.data.DataLoader(traindata, batch_size=batch_sz, shuffle=True)

#img, label = next(iter(train_loader))


In [None]:
n_epochs = 100
for epoch in range(n_epochs):
    epoch_loss = 0
    for img, label in train_loader:
        one_hot = torch.zeros(batch_sz, 10).scatter_(1, label.view(batch_sz,1), 1.0)
        res = model(img.view(batch_sz, -1))
        curr_loss = loss(res, one_hot)
    
        optimizer.zero_grad()
        curr_loss.backward()
        optimizer.step()
        
        epoch_loss += curr_loss
    print("Epoch: %d, Loss: %f" % (epoch, float(epoch_loss)))



        


## Test accuracy

In [None]:
# val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,
# shuffle=False)
# correct = 0
# total = 0
# with torch.no_grad():
# for imgs, labels in val_loader:
# batch_size = imgs.shape[0]
# outputs = model(imgs.view(batch_size, -1))
# _, predicted = torch.max(outputs, dim=1)
# total += labels.shape[0]
# correct += int((predicted == labels).sum())
# print("Accuracy: %f", correct / total)
# Accuracy: 0.794000

## Visualization

In [None]:
!rm -rf runs
writer = SummaryWriter('runs/mnist')

In [None]:
writer.flush()
writer.close()

In [None]:
!tensorboard --logdir=runs/resnet