## 1. ResNet

- ResNet은 CNN 중에서 가장 많이 쓰이는 모델
- VGG는 19층 이상으로 쌓을 수 없다. 다만 ResNet은 Skip Connection 기술을 이용해서 기울시 소실을 방지함

In [1]:
import torch
import torch.nn as nn

## 2. ResNet 기본 블록 정의하기

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(BasicBlock, self).__init__()
        
        self.c1 = nn.Conv2d(in_channels, out_channels, 
                            kernel_size=kernel_size, padding=1)
        self.c2 = nn.Conv2d(out_channels, out_channels, 
                            kernel_size=kernel_size, padding=1)
        self.downsample = nn.Conv2d(in_channels, out_channels, 
                                    kernel_size=1)
        
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x_ = x
        
        x = self.c1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.c2(x)
        x = self.bn2(x)
        
        # 합성곱의 결과와 입력의 채널 수를 맞춤
        x_ = self.downsample(x_)
        
        x += x_
        
        x = self.relu(x)
        
        return x 
        

## 3. ResNet 모델 정의하기

In [3]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        
        self.b1 = BasicBlock(in_channels=3, out_channels=64)
        self.b2 = BasicBlock(in_channels=64, out_channels=128)
        self.b3 = BasicBlock(in_channels=128, out_channels=256)
        
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=4096, out_features=2048)
        self.fc2 = nn.Linear(in_features=2048, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=num_classes)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.b1(x)
        x = self.pool(x)
        x = self.b2(x)
        x = self.pool(x)
        x = self.b3(x)
        x = self.pool(x)
        
        x = torch.flatten(x, start_dim=1)
        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
                
        return x

## 4. 데이터 전처리 정의

In [4]:
import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms import RandomCrop, RandomHorizontalFlip
from torchvision.transforms import Normalize
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

