# Structure 이해

In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.Tensor(1, 1, 28, 28)   # batch size, channel, height, width
print(f"tensor size: {inputs.shape}")

tensor size: torch.Size([1, 1, 28, 28])


In [3]:
conv1 = nn.Conv2d(1, 32, 3, padding=1) # 1channel을 입력 받아 32channel을 뽑아냄.
                                      # kernel size = 3, padding = 1
print(conv1)

Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [4]:
conv2 = nn.Conv2d(32, 64, 3, padding=1)
print(conv2)

Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [5]:
pool = nn.MaxPool2d(2)   # 정수 하나를 인자로 넣으면 kernel size, stride가 둘 다 해당값으로 지정됨
print(pool)

MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


In [6]:
out = conv1(inputs)
print(out.shape)

torch.Size([1, 32, 28, 28])


In [7]:
out = pool(out)
print(out.shape)

torch.Size([1, 32, 14, 14])


In [8]:
out = conv2(out)
print(out.shape)

torch.Size([1, 64, 14, 14])


In [9]:
out = pool(out)
print(out.shape)

torch.Size([1, 64, 7, 7])


## view()로 tensor를 펼침

In [10]:
out = out.view(out.size(0), -1)
print(out.shape)

torch.Size([1, 3136])


In [11]:
#input_dim = 3136, output_dim = 10
fc = nn.Linear(out.shape[1], 10)
out = fc(out)
print(out.shape)

torch.Size([1, 10])


# CNN으로 MNIST 분류하기

In [12]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
from tqdm import tqdm_notebook

In [13]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# torch.manual_seed(77)

# if device == 'cuda':
#     torch.cuda.manual_seed_all(77)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(77)
if device == 'cuda':
    torch.cuda.manual_seed_all(77)
print(f"uning: {device}")



uning: cuda


In [14]:
learning_rate = 1e-3
training_epochs = 200
batch_size = 100

In [15]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import os

def get_alphabet(root: str, batch_size: int):
    
    train_path = os.path.join(root, 'train')
    test_path = os.path.join(root, 'test')
    
    alphabet_train1 = ImageFolder(root = train_path,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
#                                      transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
                                     transforms.Grayscale(1),
                                     transforms.RandomRotation(5),
                                     transforms.RandomInvert()
                                 ]),
                                 target_transform=None)
    
    alphabet_train2 = ImageFolder(root = train_path,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
#                                     transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
                                     transforms.Grayscale(1)
                                 ]),
                                 target_transform=None)
    
    alphabet_train3 = ImageFolder(root = train_path,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
#                                     transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
                                     transforms.Grayscale(1),
                                     transforms.CenterCrop(20),
                                     transforms.Resize(28)
                                 ]),
                                 target_transform=None)
    
    
    
    alphabet_test = ImageFolder(root = test_path,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
#                                      transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
                                     transforms.Grayscale(1)
                                 ]),
                                 target_transform=None)
    
    train_loader = DataLoader(alphabet_train1,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=8)

    test_loader = DataLoader(alphabet_test,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=False,
                             num_workers=8) 
    
    return (train_loader, test_loader)

dset_root = './data'
train_loader, test_loader = get_alphabet(root = dset_root, batch_size = batch_size)

In [16]:
# mnist_train = dsets.MNIST(root = 'MNIST_data/',
#                           train=True, # train data로 download
#                           transform=transforms.ToTensor(),
#                           download=True)

# mnist_test = dsets.MNIST(root = 'MNIST_data/',
#                           train=False, # train data로 download
#                           transform=transforms.ToTensor(),
#                           download=True)

# data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
#                                           batch_size=batch_size,
#                                           shuffle=True,
#                                           drop_last=True)


## Class model 설계

In [17]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Layer 1
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        
        # Layer 2 
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        
        # 전결합층 7x7x64 inputs -> 10 outputs
        self.fc = torch.nn.Linear(7*7*64, 26, bias=True)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1) # flatten for 전결합층
        out = self.fc(out)
        return out

In [18]:
# class CNN(torch.nn.Module):
#     def __init__(self, input_size, num_classes):
#         super(CNN, self).__init__()
#         # Layer 1
#         self.layer1 = torch.nn.Sequential(
#             torch.nn.Conv2d(input_size[0], 32, kernel_size=5),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool2d(kernel_size=2))
        
