In [12]:
import itertools
from IPython.display import Image
from IPython import display
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import random

In [2]:
trn_dataset = datasets.MNIST('./mnist_data/',
                             download=False,
                             train=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(), # image to Tensor
                                 transforms.Normalize((0.1307,), (0.3081,)) # image, label
                             ]))

val_dataset = datasets.MNIST("./mnist_data/",
                             download=False,
                             train=False,
                             transform= transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ),(0.3081, ))
                           ]))

In [3]:
batch_size = 64
trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                        drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                        drop_last=True)

In [4]:
# construct model on cuda if available
use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu") # GPU 사용 가능하면 사용하고 아니면 CPU 사용
print("다음 기기로 학습합니다:", device)

다음 기기로 학습합니다: cpu


In [5]:
class CNNClassifier(nn.Module):

    def __init__(self):
        # 항상 torch.nn.Module을 상속받고 시작
        super(CNNClassifier, self).__init__()
        conv1 = nn.Conv2d(1, 6, 5, 1)  # 6@24*24
        # activation ReLU
        pool1 = nn.MaxPool2d(2)  # 6@12*12
        conv2 = nn.Conv2d(6, 16, 5, 1)  # 16@8*8
        # activation ReLU
        pool2 = nn.MaxPool2d(2)  # 16@4*4

        self.conv_module = nn.Sequential(
            conv1,
            nn.ReLU(),
            pool1,
            conv2,
            nn.ReLU(),
            pool2
        )

        fc1 = nn.Linear(16 * 4 * 4, 120)
        # activation ReLU
        fc2 = nn.Linear(120, 84)
        # activation ReLU
        fc3 = nn.Linear(84, 10)

        self.fc_module = nn.Sequential(
            fc1,
            nn.ReLU(),
            fc2,
            nn.ReLU(),
            fc3
        )

        # gpu로 할당
        if use_cuda:
            self.conv_module = self.conv_module.cuda()
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        out = self.conv_module(x)  # @16*4*4
        # make linear
        dim = 1
        for d in out.size()[1:]:  # 16, 4, 4
            dim = dim * d
        out = out.view(-1, dim)
        out = self.fc_module(out)
        return F.softmax(out, dim=1)

cnn = CNNClassifier()

In [6]:
# loss
criterion = nn.CrossEntropyLoss()
# backpropagation method
learning_rate = 1e-3
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
# hyper-parameters
num_epochs = 2
num_batches = len(trn_loader)

trn_loss_list = []
val_loss_list = []
for epoch in range(num_epochs):
    trn_loss = 0.0
    for i, data in enumerate(trn_loader):
        x, label = data
        if use_cuda:
            x = x.cuda()
            label = label.cuda()
        # grad init
        optimizer.zero_grad()
        # forward propagation
        model_output = cnn(x)
        # calculate loss
        loss = criterion(model_output, label)
        # back propagation
        loss.backward()
        # weight update
        optimizer.step()

        # trn_loss summary
        trn_loss += loss.item()
        # del (memory issue)
        del loss
        del model_output

        # 학습과정 출력
        if (i + 1) % 100 == 0:  # every 100 mini-batches
            with torch.no_grad():  # very very very very important!!!
                val_loss = 0.0
                for j, val in enumerate(val_loader):
                    val_x, val_label = val
                    if use_cuda:
                        val_x = val_x.cuda()
                        val_label = val_label.cuda()
                    val_output = cnn(val_x)
                    v_loss = criterion(val_output, val_label)
                    val_loss += v_loss

            print("epoch: {}/{} | step: {}/{} | trn loss: {:.4f} | val loss: {:.4f}".format(
                epoch + 1, num_epochs, i + 1, num_batches, trn_loss / 100, val_loss / len(val_loader)
            ))

            trn_loss_list.append(trn_loss / 100)
            val_loss_list.append(val_loss / len(val_loader))
            trn_loss = 0.0

