In [12]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
device = torch.device("cuda")

In [3]:
training_data = datasets.MNIST("data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST("data", train=False, download=True, transform=ToTensor())

In [9]:
training_data[0][0].shape

torch.Size([1, 28, 28])

In [None]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(28 * 28, 64)
    self.fc2 = nn.Linear(64, 10)

  def forward(self, x):
    x = F.relu(self.fc1())
    x = self.fc2(x)
    return x

In [131]:
net = Net()

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [24]:
from torch.utils.data import DataLoader

In [None]:
train_loader = DataLoader(training_data, batch_size=64, shuffle=True)
out = next(iter(train_loader))
net(out[0][0].flatten())
criterion

tensor([-0.0986, -0.1339, -0.1010, -0.1187, -0.0802,  0.0285, -0.0141,  0.0713,
         0.1764,  0.0260], grad_fn=<ViewBackward0>)

In [132]:
from torch.utils.data import DataLoader

class Trainer:
  def __init__(self, model, loss_fn, optim_fn, training_data, test_data):
    self.model = model.to(device)
    self.loss_fn = loss_fn()
    self.optim = optim_fn(self.model.parameters(), lr=0.001)
    self.training_data = training_data
    self.test_data = test_data

  def train(self):
    train_loader = DataLoader(self.training_data, batch_size=64, shuffle=True)

    for data, targets in train_loader:
      self.optim.zero_grad()

      data = data.squeeze(1).flatten(1,2)
      data = data.to(device)
      output = self.model(data)
      targets = targets.to(device)

      loss = self.loss_fn(output, targets)
      loss.backward()

      self.optim.step()

      correct = torch.sum(output.argmax(1) == targets).item()
      total = output.size(0)

      print(f"Training accuracy: {round((correct / total) * 100, 2)}")
      print(f"Training loss: {loss.item()}")

  def validate(self):
    test_loader = DataLoader(self.test_data, batch_size=64, shuffle=True)

    correct = 0
    total = 0

    for data, targets in test_loader:
      data = data.squeeze(1).flatten(1,2)
      data = data.to(device)

      output = self.model(data)
      targets = targets.to(device)

      correct += torch.sum(output.argmax(1) == targets).item()
      total += output.size(0)

    print(f"Validation accuracy: {round((correct / total) * 100, 2)}")

trainer = Trainer(net, nn.CrossEntropyLoss, optim.Adam, training_data, test_data)

epoch = 10
for i in range(epoch):
  print(f"Epoch {i+1}..")
  trainer.train()
  trainer.validate()
  print()

Epoch 1..
Training accuracy: 4.69
Training loss: 2.319502115249634
Training accuracy: 6.25
Training loss: 2.300431728363037
Training accuracy: 12.5
Training loss: 2.2541399002075195
Training accuracy: 28.12
Training loss: 2.232858896255493
Training accuracy: 39.06
Training loss: 2.195103168487549
Training accuracy: 46.88
Training loss: 2.1859772205352783
Training accuracy: 56.25
Training loss: 2.123955726623535
Training accuracy: 56.25
Training loss: 2.1278491020202637
Training accuracy: 56.25
Training loss: 2.0824170112609863
Training accuracy: 59.38
Training loss: 2.019536256790161
Training accuracy: 68.75
Training loss: 1.965165138244629
Training accuracy: 50.0
Training loss: 1.9912686347961426
Training accuracy: 68.75
Training loss: 1.8935294151306152
Training accuracy: 51.56
Training loss: 1.9244798421859741
Training accuracy: 57.81
Training loss: 1.836297869682312
Training accuracy: 60.94
Training loss: 1.81641685962677
Training accuracy: 53.12
Training loss: 1.777782917022705
Tr

In [136]:
torch.save(net.state_dict(), "model.pt")

In [138]:
model = Net()
model.load_state_dict(torch.load("model.pt", weights_only=True))
model.eval()

Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)