## 1. Pretrained ResNet

In [1]:
from torchvision.models.resnet import resnet18
import torch
import torch.nn as nn

In [2]:
model = resnet18(pretrained=True)



In [3]:
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [4]:
model.fc = nn.Sequential(
    nn.Linear(in_features=512, out_features=256),
    nn.ReLU(),
    nn.Linear(in_features=256, out_features=10)
)

## 2. 데이터 전처리 정의

In [5]:
import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor
from torchvision.transforms import Resize, RandomCrop, RandomHorizontalFlip
from torchvision.transforms import Normalize
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

transforms = Compose([
    Resize((128,128)),
    RandomCrop((128,128), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
])

## 3. 데이터 불러오기

In [6]:
training_data = CIFAR10(root='./', train=True, download=True, transform=transforms)
test_data = CIFAR10(root='./', train=False, download=True, transform=transforms)

train_loader = DataLoader(dataset=training_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## 4. 모델 정의하기

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## 5. torchsummary로 모델 살펴보기

In [8]:
from torchsummary import summary
summary(model, input_size=(3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
       BasicBlock-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64,

## 6. 학습 루프 정의하기

In [9]:
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)
max_acc = 0

for epoch in range(30):
    iterator = tqdm.tqdm(train_loader)
    
    for data, label in iterator:
        optim.zero_grad()
        
        preds = model(data.to(device))
        
        loss = nn.CrossEntropyLoss()(preds, label.to(device))
        loss.backward()
        optim.step()
        
        iterator.set_description(f"epoch: {epoch+1} loss:{loss.item()}")
    
    with torch.no_grad():
        iterator = tqdm.tqdm(test_loader)

        acc = 0 
        
        for data, label in iterator:
            preds = model(data.to(device))
            
            output = preds.data.max(1)[1]
            corr = output.eq(label.to(device).data).sum().item()
            acc += corr
            
        
        acc = acc / len(test_data)
        print(f"Accuracy: {acc:.4f}")
        
        if max_acc < acc:
            max_acc = acc
            print("Save Pth!")
            torch.save(model.state_dict(), "ResNet.pth")
        

epoch: 1 loss:0.08227650076150894: 100%|██████████| 782/782 [00:38<00:00, 20.25it/s]
100%|██████████| 157/157 [00:05<00:00, 27.71it/s]


Accuracy: 0.9200
Save Pth!


epoch: 2 loss:0.08800821006298065: 100%|██████████| 782/782 [00:37<00:00, 21.09it/s] 
100%|██████████| 157/157 [00:07<00:00, 22.27it/s]


Accuracy: 0.9278
Save Pth!


epoch: 3 loss:0.44674810767173767: 100%|██████████| 782/782 [00:36<00:00, 21.26it/s] 
100%|██████████| 157/157 [00:06<00:00, 22.71it/s]


Accuracy: 0.9363
Save Pth!


epoch: 4 loss:0.16540215909481049: 100%|██████████| 782/782 [00:36<00:00, 21.43it/s] 
100%|██████████| 157/157 [00:06<00:00, 22.73it/s]


Accuracy: 0.9335


epoch: 5 loss:0.30489614605903625: 100%|██████████| 782/782 [00:37<00:00, 21.12it/s]  
100%|██████████| 157/157 [00:04<00:00, 37.31it/s]


Accuracy: 0.9350


epoch: 6 loss:0.10446347296237946: 100%|██████████| 782/782 [00:37<00:00, 21.05it/s]  
100%|██████████| 157/157 [00:06<00:00, 24.93it/s]


Accuracy: 0.9380
Save Pth!


epoch: 7 loss:0.22283832728862762: 100%|██████████| 782/782 [00:36<00:00, 21.21it/s]  
100%|██████████| 157/157 [00:04<00:00, 35.06it/s]


Accuracy: 0.9372


epoch: 8 loss:0.432537317276001: 100%|██████████| 782/782 [00:38<00:00, 20.07it/s]    
100%|██████████| 157/157 [00:06<00:00, 22.98it/s]


Accuracy: 0.9390
Save Pth!


epoch: 9 loss:0.052063729614019394: 100%|██████████| 782/782 [00:38<00:00, 20.46it/s] 
100%|██████████| 157/157 [00:06<00:00, 24.85it/s]


Accuracy: 0.9376


epoch: 10 loss:0.03899811580777168: 100%|██████████| 782/782 [00:37<00:00, 21.05it/s]  
100%|██████████| 157/157 [00:06<00:00, 23.24it/s]


Accuracy: 0.9385


epoch: 11 loss:0.08099459111690521: 100%|██████████| 782/782 [00:36<00:00, 21.19it/s]  
100%|██████████| 157/157 [00:05<00:00, 26.76it/s]


Accuracy: 0.9377


epoch: 12 loss:0.0926501601934433: 100%|██████████| 782/782 [00:36<00:00, 21.20it/s]   
100%|██████████| 157/157 [00:07<00:00, 22.13it/s]


Accuracy: 0.9374


epoch: 13 loss:0.13458070158958435: 100%|██████████| 782/782 [00:37<00:00, 21.08it/s]  
100%|██████████| 157/157 [00:06<00:00, 22.79it/s]


Accuracy: 0.9391
Save Pth!


epoch: 14 loss:0.4098188579082489: 100%|██████████| 782/782 [00:37<00:00, 21.00it/s]    
100%|██████████| 157/157 [00:06<00:00, 22.51it/s]


Accuracy: 0.9367


epoch: 15 loss:0.2615851163864136: 100%|██████████| 782/782 [00:36<00:00, 21.47it/s]   
100%|██████████| 157/157 [00:05<00:00, 28.16it/s]


Accuracy: 0.9428
Save Pth!


epoch: 16 loss:0.11665176600217819: 100%|██████████| 782/782 [00:36<00:00, 21.40it/s]   
100%|██████████| 157/157 [00:07<00:00, 22.41it/s]


Accuracy: 0.9403


epoch: 17 loss:0.0004999365191906691: 100%|██████████| 782/782 [00:36<00:00, 21.64it/s] 
100%|██████████| 157/157 [00:06<00:00, 22.63it/s]


Accuracy: 0.9319


epoch: 18 loss:0.22610709071159363: 100%|██████████| 782/782 [00:36<00:00, 21.55it/s]   
100%|██████████| 157/157 [00:06<00:00, 22.94it/s]


Accuracy: 0.9421


epoch: 19 loss:0.03165188059210777: 100%|██████████| 782/782 [00:37<00:00, 21.11it/s]   
100%|██████████| 157/157 [00:05<00:00, 26.95it/s]


Accuracy: 0.9381


epoch: 20 loss:0.08353224396705627: 100%|██████████| 782/782 [00:37<00:00, 20.75it/s]   
100%|██████████| 157/157 [00:03<00:00, 40.32it/s]


Accuracy: 0.9407


epoch: 21 loss:0.2644971013069153: 100%|██████████| 782/782 [00:38<00:00, 20.52it/s]    
100%|██████████| 157/157 [00:06<00:00, 23.01it/s]


Accuracy: 0.9388


epoch: 22 loss:0.22489969432353973: 100%|██████████| 782/782 [00:37<00:00, 20.64it/s]   
100%|██████████| 157/157 [00:04<00:00, 31.85it/s]


Accuracy: 0.9410


epoch: 23 loss:0.0005150189972482622: 100%|██████████| 782/782 [00:37<00:00, 20.86it/s] 
100%|██████████| 157/157 [00:06<00:00, 23.48it/s]


Accuracy: 0.9422


epoch: 24 loss:0.00022818308207206428: 100%|██████████| 782/782 [00:37<00:00, 20.92it/s]
100%|██████████| 157/157 [00:04<00:00, 37.48it/s]


Accuracy: 0.9354


epoch: 25 loss:0.008493714965879917: 100%|██████████| 782/782 [00:37<00:00, 20.72it/s]  
100%|██████████| 157/157 [00:06<00:00, 22.46it/s]


Accuracy: 0.9429
Save Pth!


epoch: 26 loss:0.025439750403165817: 100%|██████████| 782/782 [00:38<00:00, 20.30it/s]  
100%|██████████| 157/157 [00:06<00:00, 22.45it/s]


Accuracy: 0.9396


epoch: 27 loss:0.1082405149936676: 100%|██████████| 782/782 [00:36<00:00, 21.17it/s]    
100%|██████████| 157/157 [00:05<00:00, 27.66it/s]


Accuracy: 0.9358


epoch: 28 loss:0.06607886403799057: 100%|██████████| 782/782 [00:37<00:00, 20.59it/s]   
100%|██████████| 157/157 [00:06<00:00, 23.21it/s]


Accuracy: 0.9372


epoch: 29 loss:0.09490983188152313: 100%|██████████| 782/782 [00:39<00:00, 19.88it/s]   
100%|██████████| 157/157 [00:06<00:00, 22.49it/s]


Accuracy: 0.9426


epoch: 30 loss:0.022054603323340416: 100%|██████████| 782/782 [00:37<00:00, 20.83it/s]  
100%|██████████| 157/157 [00:06<00:00, 23.03it/s]

Accuracy: 0.9404



