# 기본 블록 정의하기

In [1]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(BasicBlock, self).__init__()
        
        self.c1 = nn.Conv2d(in_channels, out_channels,
                           kernel_size=kernel_size, padding=1)
        self.c2 = nn.Conv2d(out_channels, out_channels,
                           kernel_size=kernel_size, padding=1)        
        self.downsample = nn.Conv2d(in_channels, out_channels,
                           kernel_size=1)
        
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)        
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        x_ = x
        
        x = self.c1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.c2(x)
        x = self.bn2(x)
        
        x_ = self.downsample(x_)
        
        x += x_
        x = self.relu(x)
        
        return x

# ResNet 모델 정의하기

In [2]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        
        self.b1 = BasicBlock(in_channels=3, out_channels=64)
        self.b2 = BasicBlock(in_channels=64, out_channels=128)
        self.b3 = BasicBlock(in_channels=128, out_channels=256)        
        
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=4096, out_features=2048)
        self.fc2 = nn.Linear(in_features=2048, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=num_classes)   
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.b1(x)
        x = self.pool(x)
        x = self.b2(x)
        x = self.pool(x)
        x = self.b3(x)
        x = self.pool(x)        
        x = torch.flatten(x, start_dim=1)
        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        
        return x

# 모델 학습하기

In [3]:
import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, RandomCrop, Normalize
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

train_transforms = Compose([
    RandomCrop((32, 32), padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465),
             std=(0.247, 0.243, 0.261))
])

test_transforms = Compose([
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465),
             std=(0.247, 0.243, 0.261))
])

train_data = CIFAR10(root='/home/restful3/datasets/torch', train=True, download=True, transform=train_transforms)
test_data = CIFAR10(root='/home/restful3/datasets/torch', train=True, download=True, transform=test_transforms)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

  warn(f"Failed to load image Python extension: {e}")


Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model = ResNet(num_classes=10)
model.to(device)

cuda


ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [19]:
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)

for epoch in range(30):
    iterator = tqdm.tqdm(train_loader)
    for data, label in iterator:
        data, label = data.to(device), label.to(device)
        optim.zero_grad()
        
        preds = model(data)
        loss = nn.CrossEntropyLoss()(preds, label)
        loss.backward()
        optim.step()
        
        iterator.set_description(f'epoch : {epoch+1}, loss : {loss.item():.2f}')
        
torch.save(model.state_dict(), './models/ResNet.pth')        

epoch : 1, loss : 0.77: 100%|███████████████| 1563/1563 [00:17<00:00, 91.62it/s]
epoch : 2, loss : 0.39: 100%|███████████████| 1563/1563 [00:17<00:00, 91.59it/s]
epoch : 3, loss : 0.65: 100%|███████████████| 1563/1563 [00:17<00:00, 90.69it/s]
epoch : 4, loss : 0.51: 100%|███████████████| 1563/1563 [00:17<00:00, 90.47it/s]
epoch : 5, loss : 0.53: 100%|███████████████| 1563/1563 [00:17<00:00, 90.80it/s]
epoch : 6, loss : 0.31: 100%|███████████████| 1563/1563 [00:16<00:00, 94.93it/s]
epoch : 7, loss : 0.18: 100%|███████████████| 1563/1563 [00:16<00:00, 94.97it/s]
epoch : 8, loss : 0.37: 100%|███████████████| 1563/1563 [00:16<00:00, 92.51it/s]
epoch : 9, loss : 0.76: 100%|███████████████| 1563/1563 [00:17<00:00, 90.34it/s]
epoch : 10, loss : 0.13: 100%|██████████████| 1563/1563 [00:17<00:00, 90.00it/s]
epoch : 11, loss : 0.17: 100%|██████████████| 1563/1563 [00:17<00:00, 90.65it/s]
epoch : 12, loss : 0.06: 100%|██████████████| 1563/1563 [00:17<00:00, 90.90it/s]
epoch : 13, loss : 0.63: 100

# 모델 성능 평가하기

In [9]:
model.load_state_dict(torch.load('./models/ResNet.pth', map_location=device))

num_corr = 0

with torch.no_grad():
    for data, label in test_loader:
        data, label = data.to(device), label.to(device)
        
        output = model(data)
        preds = output.data.max(1)[1]
        corr = preds.eq(label.data).sum().item()
        num_corr += corr
        
    print(f'Accuracy : {num_corr/len(test_data)}')

Accuracy : 0.9778
