In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [3]:
#데이터셋 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

#MNIST 데이터셋 로드
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:992)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:09<00:00, 1.00MB/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:992)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 128kB/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:992)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.17MB/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:992)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.54MB/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [4]:
class SimpleANN(nn.Module):
    def __init__(self):
        super(SimpleANN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  #입력층에서 은닉층으로
        self.fc2 = nn.Linear(128, 64)       #은닉층에서 은닉층으로
        self.fc3 = nn.Linear(64, 10)        #은닉층에서 출력층으로

    def forward(self, x):
        x = x.view(-1, 28 * 28)  #입력 이미지를 1차원 벡터로 변환
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

nn.Module : 모든 신경망 모듈의 기본 클래스, 사용자 정의 신경망은 이 클래스를 상속받음.  
nn.Linear(in_features, out_features) : 선형 변환을 적용하는 완전 연결 레이어를 정의.  
view : 텐서의 크기 변경.  
torch.relu() : ReLU 활성화 함수를 적용한다.  

In [5]:
#모델 초기화
model = SimpleANN()

#손실 함수와 최적화 알고리즘 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)    #lr = 학습률, momentum = 모멘텀 값

#모델 학습
for epoch in range(10):  # 10 에포크 동안 학습
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        #기울기 초기화
        optimizer.zero_grad()

        #순전파 + 역전파 + 최적화
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        #손실 출력
        running_loss += loss.item()
        if i % 100 == 99:  # 매 100 미니배치마다 출력
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

[Epoch 1, Batch 100] loss: 1.317
[Epoch 1, Batch 200] loss: 0.491
[Epoch 1, Batch 300] loss: 0.392
[Epoch 1, Batch 400] loss: 0.333
[Epoch 1, Batch 500] loss: 0.314
[Epoch 1, Batch 600] loss: 0.293
[Epoch 1, Batch 700] loss: 0.287
[Epoch 1, Batch 800] loss: 0.254
[Epoch 1, Batch 900] loss: 0.224
[Epoch 2, Batch 100] loss: 0.210
[Epoch 2, Batch 200] loss: 0.201
[Epoch 2, Batch 300] loss: 0.177
[Epoch 2, Batch 400] loss: 0.182
[Epoch 2, Batch 500] loss: 0.181
[Epoch 2, Batch 600] loss: 0.173
[Epoch 2, Batch 700] loss: 0.168
[Epoch 2, Batch 800] loss: 0.163
[Epoch 2, Batch 900] loss: 0.173
[Epoch 3, Batch 100] loss: 0.140
[Epoch 3, Batch 200] loss: 0.128
[Epoch 3, Batch 300] loss: 0.146
[Epoch 3, Batch 400] loss: 0.127
[Epoch 3, Batch 500] loss: 0.131
[Epoch 3, Batch 600] loss: 0.121
[Epoch 3, Batch 700] loss: 0.129
[Epoch 3, Batch 800] loss: 0.129
[Epoch 3, Batch 900] loss: 0.129
[Epoch 4, Batch 100] loss: 0.111
[Epoch 4, Batch 200] loss: 0.104
[Epoch 4, Batch 300] loss: 0.105
[Epoch 4, 

nn.CrossEntropyLoss : 다중 클래스 분류 문제에서 주로 사용되는 손실 함수. 예측 값과 실제 값 사이의 교차 엔트로피 손실을 계산.  
optim.SGD : 확률적 경사 하강법(Stochastic Gradient Descent) 최적화 알고리즘 정의.  
optimizer.zero_grad() : 이전 단계에서 계산된 기울기 초기화.  
loss.backward() : 역전파를 통해 기울기 계산.
optimizer.step() : 계산된 기울기를 바탕으로 가중치 업데이트.

In [6]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

Accuracy of the network on the 10000 test images: 97.24%


torch.no_grad() : 평가 단계에서는 기울기 계산 필요가 없음. torch를 비활성화하여 메모리 사용 감소.  
torch.max(outputs.data, 1) : 텐서의 최대 값을 찾는다.(각 샘플에 대해 가장 높은 확률을 가진 클래스 반환)  
labels.size(0) : 배치 크기를 반환
(predicted == labels).sum().item() : 예측 값, 실제 값 일치하는 샘플 수 계산.