In [None]:
# colab
from google.colab import drive
drive.mount('/content/drive')

# Unet - Chest CT Dataset

## 데이터 전처리

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision import models

import os
import numpy as np
import pandas as pd
import shutil
import cv2
import matplotlib.pyplot as plt

from IPython import display
from PIL import Image

### kaggle datasets Download
- 데이터 : Chest CT Segmentation(Chest CT scans together with segmentation masks for lung, heart, and trachea)
- 캐글 데이터 주소: https://www.kaggle.com/datasets/polomarco/chest-ct-segmentation

In [None]:
!pip install kaggle --upgrade

In [None]:
from google.colab import files
files.upload()

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d polomarco/chest-ct-segmentation

In [None]:
!unzip -qq '/content/chestct/chest-ct-segmentation.zip'
%cd ..

In [None]:
data_raw = pd.read_csv('/content/chestct/train.csv')
data_raw.head()

In [None]:
data_raw.shape

In [None]:
def get_id(x):
    return x.split('_')[0]

data_raw['id'] = data_raw.ImageId.apply(lambda x: get_id(x))
data_raw.head()

In [None]:
cli_ids = data_raw.id.unique()
len(cli_ids)

In [None]:
print(len(os.listdir('/content/chestct/images/images')), len(os.listdir('/content/chestct/masks/masks')))

In [None]:
cli_id = -1
cli_data = data_raw[data_raw.id == cli_ids[cli_id]]
cli_data

In [None]:
def get_cli_data(data_raw, cli_id):
    cli_ids = data_raw.id.unique()
    cli_data = data_raw[data_raw.id == cli_ids[cli_id]]
    
    image_file = cli_data.imageId.values
    mask_file = cli_data.Maskid.values
    id_file = cli_data.id.values[0]

    return id_file, image_file, mask_file

In [None]:
data_dir = ''

In [None]:
idx = 0
cli_id, image_files, mask_files = get_cli_data(data_raw, idx)
canvas = np.zeros(shape=(512, 512*2+50, 3), dtype=np.uint8)

for i in range(len(image_files)):
    image = cv2.imread(os.path.join(data_dir, 'images', image_files[i]))
    mask = cv2.imread(os.path.join(data_dir, 'images', mask_files[i]))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    mask[mask < 240] = 0
    mask[mask <= 240] = 255

    canvas[:, :512, :] = image
    canvas[:, 512:512*2+50, :] = mask

    cv2.imshow('image', canvas)
    key = cv2.waitKey(1)
    if key == 27:
        break
    if key == ord('s'):
        cv2.waitkey(0)
cv2.destroyAllWindows()

## 데이터셋 구축


- 이미지와 마스크 파일이 하나로 묶여 있어, train, val, test 세트로 분리를 필요

In [None]:
!mkdir data

In [None]:
%cd data

- 디렉토리 만들기

In [None]:
!mkdir 'train' 'val' 'test'

In [None]:
!mkdir 'train/images' 'train/masks' 'val/images' 'val/masks' 'test/images' 'test/masks'

- 'id'를 기준으로 분리- train, val, test 비율: 0.8, 0.1, 0.1
- 'id'는 총 112개

In [None]:
%cd ..
len(cli_ids)

In [None]:
# 데이터셋 나누기
split_ratio = [0.8, 0.1, 0.1]
train_len = int(len(cli_ids) * split_ratio[0])
val_len = int(len(cli_ids) * split_ratio[1])
test_len = len(cli_ids) - train_len - val_len

print('{}, {}, {}'.format(train_len, val_len, test_len))

train_ids = []
val_ids = []
test_ids = []

for i in range(len(cli_ids)):
    if 0 <= i < train_len:
        train_ids.append(cli_ids[i])
    elif train_len <= i < train_len + val_len:
        val_ids.append(cli_ids[i])
    elif train_len + val_len <= i:
        test_ids.append(cli_ids[i])

print('{}, {}, {}'.format(len(train_ids), len(val_ids), len(test_ids)))

- 각각 이미지에서 나눈 'id'를 기준으로 각 디렉토리에 file 복사하기

In [None]:
data_dir = '/content/data'

In [None]:
# train/image
to_file_path = '/content/data/train/images/'
from_file_path = '/content/chestct/images/images/'

