# Data 분석
Data는 총 4 가지가 있다. 
- ADC (ADeno-Carcinoma)
- HGD (High-Grade Dysplasia)
- LGD (Low-Grade Dysplasia)
- NOR (Normal)

데이터의 생김새는 다음과 같다. Raw image와 masking이 된(labeled) image가 있다.
- data: <font color=yellow>\[phase\]\_IMG_[patient #].jpg</font>
- label: <font color=yellow>\[phase\]\_MASK_[patient #].jpg</font>

데이터들은 original data에서 검은 부분을 최대한 지운, 즉 cropped 된 상태이다. 
이미지의 사이즈는 다 달라서 resize를 해주어야 한다.

In [None]:
# Data visualize를 위한 imports
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
# GPU 한 개만 할당을 위함
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from configparser import Interpolation
from saveLoad import save, load
from torchvision.transforms.functional import to_pil_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

pd.set_option('display.max.colwidth', 30)

# 경로 설정하기
original_data_path = '/disk1/colonoscopy_dataset/cropped/' # ADC / HGD / LGD / NOR
base_dir = '/home/sundongk/Multiclass_segmentation_2'
data_dir = original_data_path
ckpt_dir = os.path.join(base_dir, "checkpoint")

# 각 이미지들이 속해 있는 경로
ADC_img_path = '/disk1/colonoscopy_dataset/cropped/ADC/'
HGD_img_path = '/disk1/colonoscopy_dataset/cropped/HGD/'
LGD_img_path = '/disk1/colonoscopy_dataset/cropped/LGD/'
NOR_img_path = '/disk1/colonoscopy_dataset/cropped/NOR/'

In [None]:
# 가) 전체 파일 수, ADC data 수와 labeled data 수 확인

ADC_file_list = os.listdir(ADC_img_path) # ADC 전체 파일 목록
HGD_file_list = os.listdir(HGD_img_path) # HGD 전체 파일 목록
LGD_file_list = os.listdir(LGD_img_path) # LGD 전체 파일 목록
NOR_file_list = os.listdir(NOR_img_path) # NOR 전체 파일 목록

ADC_data_list = sorted([x for x in ADC_file_list if 'IMG' in x])        # ADC data 파일 목록
ADC_labeled_list = sorted([x for x in ADC_file_list if 'MASK' in x])   # ADC mask 파일 목록

HGD_data_list = sorted([x for x in HGD_file_list if 'IMG' in x])        # HGD data 파일 목록
HGD_labeled_list = sorted([x for x in HGD_file_list if 'MASK' in x])    # HGD mask 파일 목록

LGD_data_list = sorted([x for x in LGD_file_list if 'IMG' in x])        # LGD data 파일 목록
LGD_labeled_list = sorted([x for x in LGD_file_list if 'MASK' in x])    # LGD mask 파일 목록

NOR_data_list = sorted([x for x in NOR_file_list if 'IMG' in x])        # NOR data 파일 목록
NOR_labeled_list = sorted([x for x in NOR_file_list if 'MASK' in x])    # NOR mask 파일 목록

totalNum_list = [len(ADC_file_list), len(HGD_file_list), len(LGD_file_list), len(NOR_file_list)]
totalData_list = [len(ADC_data_list), len(HGD_data_list), len(LGD_data_list), len(NOR_data_list)]
totalMask_list = [len(ADC_labeled_list), len(HGD_labeled_list), len(LGD_labeled_list), len(NOR_labeled_list)]

data = [totalData_list, totalMask_list, totalNum_list]
table = pd.DataFrame(data = data, index = ['data #', 'mask data #', 'Total #'], columns = ['ADC', 'HGD', 'LGD', 'NOR'])

total_img_name_list = [ADC_data_list, HGD_data_list, LGD_data_list, NOR_data_list]
total_label_name_list = [ADC_labeled_list, HGD_labeled_list, LGD_labeled_list, NOR_labeled_list]

print(table)

# Dataset 만들기

- Train, validation, test를 각각 8 : 1 : 1의 비율로 나누어 만든다. 
  - 총 데이터: 3004장 = (2403, 300, 301)
  - Train, validation의 batch size는 4로 하고, test의 batch size는 1로 한다.

- 전처리 사항
  - 256*256으로 resize + Bilinear interpolation
  - random affine (shear=10, scale=(0.8, 1.2))
  - random horizontal flip()
  - 전체 데이터의 평균과 표준편차로 normalize (mean=[0.590, 0.351, 0.259], std=[0.241, 0.199, 0.163])

In [None]:
# 8 : 1 : 1
ADC_train_length = int(len(ADC_data_list)*0.8) # 403
HGD_train_length = int(len(HGD_data_list)*0.8) # 400
LGD_train_length = int(len(LGD_data_list)*0.8) # 800
NOR_train_length = int(len(NOR_data_list)*0.8) # 800 ----- total 2403

ADC_valid_length = int(len(ADC_data_list)*0.1) # 50
HGD_valid_length = int(len(HGD_data_list)*0.1) # 50
LGD_valid_length = int(len(LGD_data_list)*0.1) # 100
NOR_valid_length = int(len(NOR_data_list)*0.1) # 100 ----- total 300

ADC_test_length = len(ADC_data_list) - ADC_train_length - ADC_valid_length # 51
HGD_test_length = len(HGD_data_list) - HGD_train_length - HGD_valid_length # 50
LGD_test_length = len(LGD_data_list) - LGD_train_length - LGD_valid_length # 100
NOR_test_length = len(NOR_data_list) - NOR_train_length - NOR_valid_length # 100 ----- total 301

In [None]:
# Dataloader 구현
class Dataset(Dataset):
    def __init__(self, data_dir, transform=None, type=None):
        self.data_dir = data_dir; self.transform = transform; self.type = type

        lst_label = []; lst_input = []; target_input = []; target_label = []

        # if type=0, train mode
        if type == 0:
            target_input = (ADC_data_list[:ADC_train_length] + HGD_data_list[:HGD_train_length] 
                                + LGD_data_list[:LGD_train_length] + NOR_data_list[:NOR_train_length])
            target_label = (ADC_labeled_list[:ADC_train_length] + HGD_labeled_list[:HGD_train_length]
                                + LGD_labeled_list[:LGD_train_length] + NOR_labeled_list[:NOR_train_length])
        
        # if type=1, valid mode
        elif type == 1:
            target_input = (ADC_data_list[ADC_train_length : ADC_train_length + ADC_valid_length] 
                                + HGD_data_list[HGD_train_length : HGD_train_length + HGD_valid_length] 
                                + LGD_data_list[LGD_train_length : LGD_train_length + LGD_valid_length] 
                                + NOR_data_list[NOR_train_length : NOR_train_length + NOR_valid_length])
            target_label = (ADC_labeled_list[ADC_train_length : ADC_train_length + ADC_valid_length] 
                                + HGD_labeled_list[HGD_train_length : HGD_train_length + HGD_valid_length]
                                + LGD_labeled_list[LGD_train_length : LGD_train_length + LGD_valid_length] 
                                + NOR_labeled_list[NOR_train_length : NOR_train_length + NOR_valid_length])

        # if type=2, test mode
        else:
            target_input = (ADC_data_list[ADC_train_length+ADC_valid_length:] 
                                + HGD_data_list[HGD_train_length + HGD_valid_length:] 
                                + LGD_data_list[LGD_train_length + LGD_valid_length:] 
                                + NOR_data_list[NOR_train_length + NOR_valid_length:])
            target_label = (ADC_labeled_list[ADC_train_length + ADC_valid_length:] 
                                + HGD_labeled_list[HGD_train_length + HGD_valid_length:]
                                + LGD_labeled_list[LGD_train_length + LGD_valid_length:] 
                                + NOR_labeled_list[NOR_train_length + NOR_valid_length:])

        self.lst_label = target_label
        self.lst_input = target_input
    
    def __len__(self):
        return len(self.lst_label)
    
    def __getitem__(self, index):
        datapath = ''
        if 'ADC' in self.lst_input[index]: # 0
            datapath = os.path.join(self.data_dir, 'ADC')
        elif 'HGD' in self.lst_input[index]: # 1
            datapath = os.path.join(self.data_dir, 'HGD')
        elif 'LGD' in self.lst_input[index]: # 2
            datapath = os.path.join(self.data_dir, 'LGD')
        elif 'NOR' in self.lst_input[index]: # 3
            datapath = os.path.join(self.data_dir, 'NOR')
        
        input = Image.open(os.path.join(datapath, self.lst_input[index])).convert('RGB')
        label = np.asarray(Image.open(os.path.join(datapath, self.lst_label[index])).convert('L').resize((256,256)))

        # 정규화
        label[label>150] = 255
        label[label<=150] = 0
        
        label = label/255.0


        if label.ndim == 2:
            label = label[:, :, np.newaxis]

        label = label.transpose((2, 0, 1)).astype(np.float32)
        
        if self.transform is not None:
            input = self.transform(input)

        data = {'input': input, 'label': torch.from_numpy(label)}

        return data

In [None]:
batch_size = 4; test_batch_size = 1

data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((256,256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                          std=[0.229, 0.224, 0.225])
        transforms.Normalize(mean=[0.590, 0.351, 0.259],
                                 std=[0.241, 0.199, 0.163])
    ]),
    'validation':
    transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                          std=[0.229, 0.224, 0.225])
        transforms.Normalize(mean=[0.590, 0.351, 0.259],
                                 std=[0.241, 0.199, 0.163])
    ]),
    'test':
    transforms.Compose([
        transforms.Resize((256,256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                          std=[0.229, 0.224, 0.225])
        transforms.Normalize(mean=[0.590, 0.351, 0.259],
                                 std=[0.241, 0.199, 0.163])
    ])
}

