In [None]:
%ls

In [None]:
%cd /disk1/colonoscopy_dataset/cropped/ADC

# Data Analysis

Data는 총 4 가지가 있다. 
- ADC (ADeno-Carcinoma)
- HGD (High-Grade Dysplasia)
- LGD (Low-Grade Dysplasia)
- NOR (Normal)

데이터의 생김새는 다음과 같다. Raw image와 masking이 된(labeled) image가 있다.
- data: \[phase\]\_IMG_[patient #].jpg
- label: \[phase\]\_MASK_[patient #].jpg

데이터들은 original data에서 검은 부분을 최대한 지운, 즉 cropped 된 상태이다. 
이미지의 사이즈는 다 달라서 resize를 해주어야 한다.

가) ADC의 데이터는 patient의 image와 mask image로 각각 data와 label로 나누어 이해하면 된다. data의 수와 그에 따라 label이 있는지 확인해보자.

나) 3개씩 뽑아서 어떤 형태인지 확인해보자.

In [None]:
# Data visualize를 위한 imports
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
# GPU 한 개만 할당을 위함
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import numpy as np
import pandas as pd
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

pd.set_option('display.max.colwidth', 30)

# 각 이미지들이 속해 있는 경로
ADC_img_path = '/disk1/colonoscopy_dataset/cropped/ADC/'
HGD_img_path = '/disk1/colonoscopy_dataset/cropped/HGD/'
LGD_img_path = '/disk1/colonoscopy_dataset/cropped/LGD/'
NOR_img_path = '/disk1/colonoscopy_dataset/cropped/NOR/'

In [None]:
# 가) 전체 파일 수, ADC data 수와 labeled data 수 확인

ADC_file_list = os.listdir(ADC_img_path) # ADC 전체 파일 목록
HGD_file_list = os.listdir(HGD_img_path) # HGD 전체 파일 목록
LGD_file_list = os.listdir(LGD_img_path) # LGD 전체 파일 목록
NOR_file_list = os.listdir(NOR_img_path) # NOR 전체 파일 목록

ADC_data_list = [x for x in ADC_file_list if 'IMG' in x]        # ADC data 파일 목록
ADC_labeled_list = [x for x in ADC_file_list if 'MASK' in x]    # ADC mask 파일 목록

HGD_data_list = [x for x in HGD_file_list if 'IMG' in x]        # HGD data 파일 목록
HGD_labeled_list = [x for x in HGD_file_list if 'MASK' in x]    # HGD mask 파일 목록

LGD_data_list = [x for x in LGD_file_list if 'IMG' in x]        # LGD data 파일 목록
LGD_labeled_list = [x for x in LGD_file_list if 'MASK' in x]    # LGD mask 파일 목록

NOR_data_list = [x for x in NOR_file_list if 'IMG' in x]        # NOR data 파일 목록
NOR_labeled_list = [x for x in NOR_file_list if 'MASK' in x]    # NOR mask 파일 목록

totalNum_list = [len(ADC_file_list), len(HGD_file_list), len(LGD_file_list), len(NOR_file_list)]
totalData_list = [len(ADC_data_list), len(HGD_data_list), len(LGD_data_list), len(NOR_data_list)]
totalMask_list = [len(ADC_labeled_list), len(HGD_labeled_list), len(LGD_labeled_list), len(NOR_labeled_list)]

data = [totalNum_list, totalData_list, totalMask_list]
table = pd.DataFrame(data = data, index = ['Total #', 'data #', 'mask data #'], columns = ['ADC', 'HGD', 'LGD', 'NOR'])

print(table)

In [None]:
# 나) 3개씩 뽑아서 확인해보자.
print("[1] ADC [2] HGD [3] LGD [4] NOR") # 각 번호에 맞는 데이터를 보여준다.
#num = input()
num = '1'