for file_name in os.listdir('/content/chestct/images/images'):
    for id in train_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

In [None]:
#train/mask
to_file_path = '/content/data/train/masks/'
from_file_path = '/content/chestct/masks/masks/'

for file_name in os.listdir('/content/chestct/masks/masks'):
    for id in train_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

In [None]:
# val/image
to_file_path = '/content/data/val/images/'
from_file_path = '/content/chestct/images/images/'

for file_name in os.listdir('/content/chestct/images/images'):
    for id in val_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

In [None]:
#val/mask
to_file_path = '/content/data/val/masks/'
from_file_path = '/content/chestct/masks/masks/'

for file_name in os.listdir('/content/chestct/masks/masks'):
    for id in val_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

In [None]:
# test/image
to_file_path = '/content/data/test/images/'
from_file_path = '/content/chestct/images/images/'

for file_name in os.listdir('/content/chestct/images/images'):
    for id in test_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

In [None]:
#test/mask
to_file_path = '/content/data/test/masks/'
from_file_path = '/content/chestct/masks/masks/'

for file_name in os.listdir('/content/chestct/masks/masks'):
    for id in test_ids:
        if file_name.startswith(id):
            shutil.copyfile(from_file_path + file_name, to_file_path + file_name)

- 잘 나누었는지 확인

In [None]:
# 세트로 잘 복사되었는지 확인
print('train image:',len(os.listdir('/content/data/train/images')), '\ttrain masks: ',len(os.listdir('/content/data/train/masks')))
print('valid image:',len(os.listdir('/content/data/val/images')), '\tvalid masks: ',len(os.listdir('/content/data/val/masks')))
print('test  image:',len(os.listdir('/content/data/test/images')), '\ttest  masks: ',len(os.listdir('/content/data/test/masks')))

In [None]:
data_dir = '/content/data'

In [1]:
class MyDataset():
    def __init__(self, data_dir, phase, transform=None):
        self.phase = phase
        self.images_dir = os.path.join(data_dir, phase, 'images')
        self.masks_dir = os.path.join(data_dir, phase, 'masks')
        self.image_files = [file_name for file_name in os.listdir(self.image_dir) if file_name.endshith('jpg')]
        self.mask_files = [file_name for file_name in os.listdir(self.masks_dir) if file_name.endshith('jpg')]
        assert len(self.image_files) == len(self.mask_files)

        self.transform = transform

    def __len__(self,):
        return len(self.image_files)

    def __getitem__(self, index):
        image = cv2.imread(os.path.join(self.images_dir, self.image_files[index]))
        image = cv2.resize(image, dsize=(224, 224), interpolation=cv2.INTER_LINEAR)
        mask = cv2.imread(os.path.join(self.masks_dir, self.mask_files[index]))
        mask = cv2.resize(mask, dsize=(224, 224), interpolation=cv2.INTER_LINEAR)

        mask[mask < 240] = 0
        mask[mask >= 240] = 255
        mask = mask / 255.

        mask_H, mask_W, mask_C = mask.shape
        background = np.ones(shape=(mask_H, mask_W))
        background[mask[..., 0] != 0] = 0
        background[mask[..., 1] != 0] = 0
        background[mask[..., 2] != 0] = 0

        mask = np.concatenate([np.expand_dims(background, axis=-1), mask], axis=-1)
        mask = np.argmax(mask, axis=-1, keepdims=False)

        if self.transform:
            image = self.transform(image)
        
        target = torch.from_numpy(mask).long()

        return image, target

In [4]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

In [5]:
def collate_fn(batch):
    images = []
    targets = []
    for a, b in batch:
        images.append(a)
        targets.append(b)
    images = torch.stack(images, dim=0)
    targets = torch.stack(targets, dim=0)

    return images, targets

In [6]:
def build_dataloader(data_dir, batch_size=4):
    dataloaders = []

    train_dataset = MyDataset(data_dir=data_dir, phase='train', transform=transform)
    dataloaders['train'] = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) # shuffle = False

    val_dataset = MyDataset(data_dir=data_dir, phase='val', transform=transform)
    dataloaders['val'] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return dataloaders

In [None]:
dataloaders = build_dataloader(data_dir=data_dir)

