In [48]:
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch
import torch.nn as nn

In [49]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

training_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST("./data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

In [50]:
class MyMnistNet(nn.Module):
  def __init__(self):
    super(MyMnistNet, self).__init__()
    self.fc1 = nn.Linear(784, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, 10)

  
  def forward(self, x):
    x = torch.flatten(x, 1)
    h1 = self.fc1(x)
    u1 = F.relu(h1)
    h2 = self.fc2(u1)
    u2 = F.relu(h2)
    h3 = self.fc3(u2)

    return h3
    


In [51]:
def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (x, y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    optimizer.zero_grad()
    p_y_hat = model(x)
    loss = F.cross_entropy(p_y_hat, y)
    loss.backward()
    optimizer.step()

    if batch_idx % 100 == 0:
      print(f"Epoch={epoch+1}, Batch={batch_idx+1:03}, Loss={loss.item():.4f}")

def test(model, device, test_loader):
  model.eval()
  correct = 0
  for x, y in test_loader:
    x = x.to(device)
    y = y.to(device)
    p_y_hat = model(x)
    y_hat = p_y_hat.argmax(dim=1, keepdim=True)
    correct += y_hat.eq(y.view_as(y_hat)).sum().item()
  
  accuracy = correct / len(test_loader.dataset)
  print(f"Test-set accuracy={accuracy :.04f}\n")

In [55]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MyMnistNet().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(10):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
    
    torch.save(model.state_dict(), "result/MNIST.model")


if __name__ == "__main__":
    main()

Epoch=1, Batch=001, Loss=2.3023
Epoch=1, Batch=101, Loss=1.8035
Epoch=1, Batch=201, Loss=1.1329
Epoch=1, Batch=301, Loss=0.5794
Epoch=1, Batch=401, Loss=0.5209
Epoch=1, Batch=501, Loss=0.4781
Epoch=1, Batch=601, Loss=0.3603
Epoch=1, Batch=701, Loss=0.2117
Epoch=1, Batch=801, Loss=0.5352
Epoch=1, Batch=901, Loss=0.5229
Test-set accuracy=0.9075

Epoch=2, Batch=001, Loss=0.3681
Epoch=2, Batch=101, Loss=0.1870
Epoch=2, Batch=201, Loss=0.4371
Epoch=2, Batch=301, Loss=0.5515
Epoch=2, Batch=401, Loss=0.1974
Epoch=2, Batch=501, Loss=0.4757
Epoch=2, Batch=601, Loss=0.2580
Epoch=2, Batch=701, Loss=0.2251
Epoch=2, Batch=801, Loss=0.2035
Epoch=2, Batch=901, Loss=0.2141
Test-set accuracy=0.9249

Epoch=3, Batch=001, Loss=0.2408
Epoch=3, Batch=101, Loss=0.2617
Epoch=3, Batch=201, Loss=0.1151
Epoch=3, Batch=301, Loss=0.1763
Epoch=3, Batch=401, Loss=0.2065
Epoch=3, Batch=501, Loss=0.1621
Epoch=3, Batch=601, Loss=0.1220
Epoch=3, Batch=701, Loss=0.3520
Epoch=3, Batch=801, Loss=0.1382
Epoch=3, Batch=901, 