path_dic = {"1": ADC_img_path, "2": HGD_img_path, "3": LGD_img_path, "4": NOR_img_path}
file_dic = {"1": ADC_data_list, "2": HGD_data_list, "3": LGD_data_list, "4": NOR_data_list}
title_dic = {"1": "ADC", "2": "HGD", "3": "LGD", "4": "NOR"}

rows = 3
columns = 2
plt.rcParams['figure.figsize'] = (8, 8)
random_file_list = (np.random.choice(file_dic[num], 3)).tolist()
match_file_list = [x.replace('IMG', 'MASK') for x in random_file_list]
# np.array(Image.open(os.path.join(self.data_dir, label_path)))
cnt = 0
for a, b in zip(random_file_list, match_file_list):
    # IMG
    cnt += 1
    img = Image.open(os.path.join(path_dic[num], a))
    #print(np.array(Image.open(os.path.join(path_dic[num], a)).resize((572,572))).min())
    ti = title_dic[num] + '_IMG_' + a[-9:-4]
    plt.subplot(rows, columns, cnt)
    plt.title(ti)
    plt.imshow(img)
    cnt += 1
    # MASK
    img = Image.open(os.path.join(path_dic[num], b))
    ti = title_dic[num] + '_IMG_' + b[-9:-4]
    #print(np.array(Image.open(os.path.join(path_dic[num], b)).resize((572,572))).shape)
    plt.subplot(rows, columns, cnt)
    plt.title(ti)
    plt.imshow(img)
plt.show()

# Dataset (train, valid) 준비

우선적으로 ADC에 대해서만 학습할 계획이다.

가) 갖고 있는 504장을 8:2으로 나누어서 dataloader를 구현하자. (train:valid)

In [None]:
import os
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
# 데이터 로더를 구현하기
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None, type=None):
        self.data_dir = data_dir
        self.transform = transform
        self.type = type

        lst_data = os.listdir(self.data_dir)

        lst_label = [f for f in lst_data if 'MASK' in f]
        lst_input = [f for f in lst_data if 'IMG' in f]

        lst_label.sort()
        lst_input.sort()

        target_label = []
        target_input = []

        train_length = int(len(lst_input)*0.8)
        # train mode
        if type == 0:
            target_label = lst_label[:train_length]
            target_input = lst_input[:train_length]
        # test mode
        else:
            target_label = lst_label[train_length:]
            target_input = lst_input[train_length:]
        
        self.lst_label = target_label
        self.lst_input = target_input

    def __len__(self):
        return len(self.lst_label)

    def __getitem__(self, index):
        label = np.asarray(((Image.open(os.path.join(self.data_dir, self.lst_label[index]))).resize((256,256))))
        input = np.asarray((Image.open(os.path.join(self.data_dir, self.lst_input[index])).resize((256,256))))

        # 정규화
        label[label>150] = 255
        label[label<=150] = 0
        
        label = label/255.0
        input = input/255.0

        # 이미지와 레이블의 차원 = 2일 경우(채널이 없을 경우, 흑백 이미지), 새로운 채널(축) 생성
        if label.ndim == 2:
            label = label[:, :, np.newaxis]
        if input.ndim == 2:
            input = input[:, :, np.newaxis]

        data = {'input': input, 'label': label}

        # transform이 정의되어 있다면 transform을 거친 데이터를 불러옴
        if self.transform:
            data = self.transform(data)

        return data

In [None]:
original_data_path = '/disk1/colonoscopy_dataset/cropped/LGD'

# 데이터로더 잘 구현되었는지 확인
dataset_train = Dataset(data_dir=original_data_path, type=0)
data = dataset_train.__getitem__(0) # 한 이미지 불러오기
input = data['input']
label = data['label']

print(input.shape, label.shape)

In [None]:
# # 불러온 이미지 시각화
# plt.subplot(121)
# plt.imshow(label, cmap='gray')
# plt.title('label')

# plt.subplot(122)
# plt.imshow(input)
# plt.title('input')

