# MNIST 실습 with MPS
- MNIST dataset: 28x28 크기의 흑백 손글씨 이미지로, 0 ~ 9가 적혀있는 데이터셋
- MNIST는 손글씨 사진과 어떤 숫자를 의미하는지에 대한 label의 pair들로 구성
- 학습과 평가를 MPS(Metal Performance Shaders) 를 이용해서 진행
- [학습 및 평가 결과 보고서](https://github.com/yuiyeong/deeplearning/blob/main/docs/report_presentation01_practice.md)

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD

In [1]:
ROOT_DIR = "../data"

In [2]:
class SimpleMPSModel(nn.Module):
    def __init__(self, input_dim, n_dim):
        super(SimpleMPSModel, self).__init__()

        self.layer1 = nn.Linear(input_dim, n_dim)
        self.layer2 = nn.Linear(n_dim, n_dim)
        self.layer3 = nn.Linear(n_dim, 1)

        self.activation = nn.ReLU()

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.activation(self.layer1(x))
        x = self.activation(self.layer2(x))
        x = self.activation(self.layer3(x))
        return x

In [3]:
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [4]:
def train(device, n_epochs, lr, batch_size, num_workers, pin_memory):
    train_set = torchvision.datasets.MNIST(
        root=ROOT_DIR, train=True, download=True, transform=transforms.ToTensor()
    )
    train_set_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    model = SimpleMPSModel(1 * 28 * 28, 1024).to(device)
    model.train()

    optimizer = SGD(model.parameters(), lr=lr)

    for epoch in range(n_epochs):
        total_loss = 0.0
        for data in train_set_loader:
            model.zero_grad()

            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            preds = model(inputs)

            loss = (preds[:, 0] - labels).pow(2).mean()
            loss.backward()

            optimizer.step()

            total_loss += loss.item()

        print("  ", f"Epoch {epoch + 1:3d} | Sum of Loss: {total_loss}")
    return model

In [5]:
def evaluate_model(device, model, batch_size):
    model.eval()  # 모델을 평가 모드로 설정

    test_set = torchvision.datasets.MNIST(
        root=ROOT_DIR, train=False, download=True, transform=transforms.ToTensor()
    )
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=False
    )

    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():  # 그래디언트 계산 비활성화
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = (outputs[:, 0] - labels).pow(2).mean()

            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

    avg_loss = total_loss / total_samples

    return avg_loss