#         # Layer 2 
#         self.layer2 = torch.nn.Sequential(
#             torch.nn.Conv2d(32, 64, kernel_size=5),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool2d(kernel_size=2))
        
#         # 전결합층 7x7x64 inputs -> 10 outputs
#         self.fc = torch.nn.Linear(4*4*64, num_classes)
        
#     def forward(self, x):
#         out = self.layer1(x)
#         out = self.layer2(out)
#         out = out.view(out.size(0), -1) # flatten for 전결합층
#         out = self.fc(out)
#         return out

In [19]:
model = CNN().to(device)

# criterion = torch.nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr= learning_rate)

# total_batch = len(data_loader)
# print(f"총 배치의 수: {total_batch}")
# model = CNN((1, 28, 28), 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
criterion = torch.nn.CrossEntropyLoss().to(device)  # loss function
# train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=64, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=64, shuffle=True)


In [20]:
# for epoch in range(training_epochs):
#     avg_cost = 0

#     for X, Y in data_loader: # 미니 배치 단위로 꺼내온다. X는 미니 배치, Y느 ㄴ레이블.
#         # image is already size of (28x28), no reshape
#         # label is not one-hot encoded
#         X = X.to(device)
#         Y = Y.to(device)

#         optimizer.zero_grad()
#         hypothesis = model(X)
#         cost = criterion(hypothesis, Y)
#         cost.backward()
#         optimizer.step()

#         avg_cost += cost / total_batch

#     print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))
import numpy as np