# # 불러온 이미지 시각화
# plt.subplot(121)
# plt.hist(input.flatten(), bins=50)
# plt.title('input')

# plt.subplot(122)
# plt.hist(label.flatten(), bins=50)
# plt.title('label')

# plt.tight_layout()
# plt.show()

In [None]:
class ToTensor(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        label = label.transpose((2, 0, 1)).astype(np.float32)
        input = input.transpose((2, 0, 1)).astype(np.float32)

        data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}

        return data

class Normalization(object):
    def __init__(self, mean=0.5, std=0.5):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        label, input = data['label'], data['input']

        input = (input - self.mean) / self.std

        data = {'label': label, 'input': input}

        return data

class RandomFlip(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        if np.random.rand() > 0.5:
            label = np.fliplr(label)
            input = np.fliplr(input)

        if np.random.rand() > 0.5:
            label = np.flipud(label)
            input = np.flipud(input)

        data = {'label': label, 'input': input}

        return data

In [None]:
## 네트워크 저장하기
def save(ckpt_dir, model, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    torch.save({'model': model.state_dict(), 'optim': optim.state_dict()},
               "%s/model_epoch%d.pth" % (ckpt_dir, epoch))

## 네트워크 불러오기
def load(ckpt_dir, model, optim):
    if not os.path.exists(ckpt_dir):
        epoch = 0
        return model, optim, epoch

    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    dict_model = torch.load('%s/%s' % (ckpt_dir, ckpt_lst[-1]))

    model.load_state_dict(dict_model['model'])
    optim.load_state_dict(dict_model['optim'])
    epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])

    return model, optim, epoch

In [None]:
# 경로 설정하기
base_dir = '/home/sundongk/Resnet50_Unet'
data_dir = original_data_path
ckpt_dir = os.path.join(base_dir, "checkpoint")
log_dir = os.path.join(base_dir, "log")

In [None]:
# 데이터셋 만들기
batch_size = 4

transform = transforms.Compose([
    ToTensor(),
    #RandomFlip() 
])

dataset_train = Dataset(data_dir=original_data_path, transform=transform, type=0)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)

transform = transforms.Compose([
    ToTensor()
])

dataset_val = Dataset(data_dir=original_data_path, transform=transform, type=1)
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=8)

dataloaders = {
    'train': loader_train,
    'val': loader_val
}

print(loader_train)

# 네트워크 설계

- Resnet50을 backbone으로 하는 U-net

In [None]:
import torch
import torch.nn as nn
import torchvision

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x

class Bridge(nn.Module):    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

class UpBlockForUNetWithResNet50(nn.Module):
    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x

class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=1):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

# 다양한 loss 함수 정의와 학습 함수

In [None]:
from collections import defaultdict
import torch.nn.functional as F
import time

checkpoint_path = os.path.join(base_dir, 'chk')
train_loss = []
val_loss = []

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, st_epoch, num_epochs=25):
    best_loss = 1e10

    for epoch in range(st_epoch+1, num_epochs+1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for batch, data in enumerate(dataloaders[phase]):
                labels = data['label'].to(device)
                inputs = data['input'].to(device)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            if phase == 'train':
                train_loss.append(epoch_loss)
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
            
            # save the model weights
            if phase == 'val':
                val_loss.append(epoch_loss)
                if epoch_loss < best_loss:
                    print(f"saving best model to {ckpt_dir}")
                    best_loss = epoch_loss
                    save(ckpt_dir=checkpoint_path, model=model, optim=optimizer, epoch=epoch)

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    #model.load_state_dict(torch.load(checkpoint_path))
    model, optim, st_epoch = load(ckpt_dir=checkpoint_path, model=model, optim=optimizer)
    return model

## 학습 및 불러오기

In [16]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time

st_epoch = 0
num_class = 1
model = UNetWithResnet50Encoder().to(device)

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)

In [None]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, st_epoch, num_epochs=70)

