In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import RandomSampler, BatchSampler, SequentialSampler, random_split, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.jit import save, load

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    correct = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        loss.backward()
        optimizer.step()
        
    acc = 100. * correct / len(train_loader.dataset)
    print(f'Train Epoch: {epoch}\nTraining Loss: {loss.item()}\tAccuracy: {acc:.4f}')
    
def val(model, device, val_loader):
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    val_loss /= len(val_loader.dataset)
    acc = 100. * correct / len(val_loader.dataset)
    
    print(f'Validation Loss: {val_loss}, Accuracy: {acc:.4f}\n')
    
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    
    print(f'Test Loss: {test_loss}, Accuracy: {acc:.4f}')
    

# set random seed
torch.manual_seed(198964)

#set device
device = torch.device('cuda')

# arguments
batch_size = 64
test_batch_size = 1000
lr = 1.0
gamma = 0.7
epochs = 14

# transform
train_tf=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])
eval_tf=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])

# import data
train_data = datasets.MNIST(root='./mnist_data', train=True, download=True,
                            transform=train_tf)
test_data = datasets.MNIST(root='./mnist_data', train=False, download=True, 
                           transform=eval_tf)
(test_data, val_data) = random_split(test_data, [5000, 5000])

# kwargs
train_kwargs = {'batch_sampler': BatchSampler(sampler=RandomSampler(train_data),
                                              batch_size=batch_size,
                                              drop_last=False),
                'num_workers':3}
val_kwargs = {'batch_sampler': BatchSampler(sampler=SequentialSampler(val_data),
                                              batch_size=test_batch_size,
                                              drop_last=False), 
              'num_workers':3}
test_kwargs = {'batch_sampler': BatchSampler(sampler=SequentialSampler(test_data),
                                              batch_size=test_batch_size,
                                              drop_last=False),
               'num_workers':3}

# dataloader
train_loader = DataLoader(train_data, **train_kwargs)
val_loader = DataLoader(val_data, **val_kwargs)
test_loader = DataLoader(test_data, **test_kwargs)

# build model
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)

scheduler = StepLR(optimizer, step_size = 1, gamma=gamma)
for epoch in range(1, epochs+1):
    train(model, device, train_loader, optimizer, epoch)
    val(model, device, val_loader)
    scheduler.step()
    
test(model, device, test_loader)

Train Epoch: 1
Training Loss: 0.02955252304673195	Accuracy: 93.5367
Validation Loss: 0.055548307037353514, Accuracy: 98.1600

Train Epoch: 2
Training Loss: 0.24216631054878235	Accuracy: 97.6700
Validation Loss: 0.040924477767944334, Accuracy: 98.6200

Train Epoch: 3
Training Loss: 0.009153272956609726	Accuracy: 98.3333
Validation Loss: 0.04083502883911133, Accuracy: 98.7400

Train Epoch: 4
Training Loss: 0.04233069717884064	Accuracy: 98.5467
Validation Loss: 0.03656873245239258, Accuracy: 98.6800

Train Epoch: 5
Training Loss: 0.005246618762612343	Accuracy: 98.7733
Validation Loss: 0.03338316230773926, Accuracy: 98.9200

Train Epoch: 6
Training Loss: 0.011628229171037674	Accuracy: 98.9367
Validation Loss: 0.03176677055358887, Accuracy: 99.0600

Train Epoch: 7
Training Loss: 0.03911621868610382	Accuracy: 99.0017
Validation Loss: 0.02960165710449219, Accuracy: 98.9600

Train Epoch: 8
Training Loss: 0.16776806116104126	Accuracy: 99.0350
Validation Loss: 0.033748940658569336, Accuracy: 98.

In [11]:
from torch.jit import script
my_script = torch.jit.script(Net())

my_script.save('mnist_script.pt')

In [13]:
torch.save(model.state_dict(), './mnist_state_dict.pt')