In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
device = torch.device("cpu")

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()

device

device(type='mps')

In [3]:
training_data = datasets.MNIST("data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST("data", train=False, download=True, transform=ToTensor())

In [4]:
training_data[0][0].shape

torch.Size([1, 28, 28])

In [5]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(28 * 28, 64)
    self.fc2 = nn.Linear(64, 10)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

In [6]:
net = Net()

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [8]:
from torch.utils.data import DataLoader

In [9]:
train_loader = DataLoader(training_data, batch_size=64, shuffle=True)
out = next(iter(train_loader))
out[0][0].flatten()
net(out[0][0].flatten())

tensor([-0.0079, -0.0335,  0.0080,  0.0099, -0.0358, -0.1462, -0.1551,  0.0116,
        -0.0042, -0.0389], grad_fn=<ViewBackward0>)

In [10]:
from torch.utils.data import DataLoader

class Trainer:
  def __init__(self, model, loss_fn, optim_fn, training_data, test_data):
    self.model = model.to(device)
    self.loss_fn = loss_fn()
    self.optim = optim_fn(self.model.parameters(), lr=0.001)
    self.training_data = training_data
    self.test_data = test_data

  def train(self):
    train_loader = DataLoader(self.training_data, batch_size=64, shuffle=True)

    for data, targets in train_loader:
      self.optim.zero_grad()

      data = data.squeeze(1).flatten(1,2)
      data = data.to(device)
      output = self.model(data)
      targets = targets.to(device)

      loss = self.loss_fn(output, targets)
      loss.backward()

      self.optim.step()

      correct = torch.sum(output.argmax(1) == targets).item()
      total = output.size(0)

      print(f"Training accuracy: {round((correct / total) * 100, 2)}\tTraining loss: {loss.item()}\n")
      print("")

  def validate(self):
    test_loader = DataLoader(self.test_data, batch_size=64, shuffle=True)

    correct = 0
    total = 0

    for data, targets in test_loader:
      data = data.squeeze(1).flatten(1,2)
      data = data.to(device)

      output = self.model(data)
      targets = targets.to(device)

      correct += torch.sum(output.argmax(1) == targets).item()
      total += output.size(0)

    print(f"Validation accuracy: {round((correct / total) * 100, 2)}\n")

In [11]:
trainer = Trainer(net, nn.CrossEntropyLoss, optim.Adam, training_data, test_data)

epoch = 10
for i in range(epoch):
  print(f"Epoch {i+1}..")
  trainer.train()
  trainer.validate()
  print()

Epoch 1..
Training accuracy: 6.25	Training loss: 2.2967677116394043


Training accuracy: 14.06	Training loss: 2.2740840911865234


Training accuracy: 40.62	Training loss: 2.234233856201172


Training accuracy: 45.31	Training loss: 2.2137842178344727


Training accuracy: 59.38	Training loss: 2.157383918762207


Training accuracy: 50.0	Training loss: 2.1164369583129883


Training accuracy: 53.12	Training loss: 2.0723776817321777


Training accuracy: 57.81	Training loss: 2.045067310333252


Training accuracy: 59.38	Training loss: 1.995074987411499


Training accuracy: 53.12	Training loss: 1.9744784832000732


Training accuracy: 56.25	Training loss: 1.9310221672058105


Training accuracy: 62.5	Training loss: 1.8745418787002563


Training accuracy: 64.06	Training loss: 1.8435271978378296


Training accuracy: 48.44	Training loss: 1.8882322311401367


Training accuracy: 64.06	Training loss: 1.822387456893921


Training accuracy: 60.94	Training loss: 1.745124340057373


Training accuracy: 62.5

In [None]:
# torch.save(net.state_dict(), "model.pt")

In [12]:
model = Net()
model.load_state_dict(torch.load("model.pt", weights_only=True, map_location=device))
model.eval()

Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)