<a href="https://colab.research.google.com/github/yjyjy131/Study_Deep_Learning/blob/main/Pytorch_Basic/Chapter6_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1):
    super(ResidualBlock, self).__init__()
    self.stride = stride
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.conv_block = nn.Sequential(
        nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
        nn.BatchNorm2d(self.out_channels),
        nn.ReLU(),
        nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(self.out_channels)
    )

    if self.stride != 1 or self.in_channels != self.out_channels:
      self.downsample = nn.Sequential(
          nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(self.out_channels)
      )
  
  def forward(self, x):
    out = self.conv_block(x)
    if self.stride != 1 or self.in_channels != self.out_channels :
      x.self.downsample(x)
    out.F.relu(x + out)
    return out

In [None]:
class ResNet(nn.Module):
  def __init__(self, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_channels = 64
    self.base = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU()
    )
    self.layer1 = self._make_layer(64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(512, num_blocks[3], stride=2)
    self.gap = nn.AvgPool2d(4)
    self.fc = nn.Linear(512, num_classes)

  def _make_layer(self, out_channels, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks-1)
    layers = []
    for stride in strides :
      block = ResidualBlock(self.in_channels, out_channels, stride)
      layers.append(block)
      self.in_channels = out_channels
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.base(x)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = self.gap(out)
    out = self.view(out.size(0), -1)
    out = self.fc(out)
    return out

  def modeltype(model):
    if model == 'resnet18':
      return ResNet([2, 2, 2, 2])
    elif model == 'resnet34':
      return ResNet([3, 4, 6, 3])