In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

USE_MPS = True

# device
if USE_MPS:
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("üçé Using Apple MPS GPU")
    else:
        device = torch.device("cpu")
        print("üíª Using CPU")
else:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("üî• Using NVIDIA CUDA GPU")
    else:
        device = torch.device("cpu")
        print("üíª Using CPU")


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_ds = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)

test_ds = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64)


In [None]:
import numpy as np

x, y = next(iter(train_loader))

plt.figure(figsize=(8,4))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(x[i][0], cmap="gray")
    plt.title(str(y[i].item()))
    plt.axis("off")
plt.show()


## CNN Ë®ìÁ∑¥

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

model = CNN().to(device)
model


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:
train_losses = []

EPOCHES = 15

for epoch in range(EPOCHES):
    model.train()
    total_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg = total_loss / len(train_loader)
    train_losses.append(avg)
    print(f"Epoch {epoch+1} | loss = {avg:.4f}")


In [None]:
# Plot training loss

plt.plot(train_losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid()
plt.show()


In [None]:
model.eval()
correct = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=1)
        correct += (pred == y).sum().item()

acc = correct / len(test_ds)
print("Test accuracy =", acc)


## ÂÑ≤Â≠òÊ®°Âûã

In [None]:
# Âª∫Ë≠∞Ë∑ØÂæëÂæåÁ∂¥‰ΩøÁî® .pth Êàñ .pt
model_path = "mnist_cnn.pth"

# ÂÑ≤Â≠òÊ®°ÂûãÊ¨äÈáç (state_dict)
torch.save(model.state_dict(), model_path)
print(f"Ê®°ÂûãÂ∑≤ÂÑ≤Â≠òËá≥ {model_path}")

## ËÆÄÂèñÊ®°Âûã

In [None]:
LOAD_MODEL = False

if LOAD_MODEL:
    # 1. ÂàùÂßãÂåñÊ®°ÂûãÁµêÊßã (ÂøÖÈ†àËàáË®ìÁ∑¥ÊôÇÂÆöÁæ©ÁöÑ CNN È°û‰∏ÄËá¥)
    loaded_model = CNN() 

    # 2. ËºâÂÖ•Ê¨äÈáçÊ™îÊ°à
    model_path = "mnist_cnn.pth"
    state_dict = torch.load(model_path, map_location=device) # ËÄÉÊÖÆÂà∞ÂèØËÉΩÂú®‰∏çÂêåË®≠ÂÇôÈñìËºâÂÖ•

    # 3. Â∞áÊ¨äÈáçÂ•óÁî®Âà∞Ê®°Âûã‰∏≠
    loaded_model.load_state_dict(state_dict)

    # 4. Â∞áÊ®°ÂûãÁßªËá≥Ê≠£Á¢∫ÁöÑË®≠ÂÇô (CPU Êàñ MPS/GPU) ‰∏¶Ë®≠ÁÇ∫Ë©ï‰º∞Ê®°Âºè
    # Â¶ÇÊûúÊúâ‰ΩøÁî® MPS/GPUÔºåË®òÂæóÊää model.to(device) Ë®≠ÂÆöÊàêÊ≠£Á¢∫ÁöÑË®≠ÂÇô

    loaded_model.to(device)
    loaded_model.eval()

    print("Ê®°ÂûãËºâÂÖ•ÊàêÂäüÔºåÂèØ‰ª•Áõ¥Êé•ÈÄ≤Ë°åÈ†êÊ∏¨„ÄÇ")

## User Ëº∏ÂÖ•

In [None]:
from PIL import Image
import torchvision.transforms.functional as F


img_path = "5.1.png"   # Êîæ‰Ω†Ëá™Â∑±ÁöÑÂúñ

img = Image.open(img_path).convert("L")  # ËΩâÁÅ∞Èöé
plt.imshow(img, cmap="gray")
plt.axis("off")


In [None]:
# Â¶ÇÊûúÂúñÊòØÁôΩÂ∫ïÈªëÂ≠óÔºåË®òÂæóÂèçËΩâ
img = F.invert(img)

In [None]:
import numpy as np

arr = np.array(img)

# Ëá™ÂãïÂà§Êñ∑ÂâçÊôØÈ°èËâ≤
if arr.mean() > 127:   # ÁôΩÂ∫ïÈªëÂ≠ó
    mask = arr < 200
else:                  # ÈªëÂ∫ïÁôΩÂ≠ó
    mask = arr > 50

ys, xs = np.where(mask)

top, bottom = ys.min(), ys.max()
left, right = xs.min(), xs.max()

img_crop = img.crop((left, top, right, bottom))
plt.imshow(img_crop, cmap="gray")
plt.axis("off")


In [None]:
IS_CROP = False

import torchvision.transforms as T
transform = T.Compose([
    T.Resize((28,28)),      # ‚ö†Ô∏è ÈóúÈçµ
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,))
])

if IS_CROP:
    x = transform(img_crop)
else:
    x = transform(img)
x = x.unsqueeze(0).to(device)

plt.imshow(x[0][0].cpu(), cmap="gray")
plt.title("Final 28x28")
plt.axis("off")


In [None]:
model.eval()

with torch.no_grad():
    out = model(x)
    pred = out.argmax(dim=1).item()

print("Ê®°ÂûãÁåúÁöÑÊòØ:", pred)