In [5]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64

trainset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model = model.train()
n_epoch = 10

for epoch in range(n_epoch):
    with tqdm(total=len(trainloader), unit="batch") as pbar:
        pbar.set_description(f"Epoch[{epoch}/{n_epoch}]")
        for images, labels in trainloader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)

model = model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy: {} %'.format(100 * correct / total))

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

  0%|          | 0/938 [00:00<?, ?batch/s]

Accuracy: 96.73 %