for epoch in range(training_epochs):
    train_loss = []
    
    for i, (data, labels) in tqdm_notebook(enumerate(train_loader), total=len(train_loader)):
        # pass data through network
        data = data.to(device)
        labels = labels.to(device)
        
        outputs = model(data)
        
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        
    test_loss = []
    test_accuracy = []
    
    for i, (data, labels) in enumerate(test_loader):
        data = data.to(device)
        labels = labels.to(device)
        
        # pass data through network
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)
        test_loss.append(loss.item())
        test_accuracy.append((predicted == labels).sum().item() / predicted.size(0))
    print('epoch: {}, train loss: {}, test loss: {}, test accuracy: {}'.format(epoch, np.mean(train_loss), np.mean(test_loss), np.mean(test_accuracy)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, (data, labels) in tqdm_notebook(enumerate(train_loader), total=len(train_loader)):


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 0, train loss: 2.278970491141081, test loss: 1.4254045373988602, test accuracy: 0.610048094709582


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 1, train loss: 1.182452634908259, test loss: 1.0180785869652371, test accuracy: 0.7383388827229004


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 2, train loss: 0.8821254253853112, test loss: 0.8152954229768717, test accuracy: 0.7891009988901221


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 3, train loss: 0.7206474585691467, test loss: 0.6980510716730693, test accuracy: 0.8213725490196078


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 4, train loss: 0.6206999785499647, test loss: 0.5970687354510685, test accuracy: 0.8464520902700704


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 5, train loss: 0.5385069668991491, test loss: 0.5787141415870415, test accuracy: 0.8519607843137256


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 6, train loss: 0.48906874092062935, test loss: 0.49728849783258616, test accuracy: 0.8749352571217167


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 7, train loss: 0.44474262348376215, test loss: 0.5038096733250708, test accuracy: 0.8670033296337403


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 8, train loss: 0.40517922438448295, test loss: 0.44090539723072414, test accuracy: 0.889652238253792


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 9, train loss: 0.38055717351380736, test loss: 0.5100807856838658, test accuracy: 0.8549500554938957


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 10, train loss: 0.3584738585050218, test loss: 0.4586713489074752, test accuracy: 0.878128005919349


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 11, train loss: 0.33597224950790405, test loss: 0.3852908730085166, test accuracy: 0.9073954864964853


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 12, train loss: 0.3102787448442541, test loss: 0.3915001520851873, test accuracy: 0.8998409174990751


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 13, train loss: 0.2938103074848186, test loss: 0.38887585783904455, test accuracy: 0.9022715501294856


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 14, train loss: 0.28042518271831796, test loss: 0.35772002741413295, test accuracy: 0.9137883832778393


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 15, train loss: 0.25938807043712586, test loss: 0.3378192539203842, test accuracy: 0.9198335183129854


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 16, train loss: 0.2519422925252002, test loss: 0.3440965657245438, test accuracy: 0.9151239363669997


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 17, train loss: 0.23510391332092695, test loss: 0.32222004487829387, test accuracy: 0.9228745837957824


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 18, train loss: 0.22271656271186657, test loss: 0.33806686054142016, test accuracy: 0.9175767665556788


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 19, train loss: 0.2139279845287092, test loss: 0.31671599158138597, test accuracy: 0.9262634110247872


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 20, train loss: 0.1987036254722625, test loss: 0.3078342544020347, test accuracy: 0.9322715501294856


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 21, train loss: 0.19640080991666764, test loss: 0.2984348773112837, test accuracy: 0.9317277099519055


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 22, train loss: 0.18518458974722307, test loss: 0.2973369349778261, test accuracy: 0.9309655937846836


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 23, train loss: 0.18263344866863918, test loss: 0.31188835966277795, test accuracy: 0.923026267110618


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 24, train loss: 0.1760610135679599, test loss: 0.28598873795203444, test accuracy: 0.932090270070292


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 25, train loss: 0.1615736328822095, test loss: 0.2725267971883405, test accuracy: 0.939071402145764


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 26, train loss: 0.15500338685524184, test loss: 0.28380776238891314, test accuracy: 0.9347317795042543


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 27, train loss: 0.15227234446501825, test loss: 0.2891102029748683, test accuracy: 0.9343470218275989


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 28, train loss: 0.14462634567462374, test loss: 0.2747350694998255, test accuracy: 0.9388901220865703


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 29, train loss: 0.13591439322044607, test loss: 0.2839995421258346, test accuracy: 0.9377432482426932


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 30, train loss: 0.13584069831995293, test loss: 0.26838324992161877, test accuracy: 0.9413281539030706


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 31, train loss: 0.13325626259029377, test loss: 0.27194586775775226, test accuracy: 0.9401960784313725


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 32, train loss: 0.1284206980635645, test loss: 0.2857752760766812, test accuracy: 0.9341657417684054


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 33, train loss: 0.12313131084374618, test loss: 0.3158559098066586, test accuracy: 0.9235849056603773


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 34, train loss: 0.11409244295646204, test loss: 0.255696672137897, test accuracy: 0.938867924528302


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 35, train loss: 0.10936388403933961, test loss: 0.26825708994325603, test accuracy: 0.9401960784313727


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 36, train loss: 0.10598026590741938, test loss: 0.26435408375735553, test accuracy: 0.9426563078061413


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 37, train loss: 0.11084683326043887, test loss: 0.25032814831103917, test accuracy: 0.9434036256011838


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 38, train loss: 0.10469670985185076, test loss: 0.2515784175539354, test accuracy: 0.9449352571217167


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 39, train loss: 0.09415685747080715, test loss: 0.2705655502120279, test accuracy: 0.9422863485016646


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 40, train loss: 0.09778907210420584, test loss: 0.2670445556769956, test accuracy: 0.9411394746577876


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 41, train loss: 0.09532514273450943, test loss: 0.255988754955876, test accuracy: 0.9479319274879761


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 42, train loss: 0.09120835617068224, test loss: 0.24971406249645747, test accuracy: 0.9479245283018867


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 43, train loss: 0.08623306801382569, test loss: 0.24810741107278275, test accuracy: 0.9500221975582686


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 44, train loss: 0.08706604799226625, test loss: 0.2655652918202697, test accuracy: 0.9460599334073251


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 45, train loss: 0.08602234622958349, test loss: 0.260386068865938, test accuracy: 0.9460599334073252


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 46, train loss: 0.07414791249175323, test loss: 0.2890315780697285, test accuracy: 0.9377506474287828


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 47, train loss: 0.07914447903749533, test loss: 0.24338041880291025, test accuracy: 0.94944876063633


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 48, train loss: 0.07781841953692492, test loss: 0.23377398854859596, test accuracy: 0.9522863485016649


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 49, train loss: 0.07361426027637208, test loss: 0.2403121418934667, test accuracy: 0.949826119126896


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 50, train loss: 0.07578586288218503, test loss: 0.23715195106342435, test accuracy: 0.9515094339622642


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 51, train loss: 0.06682936704601161, test loss: 0.2619341637078181, test accuracy: 0.9492526822049574


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 52, train loss: 0.06856890411654604, test loss: 0.24664999386471398, test accuracy: 0.9503921568627451


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 53, train loss: 0.07321789921479649, test loss: 0.30528669065756303, test accuracy: 0.9386940436551978


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 54, train loss: 0.06776748918491649, test loss: 0.2565840860512459, test accuracy: 0.9483166851646319


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 55, train loss: 0.06697519429508247, test loss: 0.2649515301410882, test accuracy: 0.9485127635960043


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 56, train loss: 0.06643646663360414, test loss: 0.2345760462559619, test accuracy: 0.953795782463929


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 57, train loss: 0.05723041088276659, test loss: 0.2377039070926466, test accuracy: 0.9530262671106178


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 58, train loss: 0.05933354583248729, test loss: 0.22761121021478242, test accuracy: 0.9556603773584906


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 59, train loss: 0.053782315295393346, test loss: 0.24414812588958806, test accuracy: 0.9554716981132075


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 60, train loss: 0.056419474862195784, test loss: 0.24567872475621835, test accuracy: 0.9543544210136886


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 61, train loss: 0.061067968446877785, test loss: 0.24454989535558336, test accuracy: 0.9562412134665188


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 62, train loss: 0.05498525138682453, test loss: 0.23599742677168184, test accuracy: 0.9571846096929337


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 63, train loss: 0.05421490788467054, test loss: 0.235870645486943, test accuracy: 0.9562412134665187


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 64, train loss: 0.051218954195064725, test loss: 0.23685703704239064, test accuracy: 0.9560525342212357


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 65, train loss: 0.04850316252668563, test loss: 0.24096870715158797, test accuracy: 0.9588679245283019


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 66, train loss: 0.04873297709855251, test loss: 0.2503449165104133, test accuracy: 0.9560377358490565


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 67, train loss: 0.05099702746701951, test loss: 0.2586511461380518, test accuracy: 0.9530262671106178


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 68, train loss: 0.05045957354559505, test loss: 0.2679034973371704, test accuracy: 0.9503921568627451


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 69, train loss: 0.04972682412517315, test loss: 0.2335468009070335, test accuracy: 0.9594413614502405


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 70, train loss: 0.04466723826226371, test loss: 0.2636502288844226, test accuracy: 0.9539770625231223


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 71, train loss: 0.04687323447978997, test loss: 0.24321112054277141, test accuracy: 0.9556751757306697


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 72, train loss: 0.042611193684933824, test loss: 0.23836996323727774, test accuracy: 0.9588679245283017


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 73, train loss: 0.04861673474806594, test loss: 0.2523574744316064, test accuracy: 0.9539770625231223


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 74, train loss: 0.04443658873424283, test loss: 0.23363704679814992, test accuracy: 0.96


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 75, train loss: 0.04328520911803935, test loss: 0.2400641521161097, test accuracy: 0.9575471698113206


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 76, train loss: 0.0412435258149344, test loss: 0.2589440975021922, test accuracy: 0.9535997040325563


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 77, train loss: 0.04882350947991654, test loss: 0.2285222671991917, test accuracy: 0.958301886792453


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 78, train loss: 0.0436549421356176, test loss: 0.2533945152404244, test accuracy: 0.9530188679245284


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 79, train loss: 0.04441111107735196, test loss: 0.2564605785850084, test accuracy: 0.9543396226415093


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 80, train loss: 0.04119394401823229, test loss: 0.22196319449763252, test accuracy: 0.960377358490566


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 81, train loss: 0.04654137351008103, test loss: 0.24490932235494256, test accuracy: 0.9552830188679244


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 82, train loss: 0.036659837577190046, test loss: 0.25201003540873984, test accuracy: 0.9539622641509434


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 83, train loss: 0.03918722451680878, test loss: 0.23477827210005653, test accuracy: 0.9581132075471697


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 84, train loss: 0.040169334447455185, test loss: 0.26291549931069447, test accuracy: 0.9530188679245284


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 85, train loss: 0.04225068063897197, test loss: 0.25963062559526834, test accuracy: 0.9543544210136887


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 86, train loss: 0.03595276192845631, test loss: 0.22378292022589244, test accuracy: 0.9601960784313726


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 87, train loss: 0.036754316616679716, test loss: 0.2498211286224003, test accuracy: 0.959811320754717


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 88, train loss: 0.040835780730958504, test loss: 0.252263536606476, test accuracy: 0.9575545689974103


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 89, train loss: 0.035597245558165014, test loss: 0.2433133329427362, test accuracy: 0.96188679245283


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 90, train loss: 0.03436433027309249, test loss: 0.2665002911334049, test accuracy: 0.9516981132075472


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 91, train loss: 0.03603512511745066, test loss: 0.23858323905019546, test accuracy: 0.9571698113207547


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 92, train loss: 0.03546076597922365, test loss: 0.25983335936600166, test accuracy: 0.9549204587495376


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 93, train loss: 0.03309651208110154, test loss: 0.23620776362049412, test accuracy: 0.9583166851646319


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 94, train loss: 0.04280213437232305, test loss: 0.25347849472320444, test accuracy: 0.9571846096929337


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 95, train loss: 0.031853359701017325, test loss: 0.23890311247229856, test accuracy: 0.9613281539030706


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 96, train loss: 0.0294314913985545, test loss: 0.23512680443741804, test accuracy: 0.9624676285608582


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 97, train loss: 0.029641869949955435, test loss: 0.25710077104949447, test accuracy: 0.9575545689974103


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 98, train loss: 0.035585934120717866, test loss: 0.23397334394449332, test accuracy: 0.9626489086200517


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 99, train loss: 0.031636967179565545, test loss: 0.25235935838515255, test accuracy: 0.9549204587495377


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 100, train loss: 0.03352320366184358, test loss: 0.23913566137165412, test accuracy: 0.9620976692563816


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 101, train loss: 0.03193865373168592, test loss: 0.25573357716334527, test accuracy: 0.9584979652238255


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 102, train loss: 0.0335919507433573, test loss: 0.2615337030443732, test accuracy: 0.9547243803181649


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 103, train loss: 0.027183026570128277, test loss: 0.23819247128898804, test accuracy: 0.96188679245283


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 104, train loss: 0.028775879104159685, test loss: 0.24459379654869717, test accuracy: 0.960188679245283


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 105, train loss: 0.03489092453492049, test loss: 0.2509597185120549, test accuracy: 0.9586866444691081


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 106, train loss: 0.0330670990915678, test loss: 0.24908026574038952, test accuracy: 0.9575471698113207


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 107, train loss: 0.03454550588321581, test loss: 0.2653472497532109, test accuracy: 0.9547317795042546


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 108, train loss: 0.029668063577446446, test loss: 0.24540357832919876, test accuracy: 0.9605734369219386


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 109, train loss: 0.0326647682904877, test loss: 0.2683048948946514, test accuracy: 0.9552830188679244


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 110, train loss: 0.028535244183785835, test loss: 0.2508211492841958, test accuracy: 0.9590788013318534


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 111, train loss: 0.028454997122480563, test loss: 0.26462147605302305, test accuracy: 0.9573806881243062


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 112, train loss: 0.02756647989690464, test loss: 0.23681906697270022, test accuracy: 0.9641583425823157


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 113, train loss: 0.02781578434542098, test loss: 0.24477325680331802, test accuracy: 0.9596374398816131


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 114, train loss: 0.028944750660230056, test loss: 0.2558660987052926, test accuracy: 0.9571772105068442


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 115, train loss: 0.02905341434006914, test loss: 0.26198795064405167, test accuracy: 0.9587014428412873


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 116, train loss: 0.03165708962160352, test loss: 0.25533206301235006, test accuracy: 0.9577358490566037


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 117, train loss: 0.028736694962844922, test loss: 0.22881469061106163, test accuracy: 0.9609507954125045


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 118, train loss: 0.027768417053493977, test loss: 0.2454965624448685, test accuracy: 0.9590566037735849


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 119, train loss: 0.030102157522833295, test loss: 0.24959837661107193, test accuracy: 0.9575545689974101


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 120, train loss: 0.02792766750053488, test loss: 0.2420336155062717, test accuracy: 0.9590640029596744


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 121, train loss: 0.021240493448431153, test loss: 0.28042693875009583, test accuracy: 0.9558564557898629


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 122, train loss: 0.029241852227642084, test loss: 0.26594627380617103, test accuracy: 0.9560599334073251


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 123, train loss: 0.02903398894136444, test loss: 0.26386780859375336, test accuracy: 0.9568220495745466


  0%|          | 0/256 [00:00<?, ?it/s]

epoch: 124, train loss: 0.025635577986804492, test loss: 0.23690119696745895, test accuracy: 0.9620902700702925


  0%|          | 0/256 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
# 학습을 진행하지 않을 것이므로 torch.no_grad()
with torch.no_grad():
#     X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
#     Y_test = mnist_test.test_labels.to(device)

    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())

NameError: name 'mnist_test' is not defined