In [1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import numpy as np

import torchvision
import torchvision.transforms as transforms

# 5. Optim & Criterion

- Loss (손실 계산)
- Update (parameter 업데이트) => torch.optim.

- 순서

    - Import
    - Dataset 만들기
    - Model 만들기
    - Optim, Loss 함수 결정하기
        - Optimizer : parameter를 넣어주어야 함. model 클래스의 model.parameters()
    - 학습을 위한 반복문 작성
    - 모델 저장

In [2]:
class my_network(nn.Module):
    
    def __init__(self):
        
        # 무조건 써야 함.
        super(my_network, self).__init__()
        
        # 사용할 함수들을 정의
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 30, 5)
        
        self.fc1 = nn.Linear(30*5*5, 128)
        self.fc2 = nn.Linear(128, 10)
        
    # input이 여러개 일 경우도 있음. x, y, 등등..
    def forward(self, x):
        
        # Network의 forward를 정의하는 장소
        x = F.relu(self.conv1(x), inplace=True)
        x = F.max_pool2d(x, (2,2))
        x = F.relu(self.conv2(x), inplace=True)
        x = F.max_pool2d(x, (2,2))
        
        # FC에 집어넣을 때는, 일렬 형태에 집어넣어야 함
        # 30 * 5 * 5 크기의 데이터에서 배치사이즈만큼 놓고
        # flatten 시켜주는 부분임.
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)
        
        return x
    

In [3]:
trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root="./data",
                                       train=True,
                                       download=True,
                                       transform=trans)

testset = torchvision.datasets.CIFAR10(root="./data",
                                      train=False,
                                      download=True,
                                      transform=trans)

trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
my_net = my_network()

In [5]:
optim = torch.optim.SGD(my_net.parameters(), lr = 0.001, momentum = 0.9)
loss_function = nn.CrossEntropyLoss()

In [7]:
epoch_num = 3

for epoch in range(epoch_num):
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        
        optim.zero_grad()
        out = my_net(inputs)
        loss = loss_function(out, labels)
        loss.backward()
        optim.step()
        
        if i % 64 == 0:
            print("%d => loss : %.3f " % (i, loss))
print("train over")

0 => loss : 1.989 
64 => loss : 2.047 
128 => loss : 2.149 
192 => loss : 1.804 
256 => loss : 2.250 
320 => loss : 2.339 
384 => loss : 1.809 
448 => loss : 1.541 
512 => loss : 2.195 
576 => loss : 2.117 
640 => loss : 1.966 
704 => loss : 1.467 
768 => loss : 1.589 
832 => loss : 1.924 
896 => loss : 1.502 
960 => loss : 1.456 
1024 => loss : 1.996 
1088 => loss : 1.430 
1152 => loss : 1.091 
1216 => loss : 1.566 
1280 => loss : 1.536 
1344 => loss : 1.719 
1408 => loss : 1.039 
1472 => loss : 1.925 
1536 => loss : 2.064 
1600 => loss : 1.202 
1664 => loss : 0.997 
1728 => loss : 2.166 
1792 => loss : 1.656 
1856 => loss : 1.952 
1920 => loss : 1.817 
1984 => loss : 1.514 
2048 => loss : 1.668 
2112 => loss : 2.104 
2176 => loss : 1.824 
2240 => loss : 1.091 
2304 => loss : 2.254 
2368 => loss : 0.965 
2432 => loss : 1.667 
2496 => loss : 1.772 
2560 => loss : 1.838 
2624 => loss : 1.879 
2688 => loss : 1.480 
2752 => loss : 1.705 
2816 => loss : 1.550 
2880 => loss : 2.261 
2944 =>

In [9]:
total = 0
correct = 0
for data in testloader:
    images, labels = data
    outputs = my_net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print("Accuracy of the network on the 10000 test images: %f" % (100 * correct / total))

Accuracy of the network on the 10000 test images: 66.000000


