## Unet

- Computer Vision Model
- 깊은 신경망에서 channel 수를 늘리고, 해상도를 줄이면서 발생하는 정보를 버리지 않고 사용하는 방법
- 이미지를 압축하고 다시 복원하는 부분에서 압축하면서 손실되는 공간 정보를 concat을 통해 활용하는 방법
- DownSample : 신경망에서 channel의 수를 늘리고 해상도를 줄이면서 특징을 추출하는 부분
- UpSample : 신경망에서 channel의 수를 줄이고, 해상도를 높이면서 이미지를 복원하는 부분

In [2]:
# EfficientNet-B0 Baseline network
## 1st = Conv 3*3
## 2nd = MBConv1, k3*3

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
from torchsummary import summary
from torch import optim
import torchvision
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class CBR2d(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # Conv + BatchNorn + Relu
        self.CBR2d = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):

        return self.CBR2d(x)



class Unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1_1 = CBR2d(in_channels=3, out_channels=64)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)
        
        self.uppool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2, padding=0, bias=True)
        self.dec4_2 = CBR2d(in_channels=512*2, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)
        
        self.uppool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, padding=0, bias=True)
        self.dec3_2 = CBR2d(in_channels=256*2, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.uppool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2, padding=0, bias=True)
        self.dec2_2 = CBR2d(in_channels=128*2, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.uppool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True)
        self.dec1_2 = CBR2d(in_channels=64*2, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
         # 좌측 1층 레이어 2개 연결 및 빨간색 화살표
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        # 좌측 2층 레이어 2개 연결 및 빨간색 화살표
        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        # 좌측 3층 레이어 2개 연결 및 빨간색 화살표
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        # 좌측 4층 레이어 2개 연결 및 빨간색 화살표
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        # 좌측 5층 레이어
        enc5_1 = self.enc5_1(pool4)

        # 우측 5층 레이어 및 초록색 화살표
        dec5_1 = self.dec5_1(enc5_1)
        uppool4 = self.uppool4(dec5_1)

        # 하얀색 부분 연결하기
        cat4 = torch.cat((uppool4, enc4_2),dim =1)

        # 파란색 화살표 실행
        # cat에서 512 + 512 로 1024의 레이어 만들고 파란색 화살표 수행후 아웃풋값을 512로 만듬
        dec4_2 = self.dec4_2(cat4)

        # 여기까지 하면 우측 4층 레이어까지 생성
        dec4_1 = self.dec4_1(dec4_2)

        # 반복 3층
        uppool3 = self.uppool3(dec4_1)
        cat3 = torch.cat((uppool3, enc3_2),dim =1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        # 반복 2층
        uppool2 = self.uppool2(dec3_1)
        cat2 = torch.cat((uppool2, enc2_2),dim =1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        # 반복 1층
        uppool1 = self.uppool1(dec2_1)
        cat1 = torch.cat((uppool1, enc1_2),dim =1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x


        

In [4]:
## nn.Module을 상속하면 모델 객체 내부의 모든 필드들이 자동으로 추적
## paramters(), named_parameters 메소드로 모든 매개변수에 접근 가능

model = Unet()

print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f'Layer: {name} | Size: {param.size()} | Values: {param[:2]} \n')

Model structure: Unet(
  (enc1_1): CBR2d(
    (CBR2d): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (enc1_2): CBR2d(
    (CBR2d): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2_1): CBR2d(
    (CBR2d): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (enc2_2): CBR2d(
    (CBR2d): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum

In [5]:
def Unet(num_classes=1):
    return Unet()

In [None]:

x = torch.randn(4, 3, 224, 224)
model = Unet()
output = model(x)
print('output size:', output.size())

# summary(model, input_size=(3, 224, 224))

: 