In [1]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import os
import torchvision
import torchvision.transforms as transforms

## 모델 쌓기

In [24]:
class UNet(torch.nn.Module):
  def __init__(self):
    super().__init__()

    # Convolution, Batch Normalization, ReLU 연산을 합친 함수
    # 논문데선 BN이 활용되지 않았으나 구현 시 활용하기도 함
    def step(input_channel, output_channel, kernel_size=3, stride=1):
      layer = nn.Sequential(
          nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride),
          nn.BatchNorm2d(num_features=output_channel),
          nn.ReLU())
      return layer

    ####### Contracting path #######
    # conv 연산마다 channel 수를 키움
    # max-pooling으로 이미지 해상도 줄임
    # 별도의 padding이 없어 conv 연산 수행마다 이미지 크기 줄어듦

    # convolution 연산
    # 처음엔 channel 수 1 (grayscaled)
    self.conv1 = nn.Sequential(
        step(1, 64, 3, 1), # 572x572x1 => 570x570x64
        step(64, 64, 3, 1)) # 570x570x64 => 568x568x64
    
    # max-pooling 
    # stride를 2로 설정하여 이미지 크기 줄임
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 568x568x64 => 284x284x64 		


    # convolution
    self.conv2 = nn.Sequential(
        step(64, 128, 3, 1), # 284x284x64 => 282x282x128
        step(128, 128, 3, 1)) # 282x282x128 => 280x280x128
    
    # max-pooling 
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 280x280x128 => 140x140x128


    # convolution
    self.conv3 = nn.Sequential(
        step(128, 256, 3, 1), # 140x140x128 => 138x138x256
        step(256, 256, 3, 1)) # 138x138x256 => 136x136x256
    
    # max-pooling 
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 136x136x256 => 68x68x256


    # convolution
    self.conv4 = nn.Sequential(
        step(256, 512, 3, 1), # 68x68x256 => 66x66x512
        step(512, 512, 3, 1), # 66x66x512 => 64x64x512
        nn.Dropout(p=0.5)) # 모델을 일반화하고 노이즈에 견고하게 만들기 위해 사용하는 장치
    
    # max-pooling 
    self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 64x64x512 => 32x32x512


    ####### BottleNeck (전환구간) #######
    # 채널 수 늘려줌 512 -> 1024
    self.bottleNeck = nn.Sequential(
        step(512, 1024, 3, 1), # 32x32x512 => 30x30x1024
        step(1024, 1024, 3, 1)) # 30x30x1024 => 28x28x1024


    ####### Expanding path #######
    # conv 연산마다 channel 수를 줄임
    # up-conv (transposed conv) 연산으로 이미지 해상도 키움     

    # up-conv (transposed conv)
    self.upconv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2) # 28x28x1024 => 56x56x512 

    # conv
    # Concatenation 수행으로 시작 채널 수가 2배 됨 (forward 과정에서 수행)
    self.ex_conv1 = nn.Sequential(
        step(1024, 512, 3, 1), # 56x56x1024 => 54x54x512
        step(512, 512, 3, 1)) # 54x54x512 => 52x52x512

    
    # up-conv (transposed conv)
    self.upconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2) # 52x52x512 => 104x104x256

    # conv
    self.ex_conv2 = nn.Sequential(
        step(512, 256, 3, 1), # 104x104x512 => 102x102x256
        step(256, 256, 3, 1)) # 102x102x256 => 100x100x256


    # up-conv (transposed conv)
    self.upconv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2) # 100x100x256 => 200x200x128

    # conv
    self.ex_conv3 = nn.Sequential(
        step(256, 128, 3, 1), # 200x200x256 => 198x198x128
        step(128, 128, 3, 1)) # 198x198x128 => 196x196x128


    # up-conv (transposed conv)
    self.upconv4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2) # 196x196x128 => 392x392x64

    # conv
    self.ex_conv4 = nn.Sequential(
        step(128, 64, 3, 1), # 392x392x128 => 390x390x64
        step(64, 64, 3, 1)) # 390x390x64 => 388x388x64

    self.fc = nn.Conv2d(64, 2, kernel_size=1, stride=1) # 388x388x64 => 388x388x2

  def forward(self, x):
    	# Contracting path
      layer1 = self.conv1(x)
      out = self.pool1(layer1)
      layer2 = self.conv2(out)
      out = self.pool2(layer2)
      layer3 = self.conv3(out)
      out = self.pool3(layer3)
      layer4 = self.conv4(out)
      out = self.pool4(layer4)

		  # bottleneck 
      # 32x32x512 => 28x28x1024
      bottleNeck = self.bottleNeck(out)

	    # Expanding path
      upconv1 = self.upconv1(bottleNeck)
        
      # Contracting path 중 같은 단계의 Feature map을 가져와 합침
      # Up-Convolution 결과의 Feature map size 만큼 CenterCrop 하여 Concat 연산
      # 56x56x512 => 56x56x1024

      #upconv1.shape의 형태: torch.Size([1, 512, 56, 56]) (미니배치, 채널, img size)
      cat1 = torch.cat((transforms.CenterCrop((upconv1.shape[2], upconv1.shape[3]))(layer4), upconv1), dim=1)
      ex_layer1 = self.ex_conv1(cat1)
      upconv2 = self.upconv2(ex_layer1)
      cat2 = torch.cat((transforms.CenterCrop((upconv2.shape[2], upconv2.shape[3]))(layer3), upconv2), dim=1)
      ex_layer2 = self.ex_conv2(cat2)
      upconv3 = self.upconv3(ex_layer2)
      cat3 = torch.cat((transforms.CenterCrop((upconv3.shape[2], upconv3.shape[3]))(layer2), upconv3), dim=1)
      ex_layer3 = self.ex_conv3(cat3)
      upconv4 = self.upconv4(ex_layer3)
      cat4 = torch.cat((transforms.CenterCrop((upconv4.shape[2], upconv4.shape[3]))(layer1), upconv4), dim=1)
      out = self.ex_conv4(cat4)
      out = self.fc(out)
      return out