In [1]:
!pip install torch
!pip install torchvision



In [2]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
training_epochs = 15
batch_size = 100

# load MNIST data
root = './data'
mnist_train = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dset.MNIST(root=root, train=False, transform=transforms.ToTensor(), download=True)

# data loader
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# MNIST linear model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
linear = torch.nn.Linear(784, 10, bias=True).to(device) #MNIST : 28x28, 0~9
torch.nn.init.normal_(linear.weight)

Parameter containing:
tensor([[ 1.4775e-01,  1.3107e+00,  8.6400e-02,  ..., -2.5632e-01,
         -1.1636e+00, -1.1360e+00],
        [ 1.2392e-01, -9.3654e-01,  7.5130e-01,  ...,  8.7881e-01,
         -2.0702e+00, -4.1303e-02],
        [ 8.4781e-01,  6.4852e-03, -3.2452e-01,  ...,  5.3248e-01,
         -1.1766e+00, -1.1945e+00],
        ...,
        [ 1.4280e+00,  9.5169e-01,  2.9085e-01,  ...,  1.5451e+00,
         -1.5432e+00, -3.6936e-01],
        [-7.3526e-04,  7.7152e-01, -1.9856e-01,  ...,  1.7256e-02,
         -1.1603e+00,  2.0420e-01],
        [ 1.9518e+00,  2.5703e-01,  1.1583e+00,  ..., -1.3782e+00,
          2.1885e+00, -2.8006e-01]], device='cuda:0', requires_grad=True)

In [5]:
# Loss func(Cross Entropy) & optimizer(SGD)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)

In [7]:
# training loop
for epoch in range(training_epochs):
  for i, (imgs, labels) in enumerate(train_loader):
    imgs, labels = imgs.to(device), labels.to(device)
    imgs = imgs.view(-1, 28 * 28)

    outputs = linear(imgs)
    loss = criterion(outputs, labels)

    #optimizer.zero_grad, loss backward & optimizer step
    optimizer.zero_grad() 
    loss.backward()
    optimizer.step()

    _, argmax = torch.max(outputs, 1)
    accuracy = (labels == argmax).float().mean()

    if (i+1) % 100 == 0:
      print('Epoch [{}/{}], Step [{}/{}], Loss : {:.4f}, Accuracy : {:.2f}%'.format(
          epoch+1, training_epochs, i+1, len(train_loader), loss.item(), accuracy.item()*100
      ))

Epoch [1/15], Step [100/600], Loss : 3.1196, Accuracy : 48.00%
Epoch [1/15], Step [200/600], Loss : 1.6864, Accuracy : 70.00%
Epoch [1/15], Step [300/600], Loss : 1.6245, Accuracy : 66.00%
Epoch [1/15], Step [400/600], Loss : 1.4174, Accuracy : 71.00%
Epoch [1/15], Step [500/600], Loss : 1.1798, Accuracy : 75.00%
Epoch [1/15], Step [600/600], Loss : 0.9240, Accuracy : 76.00%
Epoch [2/15], Step [100/600], Loss : 1.0619, Accuracy : 77.00%
Epoch [2/15], Step [200/600], Loss : 1.5597, Accuracy : 71.00%
Epoch [2/15], Step [300/600], Loss : 1.5725, Accuracy : 74.00%
Epoch [2/15], Step [400/600], Loss : 1.1395, Accuracy : 83.00%
Epoch [2/15], Step [500/600], Loss : 1.4035, Accuracy : 71.00%
Epoch [2/15], Step [600/600], Loss : 0.8926, Accuracy : 81.00%
Epoch [3/15], Step [100/600], Loss : 0.8428, Accuracy : 83.00%
Epoch [3/15], Step [200/600], Loss : 0.9122, Accuracy : 78.00%
Epoch [3/15], Step [300/600], Loss : 1.0519, Accuracy : 81.00%
Epoch [3/15], Step [400/600], Loss : 0.9724, Accuracy :

In [8]:
# test
linear.eval()
with torch.no_grad():
  correct = 0
  total = 0
  for i, (imgs, labels) in enumerate(test_loader):
    imgs, labels = imgs.to(device), labels.to(device)
    imgs = imgs.view(-1, 28 * 28)

    outputs = linear(imgs)

    _, argmax = torch.max(outputs, 1) 
    total += imgs.size(0)
    correct += (labels == argmax).sum().item()

  print('Test accuracy for {} images: {:.2f}%'.format(total, correct / total * 100))


Test accuracy for 10000 images: 89.08%
