In [4]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms

# Load MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Create dataloaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=False)

# Define the model
model = nn.Linear(784, 10)

# Define the loss function and the optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
for epoch_i in range(10):
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.view(-1, 28*28)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    accuracy = correct / total
    print(f'Epoch {epoch_i}, Loss: {loss.item()}, Accuracy: {accuracy}')

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.view(-1, 28*28)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
test_accuracy = correct / total
print(f'Test Accuracy: {test_accuracy}')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:02<00:00, 3637040.76it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 557948.74it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:01<00:00, 1509600.37it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch 0, Loss: 0.6993998885154724, Accuracy: 0.84545
Epoch 1, Loss: 0.5205710530281067, Accuracy: 0.8652833333333333
Epoch 2, Loss: 0.41092732548713684, Accuracy: 0.8750833333333333
Epoch 3, Loss: 0.519656240940094, Accuracy: 0.8804166666666666
Epoch 4, Loss: 0.43699049949645996, Accuracy: 0.8854
Epoch 5, Loss: 0.4621402323246002, Accuracy: 0.8886833333333334
Epoch 6, Loss: 0.3594604432582855, Accuracy: 0.8912
Epoch 7, Loss: 0.36092299222946167, Accuracy: 0.8929166666666667
Epoch 8, Loss: 0.3290359079837799, Accuracy: 0.8952
Epoch 9, Loss: 0.4095393717288971, Accuracy: 0.8969666666666667
Test Accuracy: 0.9034
