In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

%matplotlib inline

In [23]:
EPOCHS = 100
BATCH_SIZE = 128
VAL_STEPS = 5

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [4]:
# def plot_img (img):
#     plt.imshow(img.numpy()[0], cmap='gray')
    
# data_iter = iter(trainloader)
# images, labels = data_iter.next()

In [5]:
class Net (nn.Module):
    
    def __init__ (self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)
        
        
    def forward (self, x):
        
        # Conv 1
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        
        # Conv2
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        # Flatten
        x = x.view(-1, 64*7*7)
        
        # Fc1
        x = F.relu(self.fc1(x))
        
        # Fc2
        x = self.fc2(x)
        
        return x
        

In [24]:
net = Net()

net.to(device)

net = nn.DataParallel(net)

In [25]:
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.001)

In [26]:
step = 0
val_iterator = iter(testloader)

for epoch in range(EPOCHS):
    
    for data in trainloader:
        x, y = data
        
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        logits = net(x)
        
        loss = loss_function(logits, y)
        
        loss.backward()
        
        optimizer.step()
        
        step += 1
        
        _, preds = torch.max(logits, 1)
        
        correct = (preds == y).sum().item()
        
        accuracy = correct/BATCH_SIZE
        
        if step % 500 == 0:
            print('Epoch: %d, Step: %5d, Minibatch loss: %.3f, Minibatch accuracy: %.3f' % 
                  (epoch, step, loss, accuracy))
        
        if step % 1000 == 0:
            
            val_acc = 0
            
            with torch.no_grad():
                for s in range(VAL_STEPS):
                    val_x, val_y = val_iterator.next()
                    
                    val_x, val_y = val_x.to(device), val_y.to(device)
                    
                    val_logits = net(val_x)
                    _, val_preds = torch.max(val_logits, 1)

                    val_correct = (val_preds == val_y).sum().item()
                    
                    val_acc += val_correct / BATCH_SIZE
            
            print('Validation loss: %.3f' % (val_acc / VAL_STEPS))

Epoch: 1, Step:   500, Minibatch loss: 0.032, Minibatch accuracy: 0.984
Epoch: 2, Step:  1000, Minibatch loss: 0.004, Minibatch accuracy: 1.000
Validation loss: 0.991
Epoch: 3, Step:  1500, Minibatch loss: 0.003, Minibatch accuracy: 1.000
Epoch: 4, Step:  2000, Minibatch loss: 0.003, Minibatch accuracy: 1.000
Validation loss: 0.986
Epoch: 5, Step:  2500, Minibatch loss: 0.004, Minibatch accuracy: 1.000
Epoch: 6, Step:  3000, Minibatch loss: 0.007, Minibatch accuracy: 1.000
Validation loss: 0.986
Epoch: 7, Step:  3500, Minibatch loss: 0.000, Minibatch accuracy: 1.000
Epoch: 8, Step:  4000, Minibatch loss: 0.001, Minibatch accuracy: 1.000
Validation loss: 0.988
Epoch: 9, Step:  4500, Minibatch loss: 0.001, Minibatch accuracy: 1.000


Process Process-47:
Process Process-25:
Process Process-26:
Process Process-48:
Traceback (most recent call last):


Epoch: 10, Step:  5000, Minibatch loss: 0.001, Minibatch accuracy: 1.000
Validation loss: 0.984


Traceback (most recent call last):
Traceback (most recent call last):


KeyboardInterrupt: 

Traceback (most recent call last):
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/zhmiao/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    sel

In [17]:
# Testing

correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        logits = net(images)
        _, predictions = torch.max(logits, 1)
        total += labels.shape[0]
        correct += (predictions == labels).sum().item()
        
print(correct/total)

0.9852


In [16]:
labels.shape[0]

128

$(A \times 1 \times B \times C \times 1 \times D)$
$(A \times B \times C \times D)$