for phase in ['train', 'val']:
    for idx, data in enumerate(dataloaders[phase]):
        images = data[0]
        targets = data[1]
        print(f'Image shape :  {images.shape}\tmask shape :  {mask.shape}')

        break

## VGG16 Backbone 활용하여 UNET 구현

In [26]:
def Conv_Layer(in_channels, out_channels, kernel_size=3, padding=1):
    layers = nn.Sequential(
                            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True),

                            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True)
                            )
    return layers

In [54]:
def UpConv_Layer(in_channels, out_channels):
    layers = nn.Sequential(
                            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True)
    )
    return layers

In [55]:
vgg16 = models.vgg16_bn(weights=False)
vgg16



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [56]:
class Encoder(nn.Module):
    def __init__(self, weights):
        super().__init__()

        backbone = models.vgg16_bn(weights=weights).features
        self.conv_block1 = nn.Sequential(*backbone[0:6])
        self.conv_block2 = nn.Sequential(*backbone[6:13])
        self.conv_block3 = nn.Sequential(*backbone[13:20])
        self.conv_block4 = nn.Sequential(*backbone[20:27])
        self.conv_block5 = nn.Sequential(*backbone[27:34], Conv_Layer(512, 1024, kernel_size=1, padding=0))
    
    def forward(self, x):
        encode_features = []
        
        out = self.conv_block1(x)
        encode_features.append(out)

        out = self.conv_block2(out)
        encode_features.append(out)

        out = self.conv_block3(out)
        encode_features.append(out)

        out = self.conv_block4(out)
        encode_features.append(out)

        out = self.conv_block5(out)

        return out, encode_features

In [57]:
encoder = Encoder(weights=False)
x = torch.randn(1, 3, 224, 224)
out, ft = encoder(x)

In [58]:
for i in ft:
    print(i.shape)
print(out.shape)

torch.Size([1, 64, 224, 224])
torch.Size([1, 128, 112, 112])
torch.Size([1, 256, 56, 56])
torch.Size([1, 512, 28, 28])
torch.Size([1, 1024, 14, 14])


In [59]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.upconv_layer1 = UpConv_Layer(in_channels=1024, out_channels=512)
        self.conv_block1 = Conv_Layer(in_channels=512*2, out_channels=512)

        self.upconv_layer2 = UpConv_Layer(in_channels=512, out_channels=256)
        self.conv_block2 = Conv_Layer(in_channels=256*2, out_channels=256)

        self.upconv_layer3 = UpConv_Layer(in_channels=256, out_channels=128)
        self.conv_block3 = Conv_Layer(in_channels=128*2, out_channels=128)
        
        self.upconv_layer4 = UpConv_Layer(in_channels=128, out_channels=64)
        self.conv_block4 = Conv_Layer(in_channels=64*2, out_channels=64)

    def forward(self, x, encode_features):
        out = self.upconv_layer1(x)
        out = torch.cat([out, encode_features[-1]], dim=1)
        out = self.conv_block1(out)

        out = self.upconv_layer2(out)
        out = torch.cat([out, encode_features[-2]], dim=1)
        out = self.conv_block2(out)

        out = self.upconv_layer3(out)
        out = torch.cat([out, encode_features[-3]], dim=1)
        out = self.conv_block3(out)

        out = self.upconv_layer4(out)
        out = torch.cat([out, encode_features[-4]], dim=1)
        out = self.conv_block4(out)

        return out

In [60]:
encoder = Encoder(weights=False)
decoder = Decoder()
x = torch.rand(1, 3, 224, 224)
out, ftrs = encoder(x)
out = decoder(out, ftrs)

In [62]:
out.shape

torch.Size([1, 64, 224, 224])

In [66]:
class UNET(nn.Module):
    def __init__(self, num_classes, weights):
        super().__init__()
        self.encoder = Encoder(weights=weights)
        self.decoder = Decoder()
        self.head = nn.Conv2d(64, num_classes, kernel_size=1, padding=0)
    
    def forward(self, x):
        out, encode_features = self.encoder(x)
        out = self.decoder(out, encode_features)
        out = self.head(out)

        return out

In [67]:
model = UNET(num_classes=4, weights=False)
x = torch.randn(1, 3, 224, 224)
out = model(x)
out.shape

torch.Size([1, 4, 224, 224])