# FCN 논문 구현 및 Segmentation 실습

## Prep

### 라이브러리

In [1]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torchvision
from torchvision import models
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 데이터 확인

In [None]:
image_list = glob('/kaggle/input/flood-area-segmentation/Image/*')

mask_list = glob('/kaggle/input/flood-area-segmentation/Mask/*')
len(image_list), len(mask_list)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

for i in range(4):
    img_path = image_list[i]
    mask_path = image_list[i].replace('Image', 'Mask').replace('jpg', 'png')
    
    axes[2 * i].imshow(Image.open(img_path))
    axes[2 * i].set_title(f'Image {i+1}')
    axes[2 * i].axis('off')
    
    axes[2 * i + 1].imshow(Image.open(mask_path))
    axes[2 * i + 1].set_title(f'Mask {i+1}')
    axes[2 * i + 1].axis('off')

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
    img_path = image_list[i]
    mask_path = image_list[i].replace('Image', 'Mask').replace('jpg', 'png')
    img = Image.open(img_path)
    mask = Image.open(mask_path)
    ax.imshow(img)
    ax.imshow(mask, alpha=0.5)
    ax.axis('off')

### 마스크 Binary화 필요

In [None]:
unique, counts = np.unique(mask_list, return_counts=True)
for mask_path in mask_list:
    mask = Image.open(mask_path)
    mask = np.array(mask)
    print(mask)
#     print(mask.shape)
    print(np.unique(mask))
    break

### 이미지 데이터 차원 통일

In [None]:
exclude = []

for img_path in image_list:
    img = Image.open(img_path)
    img = np.array(img)
    if img.ndim != 3 or img.shape[2] != 3:
        print(img.shape)
        print(img.ndim)
        print(img_path)
        exclude.append(img_path)

### 차원 불일치 데이터 제거

In [None]:
for ex in exclude:
    image_list.remove(ex)
    mask_list.remove(ex.replace('Image', 'Mask').replace('jpg', 'png'))

### Dataset

In [None]:
class FloodDataset(Dataset):
    def __init__(self, image_list, mask_list, transform):
        
        self.image_list = image_list
        self.mask_list = mask_list
        self.transform = transform
        
    def __len__(self):
        
        return len(self.image_list)
    
    def __getitem__(self, idx):
        
        image_path = self.image_list[idx]
        mask_path = image_path.replace('Image', 'Mask').replace('jpg', 'png')
        
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        
        image = self.transform(image)
        mask = self.transform(mask)
        
        mask = (mask > 0.5).float()
        
        return image, mask

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Resize((224, 224)),
                               ])

### DataLoader

In [None]:
flood_dataset = FloodDataset(image_list, mask_list, transform = transform)
print(flood_dataset.__len__())

train_size = int(0.8 * len(flood_dataset))
val_size = len(flood_dataset) - train_size

train_ds, val_ds = random_split(flood_dataset, [train_size, val_size])
print(f"Train dataset size: {len(train_ds)} | Validation dataset size: {len(val_ds)}")

train_dl = DataLoader(train_ds, batch_size = 16, shuffle = True)
val_dl = DataLoader(val_ds, batch_size = 16, shuffle = False)

imgs, masks = next(iter(train_dl))
print(imgs.shape, masks.shape)

## 학습 및 추론


### 모델 구현
아래 그림과 조건에 맞게 모델을 구현해봅니다.
![](https://velog.velcdn.com/images%2Fleejaejun%2Fpost%2Fc3b69a0d-2329-4903-9baa-de2949af87fb%2Fimage.png)

- VGG16 Backbone을 활용
    - VGG16 Backbone은 총 5개의 Conv block으로 구성
    - 이미지를 Backbone에 통과
    - 추가 Conv layer 2개를 통과
    - 마지막 Feature map Channel size = class num과 같아야 함
- Skip connection
    - 3, 4번 block을 통과한 Feature map을 재사용
    - Upsampling Feature map과 더하기 위하여 Channel size 조정 필요
    - 결합 방식: tensor_1 + tensor_2(더하기 연산)
- Upsampling
    - 두 개의 Upsampling 레이어를 통과
    - 총 세 번 Upsampling(2배 증가 2번, 8배 증가 1번)
    - 첫 번째 2배 증가 후 4번 Block과 결합
    - 두 번째 2배 증가 후 3번 Block과 결합

In [8]:
class FCN8s(nn.Module):
    def __init__(self, n_classes):
        super(FCN8s, self).__init__()
        vgg = models.vgg16(pretrained=True)
        features = list(vgg.features.children())

        # VGG16의 각 블록을 PyTorch Sequential로 구성
        self.block3 = nn.Sequential(*features[:17])  # Conv1 ~ Conv3
        self.block4 = nn.Sequential(*features[17:24])  # Conv4
        self.block5 = nn.Sequential(*features[24:])  # Conv5

        # 추가 Conv 레이어
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=7, padding=3)
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)

        # FCN에서 사용할 1x1 Conv
        self.conv1x1_pool3 = nn.Conv2d(256, n_classes, kernel_size=1)
        self.conv1x1_pool4 = nn.Conv2d(512, n_classes, kernel_size=1)
        self.conv1x1_output = nn.Conv2d(4096, n_classes, kernel_size=1)

        # Transposed convolutions for upsampling
        self.upconv2 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=4, stride=2, padding=1)
        self.upconv8 = nn.ConvTranspose2d(n_classes, n_classes, kernel_size=8, stride=8, padding=0)

    def forward(self, x):
        p3 = self.block3(x) # (3, 256, 256) -> (256, 32, 32) # H/8
        p4 = self.block4(p3) # (256, 32, 32) -> (512, 16, 16) # H/16 
        p5 = self.block5(p4) # (512, 16, 16) -> (512, 8, 8)
        p5 = F.relu(self.conv6(p5)) # (512, 8, 8) -> (4096, 8, 8)
        p5 = F.relu(self.conv7(p5)) # (4096, 8, 8) -> (4096, 8, 8)


        # Decoder
        output = self.conv1x1_output(p5)
        output = self.upconv2(output) + self.conv1x1_pool4(p4)
        output = self.upconv2(output) + self.conv1x1_pool3(p3)
        output = self.upconv8(output)



    # def forward(self, x):
    #     x = self.block3(x)
    #     fmap1 = self.conv1x1_pool3(x)
    #     x = self.block4(x)
    #     fmap2 = self.conv1x1_pool4(x)
    #     x = self.block5(x)

    #     x = self.conv6(x)
    #     x = self.conv7(x)
    #     x = self.conv1x1_output(x)

    #     x = self.upconv2(x)
    #     x = x + fmap2
    #     x = self.upconv2(x)
    #     x = x + fmap1
    #     output = self.upconv8(x)

        return output

