In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt

mnist_train = datasets.MNIST(root='MNIST_data/',
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True)

mnist_test  = datasets.MNIST(root='MNIST_data/',
                             train=False,
                             transform=transforms.ToTensor(),
                             download=True)

BATSIZE = 10000

train_loader = DataLoader(dataset=mnist_train,
                          batch_size=BATSIZE,
                          shuffle=True,
                          num_workers=0)

test_loader  = DataLoader(dataset=mnist_test,
                          batch_size=BATSIZE,
                          shuffle=True,
                          num_workers=0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 1024)
        self.fc2 = torch.nn.Linear(1024, 512)
        self.fc3 = torch.nn.Linear(512, 256)
        self.fc4 = torch.nn.Linear(256, 128)
        self.fc5 = torch.nn.Linear(128, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        z = self.fc5(x)
        return z

net = Net().to(device)
cel = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

EPOCHS = 10
for epoch in range(EPOCHS):
    l_sum = 0

    for batch_idx, (x,y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        z = net(x)
        loss = cel(z, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        l_sum += loss.item()

    print(f'Epoch: {epoch+1:2d} / {EPOCHS}',
          f'Loss: {l_sum:0.6f}')

In [None]:
index = 10

net.eval()
x = mnist_test[index][0].view(28, 28).to(device)
y = mnist_test[index][1]

z = net(x)
pred = torch.max(z, 1)[1].item()

print(f'Predicted: {pred}')
print(f'Label: {y}')

plt.imshow(x.cpu(), cmap='Greys')
plt.show()