transforms = Compose([
    RandomCrop((32,32), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

## 5. 데이터 불러오기

In [5]:
training_data = CIFAR10(root='./', train=True, download=True, transform=transforms)
test_data = CIFAR10(root='./', train=False, download=True, transform=transforms)

train_loader = DataLoader(dataset=training_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## 6. 모델 정의하기

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet(num_classes=10)
model.to(device)

ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


## 7. 학습 루프 정의하기

In [8]:
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)
max_acc = 0

for epoch in range(30):
    iterator = tqdm.tqdm(train_loader)
    
    for data, label in iterator:
        optim.zero_grad()
        
        preds = model(data.to(device))
        
        loss = nn.CrossEntropyLoss()(preds, label.to(device))
        loss.backward()
        optim.step()
        
        iterator.set_description(f"epoch: {epoch+1} loss:{loss.item()}")
    
    with torch.no_grad():
        iterator = tqdm.tqdm(test_loader)

        acc = 0 
        
        for data, label in iterator:
            preds = model(data.to(device))
            
            output = preds.data.max(1)[1]
            corr = output.eq(label.to(device).data).sum().item()
            acc += corr
            
        
        acc = acc / len(test_data)
        print(f"Accuracy: {acc:.4f}")
        
        if max_acc < acc:
            max_acc = acc
            print("Save Pth!")
            torch.save(model.state_dict(), "ResNet.pth")
        

epoch: 1 loss:0.9484152793884277: 100%|██████████| 782/782 [00:12<00:00, 63.69it/s]
100%|██████████| 157/157 [00:01<00:00, 121.36it/s]


Accuracy: 0.6916
Save Pth!


epoch: 2 loss:0.5299801826477051: 100%|██████████| 782/782 [00:12<00:00, 63.25it/s] 
100%|██████████| 157/157 [00:01<00:00, 120.90it/s]


Accuracy: 0.7324
Save Pth!


epoch: 3 loss:0.8240244388580322: 100%|██████████| 782/782 [00:12<00:00, 63.30it/s] 
100%|██████████| 157/157 [00:01<00:00, 120.80it/s]


Accuracy: 0.7697
Save Pth!


epoch: 4 loss:0.3468015491962433: 100%|██████████| 782/782 [00:12<00:00, 63.28it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.24it/s]


Accuracy: 0.7926
Save Pth!


epoch: 5 loss:0.48342499136924744: 100%|██████████| 782/782 [00:12<00:00, 63.13it/s]
100%|██████████| 157/157 [00:01<00:00, 120.94it/s]


Accuracy: 0.7943
Save Pth!


epoch: 6 loss:0.27809107303619385: 100%|██████████| 782/782 [00:12<00:00, 63.14it/s]
100%|██████████| 157/157 [00:01<00:00, 120.42it/s]


Accuracy: 0.8199
Save Pth!


epoch: 7 loss:0.4314178228378296: 100%|██████████| 782/782 [00:12<00:00, 63.39it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.05it/s]


Accuracy: 0.8207
Save Pth!


epoch: 8 loss:0.18510626256465912: 100%|██████████| 782/782 [00:12<00:00, 63.19it/s]
100%|██████████| 157/157 [00:01<00:00, 120.75it/s]


Accuracy: 0.8349
Save Pth!


epoch: 9 loss:0.3514291048049927: 100%|██████████| 782/782 [00:12<00:00, 63.42it/s] 
100%|██████████| 157/157 [00:01<00:00, 123.11it/s]


Accuracy: 0.8448
Save Pth!


epoch: 10 loss:0.40302810072898865: 100%|██████████| 782/782 [00:12<00:00, 63.13it/s]
100%|██████████| 157/157 [00:01<00:00, 122.47it/s]


Accuracy: 0.8405


epoch: 11 loss:0.3962790369987488: 100%|██████████| 782/782 [00:12<00:00, 63.48it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.96it/s]


Accuracy: 0.8467
Save Pth!


epoch: 12 loss:0.40217676758766174: 100%|██████████| 782/782 [00:12<00:00, 63.57it/s]
100%|██████████| 157/157 [00:01<00:00, 121.23it/s]


Accuracy: 0.8471
Save Pth!


epoch: 13 loss:0.35029876232147217: 100%|██████████| 782/782 [00:12<00:00, 63.49it/s]
100%|██████████| 157/157 [00:01<00:00, 121.47it/s]


Accuracy: 0.8512
Save Pth!


epoch: 14 loss:0.10817859321832657: 100%|██████████| 782/782 [00:12<00:00, 63.36it/s]
100%|██████████| 157/157 [00:01<00:00, 122.45it/s]


Accuracy: 0.8548
Save Pth!


epoch: 15 loss:0.5675683617591858: 100%|██████████| 782/782 [00:12<00:00, 63.28it/s]  
100%|██████████| 157/157 [00:01<00:00, 121.63it/s]


Accuracy: 0.8637
Save Pth!


epoch: 16 loss:0.22054070234298706: 100%|██████████| 782/782 [00:12<00:00, 63.75it/s] 
100%|██████████| 157/157 [00:01<00:00, 122.89it/s]


Accuracy: 0.8642
Save Pth!


epoch: 17 loss:0.6785425543785095: 100%|██████████| 782/782 [00:12<00:00, 63.44it/s]  
100%|██████████| 157/157 [00:01<00:00, 120.72it/s]


Accuracy: 0.8614


epoch: 18 loss:0.008485511876642704: 100%|██████████| 782/782 [00:12<00:00, 63.37it/s]
100%|██████████| 157/157 [00:01<00:00, 121.75it/s]


Accuracy: 0.8656
Save Pth!


epoch: 19 loss:0.16727641224861145: 100%|██████████| 782/782 [00:12<00:00, 63.65it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.69it/s]


Accuracy: 0.8714
Save Pth!


epoch: 20 loss:0.22029522061347961: 100%|██████████| 782/782 [00:12<00:00, 63.50it/s] 
100%|██████████| 157/157 [00:01<00:00, 122.02it/s]


Accuracy: 0.8738
Save Pth!


epoch: 21 loss:0.14009307324886322: 100%|██████████| 782/782 [00:12<00:00, 63.44it/s] 
100%|██████████| 157/157 [00:01<00:00, 122.01it/s]


Accuracy: 0.8671


epoch: 22 loss:0.15267018973827362: 100%|██████████| 782/782 [00:12<00:00, 63.08it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.96it/s]


Accuracy: 0.8723


epoch: 23 loss:0.16239304840564728: 100%|██████████| 782/782 [00:12<00:00, 63.55it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.78it/s]


Accuracy: 0.8709


epoch: 24 loss:0.026315953582525253: 100%|██████████| 782/782 [00:12<00:00, 63.29it/s]
100%|██████████| 157/157 [00:01<00:00, 121.06it/s]


Accuracy: 0.8774
Save Pth!


epoch: 25 loss:0.0424097515642643: 100%|██████████| 782/782 [00:12<00:00, 62.61it/s]  
100%|██████████| 157/157 [00:01<00:00, 121.65it/s]


Accuracy: 0.8763


epoch: 26 loss:0.07891726493835449: 100%|██████████| 782/782 [00:12<00:00, 63.21it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.82it/s]


Accuracy: 0.8751


epoch: 27 loss:0.01508516538888216: 100%|██████████| 782/782 [00:12<00:00, 63.17it/s] 
100%|██████████| 157/157 [00:01<00:00, 121.62it/s]


Accuracy: 0.8776
Save Pth!


epoch: 28 loss:0.1348452866077423: 100%|██████████| 782/782 [00:12<00:00, 63.19it/s]  
100%|██████████| 157/157 [00:01<00:00, 122.70it/s]


Accuracy: 0.8726


epoch: 29 loss:0.041985027492046356: 100%|██████████| 782/782 [00:12<00:00, 63.17it/s]
100%|██████████| 157/157 [00:01<00:00, 121.71it/s]


Accuracy: 0.8823
Save Pth!


epoch: 30 loss:0.038432877510786057: 100%|██████████| 782/782 [00:12<00:00, 63.45it/s]
100%|██████████| 157/157 [00:01<00:00, 122.03it/s]

Accuracy: 0.8820