In [6]:
def test(num, device, n_epochs, lr, batch_size, num_workers=4, pin_memory=True):
    print("<" * 20, f"{num} 번째 테스트", ">" * 20)
    print("device:", device)
    print("batch_size:", batch_size)
    print("n_epochs:", n_epochs)
    print("lr:", lr)
    print("num_workers:", num_workers)
    print("pin_memory:", pin_memory)
    print("-" * 80)

    trained_model = train(
        device=device,
        n_epochs=n_epochs,
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    print("Finish Training.")

    mean_loss = evaluate_model(
        device=device, model=trained_model, batch_size=batch_size
    )

    print("Average Loss:", mean_loss)
    print("=" * 120)

In [7]:
device = get_device()
device

device(type='mps')

In [8]:
test(
    num=1,
    device=device,
    n_epochs=100,
    lr=0.001,
    batch_size=64,
)

<<<<<<<<<<<<<<<<<<<< 1 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 64
n_epochs: 100
lr: 0.001
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 4656.526816248894
   Epoch   2 | Sum of Loss: 2520.8836246728897
   Epoch   3 | Sum of Loss: 1846.4811725616455
   Epoch   4 | Sum of Loss: 1462.0174854397774
   Epoch   5 | Sum of Loss: 1225.9045073390007
   Epoch   6 | Sum of Loss: 1075.2036587297916
   Epoch   7 | Sum of Loss: 972.8748465180397
   Epoch   8 | Sum of Loss: 895.1848843991756
   Epoch   9 | Sum of Loss: 836.3045299947262
   Epoch  10 | Sum of Loss: 783.4455498158932
   Epoch  11 | Sum of Loss: 741.9677709192038
   Epoch  12 | Sum of Loss: 706.2197566330433
   Epoch  13 | Sum of Loss: 671.0826960802078
   Epoch  14 | Sum of Loss: 641.7007465213537
   Epoch  15 | Sum of Loss: 616.4688709527254
   Epoch  16 | Sum of Loss: 591.2432393729687
   Epoch  17 | Sum of Loss: 567.918739438057

In [9]:
test(
    num=2,
    device=device,
    n_epochs=100,
    lr=0.001,
    batch_size=128,
)

<<<<<<<<<<<<<<<<<<<< 2 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 128
n_epochs: 100
lr: 0.001
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 3150.8514816761017
   Epoch   2 | Sum of Loss: 1679.7778406143188
   Epoch   3 | Sum of Loss: 1379.1406512260437
   Epoch   4 | Sum of Loss: 1144.5657110214233
   Epoch   5 | Sum of Loss: 965.2897183895111
   Epoch   6 | Sum of Loss: 835.3759979009628
   Epoch   7 | Sum of Loss: 740.6199172139168
   Epoch   8 | Sum of Loss: 670.0202758312225
   Epoch   9 | Sum of Loss: 614.7123789787292
   Epoch  10 | Sum of Loss: 570.9902322292328
   Epoch  11 | Sum of Loss: 535.584322988987
   Epoch  12 | Sum of Loss: 505.4246777892113
   Epoch  13 | Sum of Loss: 480.5290069580078
   Epoch  14 | Sum of Loss: 459.2700951099396
   Epoch  15 | Sum of Loss: 441.00994765758514
   Epoch  16 | Sum of Loss: 423.262957662344
   Epoch  17 | Sum of Loss: 408.3256037831306

In [10]:
test(
    num=3,
    device=device,
    n_epochs=100,
    lr=0.001,
    batch_size=256,
)

<<<<<<<<<<<<<<<<<<<< 3 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 256
n_epochs: 100
lr: 0.001
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 2076.8065502643585
   Epoch   2 | Sum of Loss: 1012.5580270290375
   Epoch   3 | Sum of Loss: 881.9028687477112
   Epoch   4 | Sum of Loss: 793.462349653244
   Epoch   5 | Sum of Loss: 719.9371078014374
   Epoch   6 | Sum of Loss: 655.372013092041
   Epoch   7 | Sum of Loss: 598.4144226312637
   Epoch   8 | Sum of Loss: 548.2477571964264
   Epoch   9 | Sum of Loss: 504.94272780418396
   Epoch  10 | Sum of Loss: 467.46191585063934
   Epoch  11 | Sum of Loss: 436.0094689130783
   Epoch  12 | Sum of Loss: 409.0290081501007
   Epoch  13 | Sum of Loss: 385.3111048936844
   Epoch  14 | Sum of Loss: 364.8206013441086
   Epoch  15 | Sum of Loss: 347.1907663345337
   Epoch  16 | Sum of Loss: 331.1636708378792
   Epoch  17 | Sum of Loss: 317.0799863934517


In [11]:
test(
    num=4,
    device=device,
    n_epochs=100,
    lr=0.01,
    batch_size=64,
)

<<<<<<<<<<<<<<<<<<<< 4 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 64
n_epochs: 100
lr: 0.01
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 2085.6829941272736
   Epoch   2 | Sum of Loss: 864.7714472711086
   Epoch   3 | Sum of Loss: 649.5532964020967
   Epoch   4 | Sum of Loss: 528.716418273747
   Epoch   5 | Sum of Loss: 445.16699853539467
   Epoch   6 | Sum of Loss: 381.3316645473242
   Epoch   7 | Sum of Loss: 332.2157292589545
   Epoch   8 | Sum of Loss: 285.7224821895361
   Epoch   9 | Sum of Loss: 256.30247639864683
   Epoch  10 | Sum of Loss: 223.39756705611944
   Epoch  11 | Sum of Loss: 197.61556823179126
   Epoch  12 | Sum of Loss: 177.30622961744666
   Epoch  13 | Sum of Loss: 157.12740667909384
   Epoch  14 | Sum of Loss: 142.71166347712278
   Epoch  15 | Sum of Loss: 127.06254420801997
   Epoch  16 | Sum of Loss: 115.15930036082864
   Epoch  17 | Sum of Loss: 105.289682075

In [12]:
test(
    num=5,
    device=device,
    n_epochs=100,
    lr=0.01,
    batch_size=128,
)

<<<<<<<<<<<<<<<<<<<< 5 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 128
n_epochs: 100
lr: 0.01
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 1382.9782609939575
   Epoch   2 | Sum of Loss: 616.0506313443184
   Epoch   3 | Sum of Loss: 447.3467983007431
   Epoch   4 | Sum of Loss: 370.7292150557041
   Epoch   5 | Sum of Loss: 314.6265758574009
   Epoch   6 | Sum of Loss: 278.6782314777374
   Epoch   7 | Sum of Loss: 249.84544822573662
   Epoch   8 | Sum of Loss: 231.7423243522644
   Epoch   9 | Sum of Loss: 204.89162692427635
   Epoch  10 | Sum of Loss: 191.89213261008263
   Epoch  11 | Sum of Loss: 173.74577514827251
   Epoch  12 | Sum of Loss: 163.15161250531673
   Epoch  13 | Sum of Loss: 153.6233011484146
   Epoch  14 | Sum of Loss: 141.1562547981739
   Epoch  15 | Sum of Loss: 130.07516093552113
   Epoch  16 | Sum of Loss: 119.10069768875837
   Epoch  17 | Sum of Loss: 115.089725837

In [13]:
test(
    num=6,
    device=device,
    n_epochs=100,
    lr=0.01,
    batch_size=256,
)

<<<<<<<<<<<<<<<<<<<< 6 번째 테스트 >>>>>>>>>>>>>>>>>>>>
device: mps
batch_size: 256
n_epochs: 100
lr: 0.01
num_workers: 4
pin_memory: True
--------------------------------------------------------------------------------
   Epoch   1 | Sum of Loss: 904.2223970890045
   Epoch   2 | Sum of Loss: 457.9501736164093
   Epoch   3 | Sum of Loss: 336.07708871364594
   Epoch   4 | Sum of Loss: 270.41756999492645
   Epoch   5 | Sum of Loss: 227.63025611639023
   Epoch   6 | Sum of Loss: 206.59857857227325
   Epoch   7 | Sum of Loss: 188.93117955327034
   Epoch   8 | Sum of Loss: 166.9390529692173
   Epoch   9 | Sum of Loss: 149.57255464792252
   Epoch  10 | Sum of Loss: 148.959058791399
   Epoch  11 | Sum of Loss: 136.8480058312416
   Epoch  12 | Sum of Loss: 128.68334129452705
   Epoch  13 | Sum of Loss: 121.15905177593231
   Epoch  14 | Sum of Loss: 115.82982394099236
   Epoch  15 | Sum of Loss: 111.69681245088577
   Epoch  16 | Sum of Loss: 102.5598232448101
   Epoch  17 | Sum of Loss: 98.318710595