image_datasets = {
    'train': Dataset(data_dir=original_data_path, transform=data_transforms['train'], type=0),
    'validation': Dataset(data_dir=original_data_path, transform=data_transforms['validation'], type=1),
    'test': Dataset(data_dir=original_data_path, transform=data_transforms['test'], type=2) 
}

dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=8),
    'validation': DataLoader(image_datasets['validation'], batch_size=batch_size, shuffle=False, num_workers=8),
    'test': DataLoader(image_datasets['test'], batch_size=test_batch_size, shuffle=False, num_workers=8)
}

# 네트워크 설계

- Resnet50을 backbone으로 하는 U-net

In [None]:
import torch
import torch.nn as nn
import torchvision

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x

class Bridge(nn.Module):    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

class UpBlockForUNetWithResNet50(nn.Module):
    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x

class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=4):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

In [None]:
from torchsummary import summary

model = UNetWithResnet50Encoder().to(device)
summary(model, (3, 256, 256))

# 다양한 loss 함수 정의와 학습 함수

In [None]:
from collections import defaultdict
import torch.nn.functional as F
import time
import pickle

checkpoint_path = os.path.join(base_dir, 'chk')
train_loss = []; val_loss = []

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, st_epoch, num_epochs=25):
    best_loss = 1e10
    if len(os.listdir(os.path.join(base_dir, 'loss'))) != 0:
        loss_info_path = os.path.join(base_dir, 'loss/data_dict.pkl')
        with open(loss_info_path, 'rb') as f:
            mydict = pickle.load(f)
        mydict['val_loss'][-1] = best_loss

    for epoch in range(st_epoch+1, num_epochs+1):
        print('Epoch {}/{}'.format(epoch, num_epochs)); print('-' * 10)


        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for batch, data in enumerate(dataloaders[phase]):
                if batch % 100 == 0:
                    print('{}/{}'.format(batch, len(dataloaders[phase])))
                labels = data['label'].to(device)
                inputs = data['input'].to(device)
                
                outputs = model(inputs)
                tmp = F.softmax(outputs,1)
                tmp = torch.sum(tmp[:, :2, :, :], dim=1)
                tmp = tmp.unsqueeze(1)

                loss = calc_loss(tmp, labels, metrics)

                if phase == 'train':
                    optimizer.zero_grad() # zero the parameter gradients
                    loss.backward()
                    optimizer.step()
                
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            

            if phase == 'train':
                train_loss.append(epoch_loss)
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
            
            # save the model weights
            if phase == 'validation':
                val_loss.append(epoch_loss)
                if epoch_loss < best_loss:
                    print(f"saving best model to {checkpoint_path}")
                    best_loss = epoch_loss
                    save(ckpt_dir=checkpoint_path, model=model, optim=optimizer, epoch=epoch)

        time_elapsed = time.time() - since
        print('{} loss: {:.4f}'.format(phase, epoch_loss))
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        
    print('Best val loss: {:4f}'.format(best_loss))

    model, optim, st_epoch = load(ckpt_dir=checkpoint_path, model=model, optim=optimizer)
    return model

