## 1. 사전학습된 VGG 불러오기

In [1]:
import torch
import torch.nn as nn
from torchvision.models.vgg import vgg16

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = vgg16(pretrained=True)
fc = nn.Sequential(
    nn.Linear(512*7*7, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(4096, 10),
)

model.classifier = fc
model.to(device)



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [3]:
import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, RandomCrop, ToTensor, Resize, Normalize, RandomHorizontalFlip
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

transforms = Compose([
    Resize(224),
    RandomCrop((224,224), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.44465), std=(0.247, 0.243, 0.261))
])

training_data = CIFAR10(root='./', download=True, train=True, transform=transforms)
test_data = CIFAR10(root='./', download=True, train=False, transform=transforms)

train_loader = DataLoader(training_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## 2. 모델 훈련하기

In [4]:
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(30):
    iterator = tqdm.tqdm(train_loader)
    
    for data, label in iterator:
        optim.zero_grad()
        
        preds = model(data.to(device))
        
        loss = criterion(preds, label.to(device))
        loss.backward()
        optim.step()
        
        iterator.set_description(f"epoch:{epoch+1} loss:{loss.item()}")
        
torch.save(model.state_dict(), "CIFAR_pretrained.pt")

epoch:1 loss:0.07583153992891312: 100%|██████████| 782/782 [04:52<00:00,  2.67it/s]
epoch:2 loss:0.447729229927063: 100%|██████████| 782/782 [04:50<00:00,  2.69it/s]   
epoch:3 loss:0.276898592710495: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]   
epoch:4 loss:0.03564669191837311: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s] 
epoch:5 loss:0.3124959170818329: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]  
epoch:6 loss:0.041126225143671036: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]
epoch:7 loss:0.0038204751908779144: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]
epoch:8 loss:0.19886036217212677: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]  
epoch:9 loss:0.01972886174917221: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s]  
epoch:10 loss:0.1995793730020523: 100%|██████████| 782/782 [04:50<00:00,  2.69it/s]   
epoch:11 loss:0.006823040544986725: 100%|██████████| 782/782 [04:51<00:00,  2.68it/s] 
epoch:12 loss:0.0005984307499602437: 100%|██████████| 782/7

## 3. 모델 성능평가하기

In [5]:
model.load_state_dict(torch.load("CIFAR_pretrained.pt", map_location=device))

num_corr = 0

with torch.no_grad():
    for data, label in test_loader:
        output = model(data.to(device))
        
        preds = output.max(1)[1]
        num_corr += preds.eq(label.to(device).data).sum().item()
        
print(f"Accuracy:{num_corr/len(test_data)}")

Accuracy:0.9296