In [9]:
# 모델 인스턴스 생성
model = FCN8s(n_classes=1).to(device)

input_image = torch.randn(1, 3, 224, 224).to(device)

output = model(input_image)
print(output.shape)



torch.Size([1, 1, 224, 224])


### 학습 루프 설정

In [None]:
criterion = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(params = model.parameters(), lr = 5e-4)
epochs = 20

### 평가 지표

In [None]:
def IoU(output, mask):
    """
    1. pred를 threshold 기준 Binary 분류
    2. pred과 gt 사이 교집합 영역(픽셀) 계산
    3. pred와 gt의 합집합 영역 계산
    4. zero division 방지
    """
    threshold = 0.5
    output = (output > threshold).float()
    intersection = torch.sum(output * mask)
    union = torch.sum(output) + torch.sum(mask) - intersection
    
    return intersection / union + 1e-7
    
    
    
def PA(output, mask):
    """
    1. pred를 threshold 기준 Binary 분류
    2. pred과 gt 일치 여부 계산
    3. 전체 픽셀 중 일치 픽셀 수 반환
    """
    
    threshold = 0.5
    output = (output > threshold).float()
    correct = torch.sum(output==mask) # TP, TN
    total = torch.numel(output)
    
    return correct / total

### 학습 루프 함수

In [None]:
def train_and_validate(model, train_loader, val_loader, optim, criterion, epochs):
    train_losses = []
    train_IoUs = []
    train_PAs = []
    val_losses = []
    val_IoUs = []
    val_PAs = []
    
    for epoch in range(epochs):
        train_loss = 0
        train_IoU = 0
        train_PA = 0
        model.train()
        for img, mask in tqdm(train_loader):
            # YOUR CODE
        
        val_loss = 0
        val_IoU = 0
        val_PA = 0
        model.eval()
        with torch.no_grad():
            for img, mask in tqdm(val_loader):
                # YOUR CODE
            
    return train_losses, train_IoUs, train_PAs, val_losses, val_IoUs, val_PAs


### 학습

In [None]:
train_losses, train_IoUs, train_PAs, val_losses, val_IoUs, val_PAs = train_and_validate(model, train_dl, val_dl, optim, criterion, epochs)

### 로그 시각화

In [None]:
plt.plot(train_losses, label = 'train loss')
plt.plot(val_losses, label = 'val loss')
plt.legend()
plt.show()
plt.plot(train_IoUs, label = 'train IoU')
plt.plot(val_IoUs, label = 'val IoU')
plt.legend()
plt.show()
plt.plot(train_PA, label = 'train Pixel Accuracy')
plt.plot(val_PA, label = 'val Pixel Accuracy')
plt.legend()
plt.show()

In [None]:
def plot_batch(model, data_loader):
    model.eval()
    with torch.no_grad():
        for img, mask in tqdm(data_loader):
            img = img.to(device)
            output = model(img)
            output = (output > 0.5).float()
            img = img.cpu().numpy().transpose(0, 2, 3, 1)
            mask = mask.cpu().numpy().transpose(0, 2, 3, 1)
            output = output.cpu().numpy().transpose(0, 2, 3, 1)
            break
    for i in range(data_loader.batch_size):
        fig, ax = plt.subplots(1, 3, figsize = (15, 5))
        ax[0].imshow(img[i])
        ax[0].set_title('image')
        ax[1].imshow(mask[i], cmap = 'gray')
        ax[1].set_title('mask')
        ax[2].imshow(output[i], cmap = 'gray')
        ax[2].set_title('predicted mask')
        plt.show()

### 결과 비교

In [None]:
plot_batch(model, val_dl)