## 학습 및 불러오기

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time

st_epoch = 0
num_class = 4

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)

In [None]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, st_epoch, num_epochs=100)

In [None]:
model, optimizer_ft, st_epoch = load(ckpt_dir=checkpoint_path, model=model, optim=optimizer_ft)
print(st_epoch)

## Loss 및 Accuracy 저장

In [None]:
import pickle
# loss 및 acc 정보 저장
loss_dict = {}
loss_dict['train_loss'] = train_loss; loss_dict['val_loss'] = val_loss

# loss_dict 저장
with open('data_dict.pkl','wb') as f:
    pickle.dump(loss_dict,f)

## Loss 시각화

In [None]:
import pickle
# loss_dict 불러오기 os.path.join(base_dir, 'loss/data_dict.pkl')
loss_info_path = os.path.join(base_dir, 'loss/data_dict.pkl')
with open(loss_info_path,'rb') as f:
    mydict = pickle.load(f)

train_loss = mydict['train_loss']
val_loss = mydict['val_loss']

In [None]:
# Loss 시각화
plt.figure(figsize=(15, 6)) 
plt.subplot(1,2,1)
plt.ylabel('loss'); plt.xlabel('epoch')
plt.plot(train_loss, 'b', label='train loss')
plt.plot(val_loss, 'g', label='val loss')
plt.legend(loc='best')
plt.show()

