In [2]:
import torch

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as opt
from torch.autograd import Variable

## Swish Function

The Swish Function: 
$ f(x) = x * sigmoid(x) $

In [3]:
def swish(x):
    return x * torch.sigmoid(x)

In [4]:
class Swish(nn.Module):
    def __init__(self, slope=1):
        super().__init__()
        
    def swish(self, x):
        return x * torch.sigmoid(x)
    
    def forward(self, x):
        return swish(x)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
train_dataset = MNIST(root = './data', train=True, download=True, transform=transforms.ToTensor())
test_dataset= MNIST(root = './data', train=False, download=True, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
# hyperparameter 
train_batch_size = 100
test_batch_szie = 1000

# train dataloader
train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=train_batch_size, 
    shuffle=True
    )

# test dataloader
test_loader = DataLoader(
    dataset=test_dataset, 
    batch_size=test_batch_szie, 
    shuffle=False
    )

In [8]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
            Swish(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            Swish(),
            nn.MaxPool2d(stride=2, kernel_size=2)
        )
        
        self.dense = nn.Sequential(
            nn.Linear(in_features=14*14*128, out_features=1024),
            Swish(),
            nn.Linear(1024, 10)
        )

    def forward(self, x):
        output = self.conv_layers(x)
        output = output.view(-1, 14*14*128)
        output = self.dense(output)
        return output

In [10]:
model = CNN().to(device)

In [11]:
# hypyerperameter
learning_rate = 0.001

loss_func = nn.CrossEntropyLoss()
optimizer = opt.Adam(model.parameters(), lr=learning_rate)

In [12]:
# hyperparameter 
num_epochs = 5

for epoch in range(num_epochs):
    for idx, (images, labels) in enumerate(train_loader):
        images = Variable(images.to(device))
        labels = Variable(labels.to(device))

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        if (idx+1)%100 == 0:
            print("Epoch: %d, Batch: %d, Loss: %.4f" %(epoch+1, idx+1, loss.data))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 1, Batch: 100, Loss: 0.1843
Epoch: 1, Batch: 200, Loss: 0.0398
Epoch: 1, Batch: 300, Loss: 0.1091
Epoch: 1, Batch: 400, Loss: 0.0513
Epoch: 1, Batch: 500, Loss: 0.0686
Epoch: 1, Batch: 600, Loss: 0.0308
Epoch: 2, Batch: 100, Loss: 0.0104
Epoch: 2, Batch: 200, Loss: 0.0168
Epoch: 2, Batch: 300, Loss: 0.1110
Epoch: 2, Batch: 400, Loss: 0.0257
Epoch: 2, Batch: 500, Loss: 0.0138
Epoch: 2, Batch: 600, Loss: 0.0705
Epoch: 3, Batch: 100, Loss: 0.0012
Epoch: 3, Batch: 200, Loss: 0.0030
Epoch: 3, Batch: 300, Loss: 0.0525
Epoch: 3, Batch: 400, Loss: 0.0034
Epoch: 3, Batch: 500, Loss: 0.0462
Epoch: 3, Batch: 600, Loss: 0.0252
Epoch: 4, Batch: 100, Loss: 0.1020
Epoch: 4, Batch: 200, Loss: 0.0097
Epoch: 4, Batch: 300, Loss: 0.0001
Epoch: 4, Batch: 400, Loss: 0.0038
Epoch: 4, Batch: 500, Loss: 0.0191
Epoch: 4, Batch: 600, Loss: 0.0006
Epoch: 5, Batch: 100, Loss: 0.0065
Epoch: 5, Batch: 200, Loss: 0.0185
Epoch: 5, Batch: 300, Loss: 0.0373
Epoch: 5, Batch: 400, Loss: 0.0005
Epoch: 5, Batch: 500

In [13]:
correct = 0
total = 0
for images, labels in test_loader:
  images = Variable(images.to(device))
  outputs = model(images)

  _, pred = torch.max(outputs.data, 1)
  
  correct += (pred == labels.to(device)).sum()
  total += labels.size(0)

print('Accuracy:%.3f%%' %(100.0 * float(correct)/float(total)))

Accuracy:98.600%