epoch: 1/2 | step: 100/937 | trn loss: 1.9089 | val loss: 1.6622
epoch: 1/2 | step: 200/937 | trn loss: 1.6536 | val loss: 1.6156
epoch: 1/2 | step: 300/937 | trn loss: 1.6180 | val loss: 1.6074
epoch: 1/2 | step: 400/937 | trn loss: 1.6130 | val loss: 1.5998
epoch: 1/2 | step: 500/937 | trn loss: 1.5978 | val loss: 1.5868
epoch: 1/2 | step: 600/937 | trn loss: 1.5428 | val loss: 1.5232
epoch: 1/2 | step: 700/937 | trn loss: 1.5121 | val loss: 1.5176
epoch: 1/2 | step: 800/937 | trn loss: 1.5089 | val loss: 1.5035
epoch: 1/2 | step: 900/937 | trn loss: 1.5065 | val loss: 1.4991
epoch: 2/2 | step: 100/937 | trn loss: 1.4967 | val loss: 1.4929
epoch: 2/2 | step: 200/937 | trn loss: 1.5018 | val loss: 1.4936
epoch: 2/2 | step: 300/937 | trn loss: 1.5027 | val loss: 1.4906
epoch: 2/2 | step: 400/937 | trn loss: 1.4962 | val loss: 1.4913
epoch: 2/2 | step: 500/937 | trn loss: 1.4950 | val loss: 1.4904
epoch: 2/2 | step: 600/937 | trn loss: 1.4956 | val loss: 1.4885
epoch: 2/2 | step: 700/93

In [19]:
with torch.no_grad(): # torch.no_grad()를 하면 gradient 계산을 수행하지 않는다.
#     X_test, Y_test = val_loader
    for i, data in enumerate(val_loader):
        x_test, Y_test = data

        prediction = cnn(x_test)
        print(prediction)
        correct_prediction = torch.argmax(prediction, 1) == Y_test
        accuracy = correct_prediction.float().mean()
        print('Accuracy:', accuracy.item())

