In [18]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [19]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())

True
NVIDIA GeForce RTX 4060


In [20]:
trainD = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
testD = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

In [21]:
BATCH_SIZE = 128
DEVICE = torch.accelerator.current_accelerator()

train_DL = DataLoader(trainD, batch_size=BATCH_SIZE, shuffle=True)
test_DL = DataLoader(testD, batch_size=BATCH_SIZE, shuffle=True)

for X, y in test_DL:
  print("shape x:", X.shape) # batchsize, channels, w, h
  print("shape y:", y.shape, y.dtype), # label shape of batch
  break

shape x: torch.Size([128, 1, 28, 28])
shape y: torch.Size([128]) torch.int64


In [22]:
loss_fn = nn.CrossEntropyLoss()

In [23]:
class Mnist(nn.Module):
  def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.model = nn.Sequential(
      nn.Linear(28*28, 256),
      nn.ReLU(),
      nn.Linear(256, 256),
      nn.ReLU(),
      nn.Linear(256, 10)
    )
  
  def forward(self, X):
    X = self.flatten(X) # 28*28 images to (784, 1)
    logits = self.model(X)
    return logits

In [24]:
def accuracy(output, y):
  preds = torch.argmax(output, dim=1)
    # y (labels) == preds.argmax
  return (preds == y).float().mean()*100

In [25]:
def train(dataloader, model, loss_fn, optimizer=None):
  model.train()
  size = len(dataloader.dataset)
  lr = 1e-3

  for batch_i, (X, y) in enumerate(dataloader): # in batches from DL
    X, y = X.to(DEVICE), y.to(DEVICE)

    pred = model(X)
    loss = loss_fn(pred, y)
    loss.backward()

    if (not optimizer):
      # -- MANUAL SGD
      with torch.no_grad():
        for param in model.parameters():
          param -= param.grad * lr # update gradients
      model.zero_grad() # reset gradients
    else:
      # optimization
      optimizer.step()
      optimizer.zero_grad() # clear gradients

    if (batch_i) % 100 < 1:
      print(f"Curr: {batch_i*BATCH_SIZE + len(X)}/{size}, Loss = {loss.item():.4f}")


In [26]:
def testing(dataloader, model, loss_fn):
  model.eval()
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  test_loss = correct = 0

  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(DEVICE), y.to(DEVICE)

      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(dim=1) == y).type(torch.float).sum().item()

  avg_loss = test_loss / num_batches
  accuracy = correct / size

  print("Test completed,")
  print(f"Accuracy: {accuracy*100:.2f}%, Avg Loss: {avg_loss:.4f}\n")

In [27]:
def fit(epochs: int):
  model = Mnist().to(DEVICE)
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  print("Starting!")
  for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    train(train_DL, model, loss_fn, optimizer)
    testing(test_DL, model, loss_fn)
  print("Done!")

  torch.save(model.state_dict(), "mnist_weights.pth")
  print("Weights saved to `mnist_weights.pth`")

In [28]:
EPOCHS = 5

fit(EPOCHS)

Starting!
Epoch 1
Curr: 128/60000, Loss = 2.3082
Curr: 12928/60000, Loss = 0.2922
Curr: 25728/60000, Loss = 0.2259
Curr: 38528/60000, Loss = 0.2181
Curr: 51328/60000, Loss = 0.1445
Test completed,
Accuracy: 95.30%, Avg Loss: 0.1476

Epoch 2
Curr: 128/60000, Loss = 0.1632
Curr: 12928/60000, Loss = 0.1247
Curr: 25728/60000, Loss = 0.1349
Curr: 38528/60000, Loss = 0.0664
Curr: 51328/60000, Loss = 0.1203
Test completed,
Accuracy: 96.89%, Avg Loss: 0.0990

Epoch 3
Curr: 128/60000, Loss = 0.1111
Curr: 12928/60000, Loss = 0.1582
Curr: 25728/60000, Loss = 0.0839
Curr: 38528/60000, Loss = 0.0498
Curr: 51328/60000, Loss = 0.0542
Test completed,
Accuracy: 97.26%, Avg Loss: 0.0889

Epoch 4
Curr: 128/60000, Loss = 0.0337
Curr: 12928/60000, Loss = 0.0209
Curr: 25728/60000, Loss = 0.1685
Curr: 38528/60000, Loss = 0.0557
Curr: 51328/60000, Loss = 0.0402
Test completed,
Accuracy: 97.67%, Avg Loss: 0.0745

Epoch 5
Curr: 128/60000, Loss = 0.0561
Curr: 12928/60000, Loss = 0.0229
Curr: 25728/60000, Loss = 