In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch import nn

In [2]:
batch_size = 32
valid_size = 0.2

In [13]:
train_data = datasets.MNIST(root='data', download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

In [14]:
print(type(train_data))

<class 'torchvision.datasets.mnist.MNIST'>


In [15]:
n = len(train_data)
indices = torch.randperm(n)

In [16]:
mid = int(valid_size * n)
train_indices = indices[:mid]
valid_indices = indices[mid:]

In [17]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

In [18]:
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size,
                                           sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size,
                                           sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size)

In [19]:
model = nn.Sequential(
  nn.Linear(784, 512),
  nn.ReLU(),
  nn.Dropout(0.2),
  nn.Linear(512, 512),
  nn.ReLU(),
  nn.Dropout(0.2),
  nn.Linear(512, 10)
)

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [21]:
for epoch in range(50):
  train_loss, valid_loss = 0.0, 0.0

  model.train()
  for idx, (data, target) in enumerate(train_loader):
    data = data.view(-1, 28*28)
    output = model(data)
    loss = criterion(output, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_loss += loss.item() * data.size(0)

  model.eval()
  for idx, (data, target) in enumerate(valid_loader):
    data = data.view(-1, 28*28)
    output = model(data)
    loss = criterion(output, target)

    valid_loss += loss.item() * data.size(0)

  train_loss /= len(train_loader.dataset)
  valid_loss /= len(valid_loader.dataset)

  print(f'epoch: {epoch}, training loss: {train_loss:>0.4}, validation loss: {valid_loss:>0.4}')



epoch: 0, training loss: 0.4364, validation loss: 0.3884
epoch: 1, training loss: 0.2851, validation loss: 0.1815
epoch: 2, training loss: 0.151, validation loss: 0.116
epoch: 3, training loss: 0.113, validation loss: 0.0918
epoch: 4, training loss: 0.09533, validation loss: 0.08107
epoch: 5, training loss: 0.08524, validation loss: 0.07355
epoch: 6, training loss: 0.07932, validation loss: 0.06773
epoch: 7, training loss: 0.07358, validation loss: 0.06403
epoch: 8, training loss: 0.06991, validation loss: 0.06093
epoch: 9, training loss: 0.06633, validation loss: 0.05755
epoch: 10, training loss: 0.06387, validation loss: 0.05501
epoch: 11, training loss: 0.06015, validation loss: 0.05279
epoch: 12, training loss: 0.0582, validation loss: 0.05021
epoch: 13, training loss: 0.05553, validation loss: 0.0483
epoch: 14, training loss: 0.05337, validation loss: 0.04663
epoch: 15, training loss: 0.0516, validation loss: 0.04464
epoch: 16, training loss: 0.05001, validation loss: 0.0432
epoch

In [22]:
model.eval()
correct = 0
for idx, (data, target) in enumerate(test_loader):
  data = data.view(-1, 28*28)
  output = model(data)
  _, pred = torch.max(output, 1)
  correct += pred.eq(target.data.view_as(pred)).sum()

print(f'accuracy: {correct / len(test_data) * 100:>0.4}')


accuracy: 95.71
