In [0]:
import numpy as np
import os

In [0]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms, datasets


In [0]:
lr = 1e-3
batch_size = 64
num_epoch = 10

ckpt_dir = './chekpoint'
log_dir = './log'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, padding=0, bias=True)
    self.pool1 = nn.MaxPool2d(kernel_size=2)
    self.relu1 = nn.ReLU()

    self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=0, bias=True)
    self.drop2 = nn.Dropout2d(p=0.5)
    self.pool2 = nn.MaxPool2d(kernel_size=2)
    self.relu2 = nn.ReLU()

    self.fc1 = nn.Linear(in_features=320, out_features=50, bias=True)
    self.relu1_fc1 = nn.ReLU()
    self.drop1_fc1 = nn.Dropout(p=0.5)

    self.fc2 = nn.Linear(in_features=50, out_features=10, bias=True)

  def forward(self, x):
    x = self.conv1(x)
    x = self.pool1(x)
    x = self.relu1(x)

    x = self.conv2(x)
    x = self.drop2(x)
    x = self.pool2(x)
    x = self.relu2(x)

    x = x.view(-1,320)

    x = self.fc1(x)
    x = self.relu1_fc1(x)
    x = self.drop1_fc1(x)

    x = self.fc2(x)
  
    return x

In [0]:
def save(ckpt_dir, net, optim, epoch):
  if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
  
  torch.save({'net': net.state_dict(), 'optim' : optim.state_dict()},
             './%s/model_epoch%d.pth' % (ckpt_dir, epoch))
  
def load(ckpt_dir, net, optim):
  ckpt_lst = os.listdir(ckpt_dir)
  ckpt_lst.sort()

  dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))

  net.load_state_dict(dict_model['net'])
  optim.load_state_dict(dict_model['optim'])

  return net, optim

In [0]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,),std=(0.5,))])

dataset = datasets.MNIST(download=True, root='./', train=True, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

num_data = len(loader.dataset)
num_batch = np.ceil((num_data/batch_size))

In [0]:
net = Net().to(device)
params = net.parameters()

fn_loss = nn.CrossEntropyLoss().to(device)
fn_pred = lambda output: torch.softmax(output, dim=1)
fn_acc = lambda pred, label: ((pred.max(dim=1)[1] == label).type(torch.float)).mean() 

optim = torch.optim.Adam(params, lr=lr)

writer = SummaryWriter(log_dir=log_dir)

In [29]:
for epoch in range(1, num_epoch + 1):
  net.train()

  loss_arr = []
  acc_arr = []

  for batch, (input, label) in enumerate(loader, 1): #index start:1
    input = input.to(device)
    label = label.to(device)

    output = net(input)
    pred = fn_pred(output)

    optim.zero_grad()

    loss = fn_loss(output, label)
    acc = fn_acc(pred, label)

    loss.backward()

    optim.step()

    loss_arr += [loss.item()]
    acc_arr += [acc.item()]

    print('TRAIN: Epoch %04d/%04d | Batch %04d/%04d | Loss: %.4f | Acc: %.4f' 
          %(epoch, num_epoch, batch, num_batch, loss, acc))
    
  writer.add_scalar('loss', np.mean(loss_arr), epoch)
  writer.add_scalar('acc', np.mean(acc_arr), epoch)

  save(ckpt_dir = ckpt_dir, net = net, optim = optim, epoch = epoch)

writer.close()

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
TRAIN: Epoch 0005/0010 | Batch 0630/0938 | Loss: 0.1535 | Acc: 0.9531
TRAIN: Epoch 0005/0010 | Batch 0631/0938 | Loss: 0.1392 | Acc: 0.9531
TRAIN: Epoch 0005/0010 | Batch 0632/0938 | Loss: 0.2082 | Acc: 0.9062
TRAIN: Epoch 0005/0010 | Batch 0633/0938 | Loss: 0.1103 | Acc: 0.9688
TRAIN: Epoch 0005/0010 | Batch 0634/0938 | Loss: 0.3950 | Acc: 0.8906
TRAIN: Epoch 0005/0010 | Batch 0635/0938 | Loss: 0.1710 | Acc: 0.9531
TRAIN: Epoch 0005/0010 | Batch 0636/0938 | Loss: 0.1108 | Acc: 0.9688
TRAIN: Epoch 0005/0010 | Batch 0637/0938 | Loss: 0.1539 | Acc: 0.9688
TRAIN: Epoch 0005/0010 | Batch 0638/0938 | Loss: 0.1008 | Acc: 0.9844
TRAIN: Epoch 0005/0010 | Batch 0639/0938 | Loss: 0.0874 | Acc: 0.9844
TRAIN: Epoch 0005/0010 | Batch 0640/0938 | Loss: 0.2152 | Acc: 0.9219
TRAIN: Epoch 0005/0010 | Batch 0641/0938 | Loss: 0.1221 | Acc: 0.9688
TRAIN: Epoch 0005/0010 | Batch 0642/0938 | Loss: 0.2489 | Acc: 0.9219
TRAIN: Epoch 0005/0010 | Batch 0643/0938

In [35]:
eval_dataset = datasets.MNIST(download=True, root='./', train=False, transform=transform)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

num_data = len(eval_loader.dataset)
num_batch = np.ceil((num_data/batch_size))

net, optim = load(ckpt_dir=ckpt_dir, net=net, optim=optim)

with torch.no_grad():
  net.eval()

  loss_arr = []
  acc_arr = []

  for batch, (input, label) in enumerate(eval_loader, 1): #index start:1
    input = input.to(device)
    label = label.to(device)

    output = net(input)
    pred = fn_pred(output)

    loss = fn_loss(output, label)
    acc = fn_acc(pred, label)

    loss_arr += [loss.item()]
    acc_arr += [acc.item()]

    print('TEST: Batch %04d/%04d | Loss: %.4f | Acc: %.4f' 
          %( batch, num_batch, loss, acc))

TEST: Batch 0001/0157 | Loss: 0.0132 | Acc: 1.0000
TEST: Batch 0002/0157 | Loss: 0.0122 | Acc: 1.0000
TEST: Batch 0003/0157 | Loss: 0.0070 | Acc: 1.0000
TEST: Batch 0004/0157 | Loss: 0.0208 | Acc: 0.9844
TEST: Batch 0005/0157 | Loss: 0.0272 | Acc: 0.9844
TEST: Batch 0006/0157 | Loss: 0.0410 | Acc: 0.9844
TEST: Batch 0007/0157 | Loss: 0.0870 | Acc: 0.9844
TEST: Batch 0008/0157 | Loss: 0.0917 | Acc: 0.9844
TEST: Batch 0009/0157 | Loss: 0.0147 | Acc: 1.0000
TEST: Batch 0010/0157 | Loss: 0.1100 | Acc: 0.9688
TEST: Batch 0011/0157 | Loss: 0.0431 | Acc: 0.9844
TEST: Batch 0012/0157 | Loss: 0.0572 | Acc: 0.9688
TEST: Batch 0013/0157 | Loss: 0.0052 | Acc: 1.0000
TEST: Batch 0014/0157 | Loss: 0.0141 | Acc: 1.0000
TEST: Batch 0015/0157 | Loss: 0.0930 | Acc: 0.9688
TEST: Batch 0016/0157 | Loss: 0.1289 | Acc: 0.9688
TEST: Batch 0017/0157 | Loss: 0.0784 | Acc: 0.9688
TEST: Batch 0018/0157 | Loss: 0.0164 | Acc: 1.0000
TEST: Batch 0019/0157 | Loss: 0.0319 | Acc: 1.0000
TEST: Batch 0020/0157 | Loss: 0