In [3]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

%matplotlib inline

In [16]:
EPOCHS = 1
BATCH_SIZE = 128
VAL_STEPS = 5

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [7]:
transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [8]:
# def plot_img (img):
#     plt.imshow(img.numpy()[0], cmap='gray')
    
# data_iter = iter(trainloader)
# images, labels = data_iter.next()

In [9]:
class Net (nn.Module):
    
    def __init__ (self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)
        
        
    def forward (self, x):
        
        # Conv 1
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        
        # Conv2
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        # Flatten
        x = x.view(-1, 64*7*7)
        
        # Fc1
        x = F.relu(self.fc1(x))
        
        # Fc2
        x = self.fc2(x)
        
        return x
        

In [10]:
net = Net()

In [11]:
net.fc1.in_features

3136

In [12]:
net = Net()

net.to(device)

net = nn.DataParallel(net)

In [13]:
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.001)

In [18]:
preds

tensor([ 2,  7,  2,  2,  0,  8,  0,  3,  4,  3,  0,  0,  6,  3,
         2,  4,  8,  0,  4,  4,  1,  7,  2,  0,  5,  4,  2,  7,
         2,  3,  0,  9,  9,  1,  5,  8,  2,  3,  7,  4,  8,  7,
         5,  7,  1,  6,  1,  3,  9,  8,  6,  8,  5,  4,  4,  1,
         5,  9,  1,  8,  2,  1,  2,  7,  3,  4,  3,  1,  5,  6,
         9,  8,  3,  5,  7,  0,  2,  7,  1,  8,  8,  9,  8,  1,
         5,  9,  3,  8,  3,  7,  7,  1,  8,  1,  4,  0], device='cuda:0')

In [17]:
step = 0
val_iterator = iter(cycle(testloader))

for epoch in range(EPOCHS):
    
    for data in trainloader:
        
        net.train()
        
        x, y = data
        
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        logits = net(x)
        
        loss = loss_function(logits, y)
        
        loss.backward()
        
        optimizer.step()
        
        step += 1
        
        _, preds = torch.max(logits, 1)
        
        correct = (preds == y).sum().item()
        
        accuracy = correct/BATCH_SIZE
        
        if step % 50 == 0:
            print('Epoch: %d, Step: %5d, Minibatch loss: %.3f, Minibatch accuracy: %.3f' % 
                  (epoch, step, loss, accuracy))
        
        if step % 100 == 0:
            
            net.eval()
            
            val_acc = 0
            
            with torch.no_grad():
                for s in range(VAL_STEPS):
                    val_x, val_y = next(val_iterator)
                    
                    val_x, val_y = val_x.to(device), val_y.to(device)
                    
                    val_logits = net(val_x)
                    _, val_preds = torch.max(val_logits, 1)

                    val_correct = (val_preds == val_y).sum().item()
                    
                    val_acc += val_correct / BATCH_SIZE
            
            print('Validation loss: %.3f' % (val_acc / VAL_STEPS))

Epoch: 0, Step:    50, Minibatch loss: 0.174, Minibatch accuracy: 0.930
Epoch: 0, Step:   100, Minibatch loss: 0.179, Minibatch accuracy: 0.945
Validation loss: 0.963
Epoch: 0, Step:   150, Minibatch loss: 0.049, Minibatch accuracy: 0.984
Epoch: 0, Step:   200, Minibatch loss: 0.082, Minibatch accuracy: 0.977
Validation loss: 0.970
Epoch: 0, Step:   250, Minibatch loss: 0.136, Minibatch accuracy: 0.961
Epoch: 0, Step:   300, Minibatch loss: 0.093, Minibatch accuracy: 0.961
Validation loss: 0.975
Epoch: 0, Step:   350, Minibatch loss: 0.071, Minibatch accuracy: 0.969
Epoch: 0, Step:   400, Minibatch loss: 0.014, Minibatch accuracy: 1.000
Validation loss: 0.978
Epoch: 0, Step:   450, Minibatch loss: 0.020, Minibatch accuracy: 0.992


In [17]:
# Testing

correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        logits = net(images)
        _, predictions = torch.max(logits, 1)
        total += labels.shape[0]
        correct += (predictions == labels).sum().item()
        
print(correct/total)

0.9852


In [16]:
labels.shape[0]

128

$(A \times 1 \times B \times C \times 1 \times D)$
$(A \times B \times C \times D)$