#1. Residual block 구현하기

In [1]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
   def __init__(self, in_channels, out_channels, kernel_size=3):
       super(BasicBlock, self).__init__()


       #conv layer 정의
       self.c1 = nn.Conv2d(in_channels, out_channels,
                           kernel_size=kernel_size, padding=1)
       self.c2 = nn.Conv2d(out_channels, out_channels,
                           kernel_size=kernel_size, padding=1)

       self.downsample = nn.Conv2d(in_channels, out_channels,
                                   kernel_size=1)

       #batchnorm 정의
       self.bn1 = nn.BatchNorm2d(num_features=out_channels)
       self.bn2 = nn.BatchNorm2d(num_features=out_channels)

       self.relu = nn.ReLU()
   def forward(self, x):

       #initial input 저장
       x_ = x

       x = self.c1(x)
       x = self.bn1(x)
       x = self.relu(x)
       x = self.c2(x)
       x = self.bn2(x)

       #conv 결과와 input의 채널 수를 맞춰줌
       x_ = self.downsample(x_)

       #conv layer의 결과와 저장해놨던 input을 더해줌
       x += x_
       x = self.relu(x)

       return x


#2. ResNet Model 구현하기

In [2]:
#ResNet-9 만들기

class ResNet(nn.Module):
   def __init__(self, num_classes=10):
       super(ResNet, self).__init__()


       #BasicBlock
       self.b1 = BasicBlock(in_channels=3, out_channels=64)
       self.b2 = BasicBlock(in_channels=64, out_channels=128)
       self.b3 = BasicBlock(in_channels=128, out_channels=256)


       #average pooling 적용
       self.pool = nn.AvgPool2d(kernel_size=2, stride=2)


       self.fc1 = nn.Linear(in_features=4096, out_features=2048)
       self.fc2 = nn.Linear(in_features=2048, out_features=512)
       self.fc3 = nn.Linear(in_features=512, out_features=num_classes)

       self.relu = nn.ReLU()

   def forward(self, x):

       #Basic block과 pooling layer 통과
       x = self.b1(x)
       x = self.pool(x)
       x = self.b2(x)
       x = self.pool(x)
       x = self.b3(x)
       x = self.pool(x)


       #flatten (fc layer)
       x = torch.flatten(x, start_dim=1)

       #예측값 출력
       x = self.fc1(x)
       x = self.relu(x)
       x = self.fc2(x)
       x = self.relu(x)
       x = self.fc3(x)

       return x


#3. CIFAR10 전처리하기

In [3]:
import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms import RandomHorizontalFlip, RandomCrop
from torchvision.transforms import Normalize
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

transforms = Compose([
   RandomCrop((32, 32), padding=4), #random crop
   RandomHorizontalFlip(p=0.5), #randomhorizontal flip
   ToTensor(),
   Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

In [4]:
training_data = CIFAR10(root="./", train=True, download=True, transform=transforms)
test_data = CIFAR10(root="./", train=False, download=True, transform=transforms)

train_loader = DataLoader(training_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

100%|██████████| 170M/170M [00:08<00:00, 20.0MB/s]


In [5]:
#모델 정의
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet(num_classes=10)
model.to(device)

ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [6]:
#lr, optimizer, loss 설정
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)

for epoch in range(30):
   iterator = tqdm.tqdm(train_loader)
   for data, label in iterator:
       # 최적화를 위해 기울기 초기화
       optim.zero_grad()

       # 모델의 예측값
       preds = model(data.to(device))

       # 손실 계산 및 역전파
       loss = nn.CrossEntropyLoss()(preds, label.to(device))
       loss.backward()
       optim.step()

       iterator.set_description(f"epoch:{epoch+1} loss:{loss.item()}")

torch.save(model.state_dict(), "ResNet.pth")

epoch:1 loss:0.9004710912704468: 100%|██████████| 1563/1563 [00:46<00:00, 33.77it/s]
epoch:2 loss:1.228370189666748: 100%|██████████| 1563/1563 [00:44<00:00, 35.31it/s]
epoch:3 loss:1.140533447265625: 100%|██████████| 1563/1563 [00:44<00:00, 34.76it/s]
epoch:4 loss:0.7456884980201721: 100%|██████████| 1563/1563 [00:44<00:00, 34.92it/s]
epoch:5 loss:0.5660049915313721: 100%|██████████| 1563/1563 [00:44<00:00, 34.99it/s]
epoch:6 loss:0.6312590837478638: 100%|██████████| 1563/1563 [00:44<00:00, 34.96it/s]
epoch:7 loss:0.21800464391708374: 100%|██████████| 1563/1563 [00:44<00:00, 35.26it/s]
epoch:8 loss:0.3861290216445923: 100%|██████████| 1563/1563 [00:44<00:00, 35.10it/s]
epoch:9 loss:0.5747197270393372: 100%|██████████| 1563/1563 [00:44<00:00, 35.43it/s]
epoch:10 loss:0.3364240527153015: 100%|██████████| 1563/1563 [00:44<00:00, 34.75it/s]
epoch:11 loss:0.4822887182235718: 100%|██████████| 1563/1563 [00:44<00:00, 35.02it/s]
epoch:12 loss:0.748232364654541: 100%|██████████| 1563/1563 [00:

#4. Test 데이터 분류하기

In [7]:
model.load_state_dict(torch.load("ResNet.pth", map_location=device))

num_corr = 0

with torch.no_grad():
   for data, label in test_loader:

       output = model(data.to(device))
       preds = output.data.max(1)[1]
       corr = preds.eq(label.to(device).data).sum().item()
       num_corr += corr

   print(f"Accuracy:{num_corr/len(test_data)}")

Accuracy:0.8744
