## ResNet을 이용한 이미지 분류
- 합성곱 신경망 구조
- VGG: 이미지 분류를 위해 제안된 CNN의 일종
- CNN의 기울기 소실 문제 -> ResNet(스킵 커넥션 도입)
    - vgg에 비해 오버피팅이 일어나기 쉬움
    - 가중치 증가 -> 계산량 증가
- 이미지 분류, 세그멘테이션, 이미지 생성

In [1]:
import tqdm
import torch
import torch.nn as nn

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms import RandomHorizontalFlip, RandomCrop
from torchvision.transforms import Normalize
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam

from model.resNet import ResNet

  from .autonotebook import tqdm as notebook_tqdm


#### 학습

In [2]:
transforms = Compose([
    RandomCrop((32,32), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

training_data = CIFAR10(root="./data", train=True, download=True, transform=transforms)
test_data = CIFAR10(root="./data", train=False, download=True, transform=transforms)

train_loader = DataLoader(training_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet(num_classes=10)
model.to(device)

lr = 1e-2
optim = Adam(model.parameters(), lr=lr)

for epoch in range(1):
  iterator = tqdm.tqdm(train_loader)
  for data, label in iterator:
    optim.zero_grad()
    preds = model(data.to(device))

    loss = nn.CrossEntropyLoss()(preds, label.to(device))
    loss.backward()
    optim.step()

    iterator.set_description(f"epoch:{epoch+1} loss:{loss.item()}")

torch.save(model.state_dict(), "./data/weights/ResNet.pth")

Files already downloaded and verified
Files already downloaded and verified


epoch:1 loss:2.308614730834961: 100%|██████████| 1563/1563 [08:57<00:00,  2.91it/s] 


#### 평가

In [4]:
model.load_state_dict(torch.load("./data/weights/ResNet.pth", map_location=device))

num_corr = 0

with torch.no_grad():
  for data, label in test_loader:

    output = model(data.to(device))
    preds = output.data.max(1)[1]
    corr = preds.eq(label.to(device).data).sum().item()
    num_corr += corr

  print(f"Accuracy:{num_corr/len(test_data)}")

Accuracy:0.1