## 평가 함수

In [None]:
def changeTensor(tensor, labelmode=0, threshold=0.1):
    image = tensor.detach().cpu().clone()
    image = image.squeeze(0)
    if labelmode == 1:
        image[image>threshold] = 1
        image[image<=threshold] = 0 
    
    return image
    
# overlay images
def do_overlay(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

In [None]:
# Pixel Accuracy / Intersection Over Union / Dice coefficient(=F1 score)
def get_PA_IOU_and_DICE(pred, true, TH):
    pred = changeTensor(pred, 1, TH)
    true = changeTensor(true, 1, TH)
    
    pred = pred.detach().cpu().numpy().reshape(-1)
    true = true.detach().cpu().numpy().reshape(-1)
    
    size = len(pred)^2
    pa = (size - np.logical_xor(true, pred).sum()) / size
    
    intersection = np.logical_and(true, pred)
    union = np.logical_or(true, pred)
    
    if np.sum(union) == 0:
        iou = -1
    else:
        iou = np.sum(intersection) / np.sum(union)

    intersection = (pred * true).sum()
    dice = (2.*intersection + 1) / (pred.sum() + true.sum() + 1)

    return pa, iou, dice

## 평가

In [None]:
# Segmentation 결과 평가
TH = [.1, .2, .3, .4, .5, .6, .7, .8, .9]
# Results(various metrics) for Training set
#TH = [.1]
print(len(dataloaders['test']))
for th in TH:
    with torch.no_grad():
        model.eval()
        acc_PA = []
        acc_IOU = []
        acc_DICE = []
        
        for batch, data in enumerate(dataloaders['test'], 1):
            label = data['label'].to(device)
            input = data['input'].to(device)
            output = model(input)

            output = torch.sum(output[:, :2, :, :], dim=1)
            
            # # denormalize to show original image
            # # mean=[0.590, 0.351, 0.259], std=[0.241, 0.199, 0.163]
            # og = input.cpu().squeeze(0).permute(1,2,0).numpy()
            # og[:,:,0] = ((og[:,:,0]) * 0.241) + 0.590
            # og[:,:,1] = ((og[:,:,1]) * 0.199) + 0.351
            # og[:,:,2] = ((og[:,:,2]) * 0.163) + 0.259

            # # Calculate PA & IOU & Dice coefficient(=F1 score)
            pa, iou, dice = get_PA_IOU_and_DICE(output, label, th)
            acc_PA.append(pa)
            acc_IOU.append(iou)
            acc_DICE.append(dice)

            # plt.figure(figsize=(8,8))
            # # 불러온 이미지 시각화
            # input_image_example = plt.subplot(1,3,1)
            # input_image_example.set_title('input Image')
            # plt.imshow(og)

            # label_image_example = plt.subplot(1,3,2)
            # label_image_example.set_title('Label Image')
            # plt.imshow(to_pil_image(changeTensor(label)))

            # test_image_example = plt.subplot(1,3,3)
            # test_image_example.set_title('Prediction Image')
            # plt.imshow(to_pil_image(changeTensor(output, 1)))

            # plt.show()
            # plt.close('all')
            # plt.clf()

        del acc_IOU[202:]
        del acc_DICE[202:]
        print(np.round(np.mean(acc_PA),4))
        print(np.round(np.mean(acc_IOU),4))
        print(np.round(np.mean(acc_DICE), 4))
        print("-"*10)

# Activation map 확인

In [None]:
# Conv Activation map 결과 평가
TH = [.1]
for th in TH:
    with torch.no_grad():
        model.eval()
        
        for batch, data in enumerate(dataloaders['test'], 1):
            if batch < 10:
                continue
            if batch == 15:
                break
            label = data['label'].to(device)
            input = data['input'].to(device)
            output = model(input)
            print(output.shape)
            output = output.cpu().squeeze(0).permute(1,2,0).numpy()
            tmp = np.dsplit(output, 4)
            result = cv2.addWeighted(tmp[0], 0.5, tmp[1], 0.5, 0)
            result = cv2.addWeighted(result, 0.5, tmp[2], 0.5, 0)
            result = cv2.addWeighted(result, 0.5, tmp[3], 0.5, 0)
            
            # denormalize to show original image
            # mean=[0.590, 0.351, 0.259], std=[0.241, 0.199, 0.163]
            og = input.cpu().squeeze(0).permute(1,2,0).numpy()
            og[:,:,0] = ((og[:,:,0]) * 0.241) + 0.590
            og[:,:,1] = ((og[:,:,1]) * 0.199) + 0.351
            og[:,:,2] = ((og[:,:,2]) * 0.163) + 0.259

            plt.figure(figsize=(8,8))
            # 불러온 이미지 시각화
            input_image_example = plt.subplot(1,3,1)
            input_image_example.set_title('input Image')
            plt.imshow(og)

            label_image_example = plt.subplot(1,3,2)
            label_image_example.set_title('Label Image')
            plt.imshow(to_pil_image(changeTensor(label)))

            test_image_example = plt.subplot(1,3,3)
            test_image_example.set_title('Prediction Image')
            # plt.imshow(result)
            plt.imshow(output[:, :, 0], 'Reds_r')
            # plt.imshow(output[:, :, 1], 'Reds_r')
            # plt.imshow(output[:, :, 2], 'Reds_r')
            # plt.imshow(output[:, :, 3], 'Reds_r')


            plt.show()
            plt.close('all')
            plt.clf()

# 네트워크 파라미터 구성

In [None]:
from torchsummary import summary
summary(model, input_size=(3, 256, 256))

In [None]:
#r = Image.open(os.path.join('./', 'Prediction_' + str(batch) + '.png'))
            #plt.imshow(r)

            # plt.imshow(to_pil_image(changeTensor(output, 1)))

            # change_example = plt.subplot(1,4,4)
            # change_example.set_title('Change Image')
            # r = Image.open(os.path.join('./', 'Prediction_' + str(batch) + '.png'))
            # plt.imshow(r)

            plt.show()
            plt.close('all')
            plt.clf()
            
            # # 이미지 저장
            # img = changeTensor(output,1).cpu().numpy()
            # img = np.swapaxes(img, 0, 1)
            # img = np.swapaxes(img, 1, 2)
            
            # for x in img:
            #     for y in x:
            #         # 흰색은 빨간색으로
            #         if y[0] == 1 and y[1] == 1 and y[2] == 1:
            #             y[0] = 1
            #             y[1] = y[2] = 0

            #         elif y[0] == 1 and y[1] == 0 and y[2] == 0:
            #             y[0] = y[2] = 0
            #             y[1] = 1
            
            # plt.imsave('Prediction_' + str(batch) +'.png', img)

        del acc_IOU[202:]
        del acc_DICE[202:]
        print(np.round(np.mean(acc_PA),4))
        print(np.round(np.mean(acc_IOU),4))
        print(np.round(np.mean(acc_DICE), 4))
        print("-"*10)