# ResNet

`torchvision`에 있는 `ResNet`을 참고해서 ResNet을 직접 구현해보자!

그리고 이를 cifar10 dataset에 적용해보자!

## 1. ResNet

In [2]:
import torch
from torch import nn

In [3]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [4]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
            
        out += identity
        out = self.relu(out)
        
        return out

In [5]:
class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
            
        out += identity
        out = self.relu(out)
        
        return out

In [6]:
class ResNet(nn.Module):
    
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        
        downsample = None
        
        if stride != 1 or self.inplanes != planes * block.expansion: 
            
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride), #conv1x1(256, 512, 2)
                nn.BatchNorm2d(planes * block.expansion), #batchnrom2d(512)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        
        self.inplanes = planes * block.expansion #self.inplanes = 128 * 4
        
        for _ in range(1, blocks): 
            layers.append(block(self.inplanes, planes)) # * 3

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [7]:
def resnet18(pretrained=False, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) #=> 2*(2+2+2+2) +1(conv1) +1(fc)  = 16 +2 =resnet 18
    return model

In [8]:
def resnet50(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) #=> 3*(3+4+6+3) +(conv1) +1(fc) = 48 +2 = 50
    return model

In [9]:
def resnet152(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) # 3*(3+8+36+3) +2 = 150+2 = resnet152    
    return mode

## 2. CIFAR-10 Dataset

앞의 VGGNet에서는 평균과 표준편차를 직접 구하지 않고 0.5라고 가정하고 `torchvision.transforms.Normalize()`를 수행하였다.

하지만 여기서는 직접 train data의 평균과 표준편차를 구한 후, 정규분포화를 수행한다.

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [11]:
batch_size = 128
validation_ratio = 0.2
random_seed = 10
initial_lr = 0.005
num_epoch = 30

In [12]:
import numpy as np
import torchvision
from torch.utils.data.sampler import SubsetRandomSampler

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(
    root='../datasets/CIFAR10_dataset', train=True, download=True, transform=transform)

Files already downloaded and verified


In [13]:
train_data_mean = trainset.data.mean( axis=(0,1,2) )
train_data_std = trainset.data.std( axis=(0,1,2) )

train_data_mean = train_data_mean / 255
train_data_std = train_data_std / 255

print(train_data_mean)
print(train_data_std)

[0.49139968 0.48215841 0.44653091]
[0.24703223 0.24348513 0.26158784]


In [14]:
transform_train = torchvision.transforms.Compose([
    # train data에는 image augmentation을 수행
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(train_data_mean, train_data_std)
])

transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(train_data_mean, train_data_std)
])

trainset = torchvision.datasets.CIFAR10(
    root='../datasets/CIFAR10_dataset', train=True, download=True, transform=transform_train)

validset = torchvision.datasets.CIFAR10(
    root='../datasets/CIFAR10_dataset', train=True, download=True, transform=transform_test)

testset = torchvision.datasets.CIFAR10(
    root='../datasets/CIFAR10_dataset', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [15]:
num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, sampler=train_sampler, num_workers=6
)

valid_loader = torch.utils.data.DataLoader(
    validset, batch_size=batch_size, sampler=valid_sampler, num_workers=6
)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=6
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
num_classes = len(classes)

## 3. 학습

ResNet50을 사용해서 학습을 수행한다.

이번에는 앞의 VGG와 달리 val_loss를 비교하며 최상의 모델을 저장한 후, 이를 불러와 test에 사용한다.

In [16]:
resnet = resnet50(num_classes=10, zero_init_residual=True).to(device)

In [17]:
from torch import optim

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(resnet.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, verbose=True)

In [18]:
epochs = 30
min_val_loss = np.inf

for epoch in range(epochs):
    # Training
    running_loss = 0.0
    resnet.train() # training mode for dropout
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = resnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # running_loss += loss.item()
        # # print every 30 mini-batches
        # if i % 30 == 29:
        #     print('[Train] Epoch %3d, Mini-batches %5d, Loss: %.7f' %
        #           (epoch+1, i+1, running_loss / 30))
        #     running_loss = 0.0
        running_loss += loss / len(train_loader)

    print('[Train] Epoch %3d/%d, Loss: %.7f' % (epoch+1, epochs, running_loss))
    
    # Validation
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_loss = 0.0
        resnet.eval() # evaluation mode for dropout
        for data in valid_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = resnet(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss / len(valid_loader)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        print('[Valid] Epoch %3d/%d, Loss: %.7f, Accuracy: %.1f%%' % 
              (epoch+1, epochs, val_loss, (100*correct / total))
            )
        
        # save best model
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(resnet.state_dict(), "../models/resnet50.pth")
        
    scheduler.step(val_loss)

[Train] Epoch   1/30, Loss: 2.5651298
[Valid] Epoch   1/30, Loss: 1.6872346, Accuracy: 37.5%
[Train] Epoch   2/30, Loss: 1.5723085
[Valid] Epoch   2/30, Loss: 1.4734730, Accuracy: 46.4%
[Train] Epoch   3/30, Loss: 1.3886424
[Valid] Epoch   3/30, Loss: 1.2665462, Accuracy: 54.6%
[Train] Epoch   4/30, Loss: 1.2443051
[Valid] Epoch   4/30, Loss: 1.1617709, Accuracy: 57.9%
[Train] Epoch   5/30, Loss: 1.1380402
[Valid] Epoch   5/30, Loss: 1.0614831, Accuracy: 62.2%
[Train] Epoch   6/30, Loss: 1.0548097
[Valid] Epoch   6/30, Loss: 1.0899091, Accuracy: 62.0%
[Train] Epoch   7/30, Loss: 0.9880305
[Valid] Epoch   7/30, Loss: 0.9734443, Accuracy: 65.8%
[Train] Epoch   8/30, Loss: 0.9506536
[Valid] Epoch   8/30, Loss: 0.9801546, Accuracy: 66.1%
[Train] Epoch   9/30, Loss: 0.9034039
[Valid] Epoch   9/30, Loss: 0.9201042, Accuracy: 68.6%
[Train] Epoch  10/30, Loss: 0.8751188
[Valid] Epoch  10/30, Loss: 0.8494295, Accuracy: 71.4%
[Train] Epoch  11/30, Loss: 0.8527765
[Valid] Epoch  11/30, Loss: 0.85

In [20]:
best_model = resnet50(num_classes=10, zero_init_residual=True).to(device)
best_model.load_state_dict(torch.load("../models/resnet50.pth"))

<All keys matched successfully>

In [22]:
correct = 0
total = 0

with torch.no_grad():
    resnet.eval() # evaluation mode for dropout
    for data in test_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = best_model(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %.1f %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 73.7 %