In [17]:
model, optimizer_ft, st_epoch = load(ckpt_dir=checkpoint_path, model=model, optim=optimizer_ft)
print(st_epoch)

22


# 시각화 및 평가

In [None]:
# Loss 시각화
plt.figure(figsize=(15, 6)) 
plt.subplot(1,2,1)
plt.title('loss')
plt.xlabel('epoch')
plt.plot(train_loss, 'b', label='train loss')
plt.plot(val_loss, 'g', label='val loss')
plt.legend(loc='upper right')

plt.show()

In [18]:
# Visualization
from torchvision.transforms.functional import to_pil_image

transform = transforms.Compose([
    ToTensor()
])

batch_size = 1

dataset_train = Dataset(data_dir=original_data_path, transform=transform, type=0)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)

dataset_val = Dataset(data_dir=original_data_path, transform=transform, type=1)
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8)

def changeTensor(tensor, labelmode=0, threshold=0.1):
    image = tensor.detach().cpu().clone()
    image = image.squeeze(0)
    if labelmode == 1:
        image[image>threshold] = 1
        image[image<=threshold] = 0 
    
    return image

In [19]:
# Pixel Accuracy / Intersection Over Union / Dice coefficient(=F1 score)
def get_PA_IOU_and_DICE(pred, true, TH):
    pred = changeTensor(pred, 1, TH)
    true = changeTensor(true, 1, TH)
    
    pred = pred.detach().cpu().numpy().reshape(-1)
    true = true.detach().cpu().numpy().reshape(-1)
    
    size = len(pred)^2
    pa = (size - np.logical_xor(true, pred).sum()) / size
    
    intersection = np.logical_and(true, pred)
    union = np.logical_or(true, pred)
    
    if np.sum(union) == 0:
        iou = -1
    else:
        iou = np.sum(intersection) / np.sum(union)

    intersection = (pred * true).sum()
    dice = (2.*intersection + 1) / (pred.sum() + true.sum() + 1)

    return pa, iou, dice

In [20]:
TH = [.1, .2, .3, .4, .5, .6, .7, .8, .9]
# Results(various metrics) for Training set
for th in TH:
    with torch.no_grad():
        model.eval()
        acc_PA = []
        acc_IOU = []
        acc_DICE = []
        
        for batch, data in enumerate(loader_val, 1):
            label = data['label'].to(device)
            input = data['input'].to(device)
            output = model(input)
        
            # Calculate PA & IOU & Dice coefficient(=F1 score)
            pa, iou, dice = get_PA_IOU_and_DICE(output, label, th)
            acc_PA.append(pa)
            acc_IOU.append(iou)
            acc_DICE.append(dice)

            # # 불러온 이미지 시각화
            # input_image_example = plt.subplot(1,3,1)
            # input_image_example.set_title('input Image Example')
            # plt.imshow(to_pil_image(changeTensor(input)))

            # label_image_example = plt.subplot(1,3,2)
            # label_image_example.set_title('Label Image Example')
            # plt.imshow(to_pil_image(changeTensor(label)), cmap='gray')

            # test_image_example = plt.subplot(1,3,3)
            # test_image_example.set_title('Test Image Example')
            # plt.imshow(to_pil_image(changeTensor(output, 1)), cmap='gray')

            # plt.show()
        print(np.round(np.mean(acc_PA),4))
        print(np.round(np.mean(acc_IOU),4))
        print(np.round(np.mean(acc_DICE), 4))
        print("-"*10)
            

0.9255
0.751
0.8469
----------
0.9255
0.7509
0.8469
----------
0.9255
0.7509
0.8469
----------
0.9255
0.7509
0.8469
----------
0.9255
0.7508
0.8469
----------
0.9255
0.7508
0.8468
----------
0.9255
0.7508
0.8468
----------
0.9255
0.7508
0.8468
----------
0.9255
0.7507
0.8468
----------


# 네트워크 파라미터 구성

In [None]:
from torchsummary import summary
summary(model, input_size=(3, 512, 512))