tensor([[9.9994e-01, 1.8472e-16, 1.3718e-11, 1.5439e-10, 2.1063e-12, 6.2013e-05,
         1.8560e-07, 8.9588e-11, 1.5506e-10, 5.9226e-09],
        [2.0564e-21, 1.6848e-20, 8.1726e-17, 1.0000e+00, 1.2946e-21, 2.1238e-09,
         2.3107e-25, 1.9979e-16, 3.4673e-13, 1.0874e-13],
        [1.0000e+00, 4.4609e-25, 5.3846e-14, 1.4288e-21, 9.3287e-23, 1.2050e-10,
         2.4658e-15, 5.1311e-13, 1.3367e-15, 2.0968e-15],
        [1.7579e-10, 3.3360e-12, 1.0000e+00, 6.2224e-08, 9.0482e-19, 3.4186e-14,
         3.4377e-17, 2.4878e-12, 3.3225e-06, 2.9191e-18],
        [4.9216e-16, 6.6534e-15, 1.0000e+00, 3.0125e-14, 7.3259e-20, 2.6045e-23,
         3.6504e-21, 2.6054e-15, 3.0705e-15, 3.0402e-26],
        [1.4737e-20, 2.3595e-23, 3.8721e-22, 4.4252e-30, 1.0000e+00, 1.3302e-22,
         4.0650e-16, 8.3424e-19, 1.8075e-22, 4.3878e-18],
        [2.5231e-11, 6.5015e-15, 4.0662e-19, 3.2368e-11, 6.5705e-17, 1.0000e+00,
         7.6199e-10, 2.1878e-14, 1.1444e-10, 3.4481e-12],
        [1.5328e-15, 1.0843

tensor([[5.4000e-15, 9.3990e-17, 1.9912e-13, 1.9980e-07, 4.5100e-18, 5.9530e-08,
         1.0116e-14, 1.4303e-17, 1.0000e+00, 4.2284e-13],
        [1.8411e-17, 9.1719e-18, 8.7749e-18, 3.8308e-10, 3.5480e-10, 9.4478e-11,
         2.3491e-22, 2.2257e-11, 6.5537e-08, 1.0000e+00],
        [7.1923e-21, 4.2805e-23, 3.6247e-19, 4.1360e-29, 1.0000e+00, 9.9097e-25,
         5.6291e-17, 2.7903e-18, 2.3433e-22, 1.1465e-19],
        [4.3823e-13, 4.8027e-14, 1.8993e-07, 2.2394e-06, 9.5611e-20, 6.4826e-09,
         3.0810e-15, 8.1244e-16, 1.0000e+00, 6.3078e-15],
        [1.9273e-09, 1.0000e+00, 1.0015e-07, 3.5458e-12, 2.0050e-08, 3.1685e-12,
         2.6028e-10, 3.2290e-07, 6.4645e-09, 1.4922e-11],
        [2.5673e-10, 1.0000e+00, 5.0647e-08, 4.3961e-10, 6.1963e-09, 4.9079e-10,
         1.0227e-09, 1.3275e-07, 7.6762e-09, 2.3297e-11],
        [2.5847e-09, 1.0000e+00, 2.3390e-06, 7.6561e-10, 5.8638e-08, 2.4931e-10,
         2.8097e-09, 1.6730e-06, 1.6141e-08, 1.6228e-10],
        [1.5046e-25, 8.4329

Accuracy: 0.921875
tensor([[1.0000e+00, 5.8475e-21, 1.5103e-10, 6.7733e-16, 1.6280e-20, 2.2414e-07,
         1.5655e-12, 4.0367e-13, 4.2819e-14, 3.7427e-13],
        [1.4583e-13, 2.2608e-20, 8.6080e-18, 3.0555e-25, 2.4144e-10, 2.3391e-13,
         1.0000e+00, 1.4032e-22, 2.0745e-16, 5.7602e-23],
        [8.9119e-14, 1.5348e-11, 1.7000e-09, 5.0485e-03, 2.3398e-17, 2.9193e-07,
         2.7785e-18, 7.6345e-09, 9.9495e-01, 1.2260e-10],
        [4.5777e-01, 2.4315e-17, 1.3302e-12, 8.4825e-16, 1.0401e-16, 4.9526e-01,
         4.6974e-02, 1.2801e-14, 2.1598e-07, 5.8994e-14],
        [1.4407e-09, 1.0000e+00, 8.6316e-10, 5.2381e-15, 1.1834e-07, 4.9902e-13,
         5.0975e-10, 4.2416e-09, 2.6446e-11, 2.6281e-14],
        [2.8508e-22, 1.3414e-26, 5.8384e-30, 6.4460e-09, 2.1906e-25, 1.0000e+00,
         1.8827e-20, 1.0337e-23, 8.6967e-12, 1.8824e-15],
        [9.4588e-10, 3.5050e-12, 1.7856e-12, 4.9227e-13, 7.9482e-14, 7.6014e-08,
         9.0007e-14, 1.0000e+00, 2.8937e-11, 1.4883e-12],
        

tensor([[1.3757e-23, 7.8344e-21, 1.0115e-17, 1.0000e+00, 2.0530e-24, 4.8280e-10,
         2.0863e-26, 8.6835e-18, 1.8102e-13, 1.1024e-17],
        [8.2957e-16, 8.3781e-21, 2.9711e-17, 2.7532e-14, 1.4745e-08, 5.0695e-12,
         4.6172e-21, 4.5153e-10, 8.1524e-12, 1.0000e+00],
        [1.8558e-15, 2.9470e-20, 1.4656e-18, 5.5798e-21, 2.3007e-11, 8.3209e-11,
         1.0000e+00, 5.0159e-23, 2.4766e-17, 3.0273e-25],
        [4.3060e-12, 4.0233e-14, 9.1864e-15, 3.7325e-14, 4.2107e-10, 2.9101e-08,
         1.0000e+00, 1.9919e-19, 1.5593e-09, 2.4120e-20],
        [5.7282e-17, 3.2916e-25, 2.2172e-19, 5.3670e-17, 9.7681e-08, 4.0293e-13,
         6.7278e-22, 5.8919e-15, 2.8930e-15, 1.0000e+00],
        [4.3834e-15, 1.0694e-18, 4.2763e-20, 8.3239e-11, 2.6702e-13, 1.0000e+00,
         6.8664e-11, 1.0797e-16, 2.3185e-10, 1.7974e-08],
        [9.9999e-01, 4.9840e-11, 2.0647e-08, 5.3708e-13, 8.9656e-08, 2.6568e-09,
         7.8447e-06, 1.1081e-08, 2.0529e-10, 7.7293e-10],
        [3.1743e-07, 9.9996

tensor([[2.5155e-24, 3.7629e-22, 3.5667e-20, 1.0000e+00, 7.5591e-24, 9.3697e-09,
         3.1433e-26, 6.2696e-18, 7.7420e-16, 8.1642e-17],
        [2.2403e-18, 2.2181e-23, 1.4906e-18, 3.5078e-28, 1.0000e+00, 2.7957e-23,
         1.0261e-14, 7.4010e-21, 7.2525e-22, 3.8148e-20],
        [1.0663e-12, 2.0206e-10, 6.0980e-09, 8.8648e-12, 1.0000e+00, 2.7660e-12,
         4.8716e-07, 2.1185e-11, 4.7522e-12, 5.8277e-15],
        [2.6504e-29, 2.8042e-27, 1.5509e-27, 9.9992e-01, 4.8178e-29, 8.0783e-05,
         2.3640e-27, 7.5206e-23, 6.1236e-19, 7.7455e-20],
        [8.1695e-23, 6.9322e-18, 2.5935e-16, 1.0000e+00, 7.0150e-22, 8.5126e-10,
         1.7902e-22, 4.2086e-17, 1.5895e-14, 3.0705e-18],
        [7.6119e-15, 7.5972e-13, 1.0000e+00, 4.9368e-16, 7.0156e-25, 1.3255e-24,
         1.1573e-26, 1.9838e-12, 8.1399e-15, 6.6562e-25],
        [4.7137e-19, 4.8339e-17, 1.3726e-14, 7.8271e-09, 2.5189e-21, 5.5082e-11,
         7.0277e-18, 1.3927e-20, 1.0000e+00, 1.8093e-18],
        [2.6964e-25, 3.3708

Accuracy: 0.984375
tensor([[3.5367e-09, 5.6784e-11, 1.8412e-10, 1.0571e-13, 8.0781e-07, 7.7618e-08,
         1.0000e+00, 9.4216e-14, 4.7929e-10, 9.4231e-16],
        [2.0644e-14, 1.1021e-17, 6.5706e-14, 2.7889e-10, 3.9979e-21, 2.7202e-10,
         1.0882e-18, 4.3815e-17, 1.0000e+00, 6.8428e-14],
        [4.8195e-11, 9.9712e-01, 2.8750e-03, 1.0363e-13, 1.7708e-13, 2.6326e-14,
         8.5412e-09, 2.6444e-12, 2.0808e-07, 2.1458e-21],
        [4.4913e-09, 9.9999e-01, 8.3201e-07, 1.5467e-06, 1.8220e-07, 1.2108e-07,
         2.7574e-09, 9.3174e-06, 4.1387e-07, 7.0200e-08],
        [1.2860e-15, 1.2282e-13, 1.7343e-09, 2.4062e-11, 3.3745e-18, 7.1674e-12,
         6.8748e-14, 1.2317e-16, 1.0000e+00, 5.0898e-16],
        [2.2680e-15, 2.2706e-18, 1.8978e-17, 3.4516e-05, 5.0092e-19, 9.9997e-01,
         3.6792e-15, 2.2219e-15, 3.3707e-08, 5.9884e-10],
        [3.7732e-08, 1.4930e-09, 2.5473e-08, 1.8669e-07, 1.2754e-11, 1.3525e-06,
         1.0428e-08, 2.5424e-11, 1.0000e+00, 2.2253e-09],
        

tensor([[4.3006e-18, 9.4894e-24, 9.1117e-27, 1.8507e-12, 4.7021e-20, 1.0000e+00,
         7.8008e-15, 5.1672e-22, 4.3327e-14, 9.3234e-13],
        [2.5250e-11, 7.8847e-20, 6.6916e-15, 1.1236e-15, 5.0785e-05, 1.0265e-10,
         3.4143e-15, 1.2022e-10, 2.7087e-12, 9.9995e-01],
        [7.7754e-15, 3.0650e-14, 1.8228e-17, 5.6967e-09, 6.3140e-08, 1.0863e-07,
         1.2801e-17, 7.2569e-11, 5.0729e-08, 1.0000e+00],
        [4.4028e-13, 2.1646e-15, 1.0691e-13, 8.7375e-10, 9.3359e-06, 6.0219e-09,
         1.8225e-16, 2.3755e-08, 1.1505e-09, 9.9999e-01],
        [1.6745e-08, 9.9999e-01, 2.0091e-06, 7.1826e-08, 2.0162e-06, 1.1996e-07,
         1.6437e-07, 4.5794e-06, 8.7166e-07, 2.0368e-08],
        [5.2542e-27, 4.9109e-26, 3.5734e-28, 6.9955e-29, 1.0000e+00, 4.2720e-21,
         8.4753e-20, 1.0435e-19, 4.3935e-20, 8.3168e-18],
        [3.9477e-20, 2.6670e-18, 9.2072e-15, 1.0000e+00, 4.2879e-20, 2.1839e-09,
         8.2498e-24, 5.2562e-15, 8.9949e-13, 1.8239e-13],
        [5.9633e-11, 2.2659

tensor([[1.2973e-12, 2.2035e-14, 1.7555e-13, 8.2724e-09, 7.7866e-07, 4.2898e-08,
         4.3096e-16, 5.9700e-06, 3.1641e-08, 9.9999e-01],
        [3.6122e-08, 1.0000e+00, 1.4314e-08, 1.0877e-12, 7.5252e-08, 8.7169e-11,
         4.5740e-09, 1.1834e-08, 4.5285e-08, 2.9786e-13],
        [8.7269e-08, 4.3630e-07, 1.4176e-05, 2.7247e-05, 2.4909e-11, 1.8913e-06,
         1.2079e-08, 1.4394e-09, 9.9996e-01, 1.2153e-09],
        [1.4951e-14, 7.6491e-16, 2.8279e-12, 5.8903e-16, 1.5810e-17, 2.1931e-18,
         4.9652e-26, 1.0000e+00, 5.0103e-16, 1.6133e-11],
        [2.2517e-17, 5.6919e-12, 1.0995e-18, 4.6926e-14, 1.0000e+00, 5.2457e-12,
         2.6857e-12, 1.1293e-12, 3.0210e-13, 1.6048e-11],
        [3.5056e-22, 2.6158e-25, 9.2375e-28, 2.4189e-06, 1.8751e-25, 1.0000e+00,
         9.1818e-22, 2.0789e-21, 6.9129e-11, 7.3861e-15],
        [1.5591e-07, 6.1597e-01, 5.8933e-02, 5.4162e-05, 2.8452e-08, 7.2032e-09,
         1.6959e-11, 3.2483e-01, 2.1251e-04, 3.2838e-07],
        [1.0000e+00, 4.0202

tensor([[1.5670e-16, 7.4319e-21, 3.4881e-25, 6.5084e-12, 5.7160e-19, 1.0000e+00,
         2.0247e-12, 1.4233e-19, 1.6017e-13, 2.3199e-12],
        [2.0172e-19, 4.1521e-24, 4.6281e-21, 7.9289e-14, 3.3859e-11, 3.7844e-12,
         1.7686e-25, 1.4794e-12, 5.6289e-12, 1.0000e+00],
        [4.0290e-09, 1.0000e+00, 1.0382e-10, 4.2116e-16, 1.7268e-07, 2.5732e-13,
         2.5895e-09, 4.2105e-09, 6.0432e-11, 3.5856e-15],
        [2.1007e-10, 2.3208e-05, 4.6132e-05, 2.3032e-01, 2.4046e-06, 7.6386e-01,
         5.3797e-03, 1.4326e-11, 3.7440e-04, 2.3961e-12],
        [9.0801e-21, 3.0562e-26, 1.4194e-22, 1.8600e-25, 7.5386e-17, 3.8684e-12,
         1.0000e+00, 3.4412e-32, 8.4169e-16, 1.6262e-32],
        [1.0000e+00, 1.5073e-18, 3.2694e-13, 5.4583e-19, 2.7683e-13, 2.6318e-13,
         2.1407e-10, 5.8361e-12, 4.8437e-16, 8.5420e-15],
        [1.0629e-18, 8.6289e-16, 1.1565e-15, 1.2405e-23, 1.0000e+00, 2.2377e-22,
         2.8923e-14, 1.0912e-15, 6.0867e-19, 1.2216e-20],
        [4.8016e-19, 1.2637

tensor([[1.0314e-23, 9.6771e-29, 8.0559e-33, 5.2063e-11, 1.7798e-23, 1.0000e+00,
         7.9877e-21, 3.1937e-23, 7.9267e-16, 1.0564e-13],
        [1.7582e-22, 1.2878e-21, 1.3874e-23, 2.8527e-23, 1.0000e+00, 1.4739e-19,
         6.8949e-18, 9.6753e-19, 1.2217e-20, 5.8932e-16],
        [9.6771e-16, 1.1312e-11, 1.2397e-05, 7.4458e-05, 2.6728e-20, 1.8416e-11,
         1.8083e-16, 2.3015e-15, 9.9991e-01, 1.1708e-20],
        [8.7468e-13, 4.4708e-14, 5.1549e-09, 1.2874e-13, 2.6059e-16, 1.1871e-17,
         7.9462e-25, 1.0000e+00, 2.0960e-12, 1.3996e-10],
        [1.3968e-15, 1.5348e-14, 1.6508e-14, 9.9988e-01, 1.4414e-13, 1.2301e-04,
         4.5901e-15, 5.6083e-11, 4.3439e-13, 2.4861e-09],
        [4.8386e-19, 1.4349e-17, 7.3219e-13, 1.0000e+00, 6.0216e-22, 1.1467e-09,
         2.3770e-23, 5.2156e-13, 4.4412e-12, 5.0877e-16],
        [4.8235e-15, 2.4934e-15, 1.0000e+00, 3.4257e-16, 2.0852e-26, 2.0272e-25,
         2.4510e-28, 1.1756e-11, 1.7700e-16, 7.7409e-27],
        [1.9758e-25, 8.4616

tensor([[6.7736e-31, 4.5566e-24, 3.1120e-19, 1.0000e+00, 3.7302e-31, 1.2499e-16,
         4.7747e-36, 2.2223e-21, 5.1581e-16, 5.0956e-25],
        [2.2891e-11, 3.0515e-10, 1.5427e-12, 1.4171e-05, 1.2453e-05, 1.0698e-05,
         1.7127e-13, 1.1218e-07, 2.3514e-06, 9.9996e-01],
        [1.2411e-12, 1.3889e-11, 1.0000e+00, 7.9549e-14, 2.4077e-19, 2.5768e-19,
         3.1883e-19, 2.4110e-12, 3.8105e-10, 5.8783e-20],
        [5.9871e-08, 1.0000e+00, 1.3246e-07, 2.0711e-14, 6.8349e-08, 2.6922e-14,
         2.6455e-10, 3.6147e-08, 7.9945e-10, 5.1371e-14],
        [5.9536e-20, 2.8969e-28, 6.2351e-32, 3.1542e-15, 6.8877e-26, 1.0000e+00,
         3.6127e-18, 6.8486e-24, 2.3202e-13, 1.2667e-16],
        [1.0012e-15, 2.4077e-23, 5.7801e-15, 2.1401e-15, 1.0725e-08, 2.3507e-13,
         1.1878e-21, 1.9511e-13, 9.5634e-13, 1.0000e+00],
        [7.1141e-16, 4.7039e-18, 1.6936e-18, 4.9499e-09, 3.3787e-09, 2.8862e-08,
         3.8467e-20, 1.1972e-11, 2.0684e-10, 1.0000e+00],
        [7.9649e-19, 3.8656

tensor([[4.3815e-14, 1.0116e-16, 1.7524e-14, 8.0873e-08, 2.3808e-15, 3.4397e-08,
         9.4141e-20, 7.7233e-13, 9.9999e-01, 9.4457e-06],
        [2.5812e-19, 8.0223e-19, 1.1793e-14, 1.6941e-15, 9.6535e-24, 5.7653e-21,
         9.2821e-33, 1.0000e+00, 1.6829e-20, 6.8351e-15],
        [3.1875e-15, 2.2360e-23, 1.0954e-19, 2.2626e-14, 3.2435e-18, 1.0000e+00,
         8.7907e-10, 5.2197e-22, 8.9480e-08, 6.5789e-13],
        [1.2331e-14, 8.3661e-25, 3.5939e-22, 1.6266e-24, 2.5049e-17, 8.8702e-11,
         1.0000e+00, 2.8707e-29, 2.2675e-11, 6.8970e-25],
        [4.3640e-15, 9.8814e-22, 1.0000e+00, 2.7080e-16, 3.2221e-22, 2.0039e-25,
         1.7159e-26, 4.1147e-17, 1.1496e-20, 9.4625e-25],
        [1.2537e-26, 2.1624e-24, 1.4414e-15, 1.0000e+00, 3.9866e-29, 4.9029e-16,
         2.7254e-32, 5.6640e-20, 7.8398e-17, 2.6500e-23],
        [1.9825e-12, 4.1032e-15, 9.9955e-01, 4.2376e-04, 2.4980e-19, 1.1752e-09,
         1.0707e-20, 3.2733e-07, 2.1581e-05, 1.1402e-13],
        [3.0227e-20, 1.0839

tensor([[3.6205e-11, 2.7227e-12, 1.7970e-11, 1.7625e-08, 7.3718e-15, 2.2393e-07,
         5.7970e-13, 9.1332e-13, 1.0000e+00, 3.0147e-10],
        [1.8148e-11, 1.1602e-16, 9.0699e-17, 2.1030e-16, 3.7440e-12, 1.4239e-07,
         1.0000e+00, 5.8968e-22, 9.6317e-11, 8.5596e-22],
        [1.0000e+00, 2.0771e-24, 6.2176e-16, 5.7806e-22, 2.0290e-15, 4.8271e-09,
         2.6303e-09, 1.8314e-14, 7.2898e-18, 8.9836e-12],
        [5.9118e-19, 1.1103e-23, 6.4928e-28, 2.2673e-13, 1.4861e-22, 1.0000e+00,
         8.5837e-17, 3.7688e-22, 2.0578e-16, 1.5503e-13],
        [1.6606e-08, 1.0000e+00, 1.6566e-09, 1.4507e-15, 1.8496e-07, 1.4362e-13,
         1.4676e-09, 9.2104e-09, 2.3340e-10, 1.4951e-14],
        [1.1555e-08, 9.9978e-01, 8.6908e-08, 7.8033e-09, 3.2580e-09, 2.2479e-07,
         2.0585e-06, 1.0307e-09, 2.1793e-04, 6.6531e-12],
        [8.2585e-15, 2.8277e-18, 6.4319e-16, 9.6804e-01, 5.8835e-21, 3.1957e-02,
         2.0908e-19, 2.0872e-11, 2.1834e-13, 7.9387e-11],
        [6.3039e-12, 3.2137

Accuracy: 1.0
tensor([[3.1800e-17, 6.9778e-22, 4.6567e-19, 6.7280e-15, 2.5692e-07, 1.9184e-12,
         3.6509e-22, 1.8576e-11, 4.4673e-13, 1.0000e+00],
        [1.7778e-10, 6.4576e-12, 1.0000e+00, 1.5294e-09, 6.6315e-13, 2.4283e-15,
         1.3519e-14, 3.4595e-11, 3.3828e-06, 5.4137e-14],
        [1.8928e-10, 2.2205e-07, 6.2529e-10, 6.1596e-07, 3.7742e-08, 1.6049e-07,
         1.6808e-13, 9.9998e-01, 6.0939e-09, 1.9827e-05],
        [3.6163e-18, 1.3216e-17, 9.1803e-14, 3.1289e-18, 3.9644e-22, 1.2264e-20,
         5.4846e-29, 1.0000e+00, 1.1733e-18, 2.8023e-16],
        [7.9140e-12, 3.8651e-12, 1.0000e+00, 6.3900e-13, 1.1921e-15, 9.1316e-19,
         1.9018e-15, 1.4133e-12, 3.4276e-10, 4.3863e-21],
        [3.9126e-16, 4.2112e-19, 2.3518e-19, 1.9892e-19, 1.8609e-17, 1.1751e-15,
         1.8680e-23, 1.0000e+00, 1.0460e-19, 4.0026e-13],
        [4.9101e-28, 4.4199e-27, 6.0631e-33, 2.4359e-06, 4.8334e-26, 1.0000e+00,
         2.0132e-21, 1.6729e-26, 5.5111e-15, 1.5006e-19],
        [1.00

tensor([[2.0045e-31, 5.7721e-35, 4.8049e-41, 3.8930e-17, 7.3036e-31, 1.0000e+00,
         1.2549e-23, 9.2519e-34, 3.7957e-19, 9.4855e-21],
        [5.8827e-09, 1.5917e-10, 9.8605e-01, 5.1220e-06, 6.4112e-11, 5.2220e-10,
         4.9002e-15, 1.3944e-02, 1.6158e-09, 2.3504e-10],
        [5.4655e-22, 1.5537e-24, 1.4365e-21, 4.9813e-27, 1.0000e+00, 4.7511e-21,
         1.6526e-17, 3.5373e-20, 3.4725e-21, 5.7817e-16],
        [1.0000e+00, 1.9242e-16, 5.4703e-10, 2.3600e-15, 1.5561e-13, 3.7253e-10,
         4.2337e-11, 1.1625e-08, 2.7244e-11, 2.8581e-11],
        [5.1202e-21, 4.7018e-18, 1.0000e+00, 1.2015e-14, 1.1558e-28, 3.2080e-27,
         1.0040e-30, 9.2622e-16, 5.3629e-18, 1.2785e-30],
        [4.0742e-08, 2.3204e-14, 2.5309e-09, 2.1040e-16, 1.0000e+00, 3.0835e-15,
         1.6146e-07, 4.3100e-11, 3.1262e-17, 1.8524e-13],
        [2.0858e-06, 9.9964e-01, 2.7611e-04, 3.2218e-06, 7.2753e-06, 1.8211e-06,
         9.3899e-06, 4.8734e-05, 1.3888e-05, 7.7949e-08],
        [3.2894e-18, 8.4567

tensor([[2.0147e-23, 9.8141e-26, 1.7807e-22, 5.8651e-28, 1.0000e+00, 3.7311e-23,
         7.5194e-20, 4.8265e-19, 2.3951e-24, 2.5606e-17],
        [3.2659e-21, 6.5491e-25, 5.3660e-29, 9.2284e-13, 1.5857e-20, 1.0000e+00,
         3.6198e-17, 7.9047e-22, 1.2343e-14, 1.7683e-11],
        [2.7388e-19, 1.5809e-21, 2.8416e-27, 1.5035e-10, 4.1015e-19, 1.0000e+00,
         7.2890e-15, 1.6917e-20, 1.8439e-16, 1.1428e-13],
        [5.5752e-10, 1.0000e+00, 1.0090e-09, 3.8007e-13, 1.2522e-06, 4.3438e-12,
         3.8889e-09, 2.3689e-07, 1.0426e-10, 4.0672e-13],
        [2.6403e-28, 5.7725e-29, 1.6015e-29, 1.8575e-13, 2.3353e-27, 1.0000e+00,
         7.2277e-16, 1.9945e-32, 3.1030e-12, 3.9492e-25],
        [2.6431e-13, 3.2836e-17, 2.6195e-14, 1.5944e-11, 4.7835e-06, 2.2754e-10,
         1.0899e-17, 5.6516e-10, 5.2061e-11, 1.0000e+00],
        [7.7920e-14, 7.4216e-16, 2.6356e-11, 1.2768e-09, 4.4694e-20, 2.2935e-10,
         2.6689e-17, 1.9058e-16, 1.0000e+00, 3.5640e-12],
        [1.4535e-14, 3.8858

tensor([[1.0000e+00, 1.8377e-14, 8.9758e-09, 3.9864e-15, 8.7915e-12, 5.3064e-11,
         3.7470e-10, 9.0147e-08, 7.7125e-12, 9.9962e-12],
        [2.2428e-11, 1.6323e-22, 2.5764e-15, 2.0288e-20, 6.9505e-20, 6.3449e-05,
         9.9994e-01, 3.2217e-26, 1.1869e-09, 2.1027e-22],
        [2.3688e-08, 7.1063e-10, 1.1748e-01, 6.1460e-06, 2.5065e-16, 8.8419e-13,
         6.8605e-22, 8.8250e-01, 1.3901e-05, 7.9589e-10],
        [7.4492e-24, 2.3954e-31, 2.2261e-30, 2.5417e-16, 2.2534e-29, 1.0000e+00,
         1.3127e-21, 1.0746e-26, 6.5187e-14, 3.2404e-16],
        [3.5815e-11, 1.5655e-17, 5.1979e-11, 9.7821e-14, 1.2588e-16, 2.2216e-11,
         7.8325e-16, 1.4055e-15, 1.0000e+00, 1.6928e-11],
        [1.0000e+00, 1.9206e-23, 1.5500e-15, 1.0151e-22, 9.2040e-18, 5.5336e-14,
         2.1770e-12, 5.0355e-14, 2.2716e-20, 1.7311e-16],
        [3.4488e-24, 4.4996e-20, 3.2188e-17, 1.0000e+00, 3.2973e-26, 8.2781e-10,
         1.1574e-26, 1.1420e-18, 1.8075e-10, 3.6928e-19],
        [3.8262e-19, 2.2711

tensor([[1.2929e-16, 8.6222e-13, 4.4434e-16, 4.5270e-19, 1.0000e+00, 9.7933e-18,
         4.9587e-14, 4.9098e-12, 4.4937e-12, 1.6384e-13],
        [1.5875e-17, 8.8672e-19, 9.1043e-19, 7.7761e-23, 1.0000e+00, 8.5759e-18,
         1.3340e-14, 2.8808e-13, 1.1326e-13, 2.1871e-10],
        [4.4959e-12, 4.0776e-10, 1.4866e-08, 6.6761e-06, 2.8667e-13, 1.6129e-07,
         1.7461e-11, 2.9777e-12, 9.9999e-01, 1.1779e-11],
        [1.0372e-11, 4.8917e-17, 1.1789e-13, 6.3287e-06, 1.0583e-15, 1.3032e-05,
         1.9581e-14, 4.1093e-16, 9.9998e-01, 2.5041e-08],
        [7.2038e-18, 1.0630e-17, 1.0000e+00, 1.9476e-09, 4.6823e-29, 8.6753e-21,
         2.4811e-28, 2.3129e-16, 9.3865e-12, 1.6699e-27],
        [3.3309e-12, 3.4540e-16, 1.4116e-13, 1.2245e-09, 4.1196e-06, 7.9301e-08,
         9.1351e-16, 2.3732e-08, 2.9677e-10, 1.0000e+00],
        [3.7093e-21, 3.7900e-25, 1.0774e-25, 2.7289e-07, 3.9789e-24, 1.0000e+00,
         1.6105e-19, 4.9545e-22, 1.4694e-11, 2.0901e-12],
        [1.1895e-27, 5.1440

In [None]:
with torch.no_grad(): # torch.no_grad()를 하면 gradient 계산을 수행하지 않는다.

    # MNIST 테스트 데이터에서 무작위로 하나를 뽑아서 예측을 해본다
    r = random.randint(0, len(val_loader) - 1)
    X_single_data = val_loader[r:r + 1][0].view(-1, 28 * 28).float().to(device)
    Y_single_data = val_loader[r:r + 1][1]

    print('Label: ', Y_single_data.item())
    single_prediction = cnn(X_single_data)
    print('Prediction: ', torch.argmax(single_prediction, 1).item())

    plt.imshow(x_test[r:r + 1].view(28, 28), cmap='Greys', interpolation='nearest')
    plt.show()