In [None]:
# 標準モジュール(install不要)
import os
import math
import random
from datetime import datetime

# myself
from config import setting
from module import const
from module import image_loader  # get_train_transform, LoadDataSet

import torch

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torch import nn
from tqdm import tqdm as tqdm

import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

const.CHECKPOINT_PATH_UNet = setting.const.CHECKPOINT_PATH + "/UNet"

In [49]:
const.IMG_HEIGHT = 256
const.IMG_WIDTH = 256

# const.NUM_EPOCHS = 300
const.NUM_EPOCHS = 10

In [None]:
# データセットを確認する
train_dataset = image_loader.LoadDataSet(
    setting.const.TRAIN_PATH,
    const.IMG_HEIGHT,
    const.IMG_WIDTH,
    transform=image_loader.get_train_transform(const.IMG_HEIGHT, const.IMG_WIDTH),
)

# 辞書型のときに要素を取得するマジックメソッド. 以下と同じ意味
# image, mask = train_dataset.__getitem__(0)
image, mask = train_dataset[0]
print(image.shape)
print(mask.shape)

# Print total number of unique images.
# フォルダーの長さを表示. 以下と同じ意味
train_dataset.__len__()
print(len(train_dataset))

In [None]:
def format_image(img):
    img = np.array(np.transpose(img, (1, 2, 0)))
    # 下は画像拡張での正規化を元に戻しています
    mean = np.array((0.485, 0.456, 0.406))
    std = np.array((0.229, 0.224, 0.225))
    img = std * img + mean
    img = img * 255
    img = img.astype(np.uint8)
    return img


def format_mask(mask):
    mask = np.squeeze(np.transpose(mask, (1, 2, 0)))
    # TODO:なぜかLoadDataSetで反転する
    mask = np.rot90(mask, k=3)
    mask = np.flip(mask, axis=1)
    return mask


def visualize_dataset(n_images, num_range, predict=None):
    # TODO:表示時に水平・垂直クリップがかるからかからないようにする(Predictedのほうも)
    images = random.sample(range(0, num_range), n_images)
    figure, ax = plt.subplots(nrows=len(images), ncols=2, figsize=(5, 8))
    print(images)
    for i in range(0, len(images)):
        img_no = images[i]
        image, mask = train_dataset[i]
        image = format_image(image)
        mask = format_mask(mask)
        ax[i, 0].imshow(image)
        ax[i, 1].imshow(mask, interpolation="nearest", cmap="gray")
        ax[i, 0].set_title(f"Input Image No.{img_no+1}")
        ax[i, 1].set_title(f"Label Mask No.{img_no+1}")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
    plt.tight_layout()
    plt.show()


num_range = len(train_dataset)
visualize_dataset(3, num_range)

In [None]:
split_ratio = 0.25
train_size = int(np.round(train_dataset.__len__() * (1 - split_ratio), 0))
valid_size = int(np.round(train_dataset.__len__() * split_ratio, 0))
train_data, valid_data = random_split(train_dataset, [train_size, valid_size])
train_loader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(dataset=valid_data, batch_size=10)

print("Length of train data: {}".format(len(train_data)))
print("Length of validation data: {}".format(len(valid_data)))

In [42]:
# UNet
class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        # 資料中の『FCN』に当たる部分
        self.conv1 = conv_bn_relu(input_channels, 64)
        self.conv2 = conv_bn_relu(64, 128)
        self.conv3 = conv_bn_relu(128, 256)
        self.conv4 = conv_bn_relu(256, 512)
        self.conv5 = conv_bn_relu(512, 1024)
        self.down_pooling = nn.MaxPool2d(2)

        # 資料中の『Up Sampling』に当たる部分
        self.up_pool6 = up_pooling(1024, 512)
        self.conv6 = conv_bn_relu(1024, 512)
        self.up_pool7 = up_pooling(512, 256)
        self.conv7 = conv_bn_relu(512, 256)
        self.up_pool8 = up_pooling(256, 128)
        self.conv8 = conv_bn_relu(256, 128)
        self.up_pool9 = up_pooling(128, 64)
        self.conv9 = conv_bn_relu(128, 64)
        self.conv10 = nn.Conv2d(64, output_channels, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out")
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        # 正規化
        x = x / 255.0

        # 資料中の『FCN』に当たる部分
        x1 = self.conv1(x)
        p1 = self.down_pooling(x1)
        x2 = self.conv2(p1)
        p2 = self.down_pooling(x2)
        x3 = self.conv3(p2)
        p3 = self.down_pooling(x3)
        x4 = self.conv4(p3)
        p4 = self.down_pooling(x4)
        x5 = self.conv5(p4)

        # 資料中の『Up Sampling』に当たる部分, torch.catによりSkip Connectionをしている
        p6 = self.up_pool6(x5)
        x6 = torch.cat([p6, x4], dim=1)
        x6 = self.conv6(x6)

        p7 = self.up_pool7(x6)
        x7 = torch.cat([p7, x3], dim=1)
        x7 = self.conv7(x7)

        p8 = self.up_pool8(x7)
        x8 = torch.cat([p8, x2], dim=1)
        x8 = self.conv8(x8)

        p9 = self.up_pool9(x8)
        x9 = torch.cat([p9, x1], dim=1)
        x9 = self.conv9(x9)

        output = self.conv10(x9)
        output = torch.sigmoid(output)

        return output


# 畳み込みとバッチ正規化と活性化関数Reluをまとめている
def conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


def down_pooling():
    return nn.MaxPool2d(2)


def up_pooling(in_channels, out_channels, kernel_size=2, stride=2):
    return nn.Sequential(
        # 転置畳み込み
        nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

In [43]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.0 * intersection + smooth) / (
            inputs.sum() + targets.sum() + smooth
        )
        BCE = F.binary_cross_entropy(inputs, targets, reduction="mean")
        Dice_BCE = BCE + dice_loss

        return Dice_BCE


class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [44]:
class IoU(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoU, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection

        IoU = (intersection + smooth) / (union + smooth)

        return IoU

In [None]:
# <---------------各インスタンス作成---------------------->
model = UNet(3, 1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = DiceLoss()
accuracy_metric = IoU()

valid_loss_min = np.Inf

best_model_file = "/best_model.pth"

total_train_loss = []
total_train_score = []
total_valid_loss = []
total_valid_score = []

losses_value = 0
for epoch in range(const.NUM_EPOCHS):
    # <---------------トレーニング---------------------->
    train_loss = []
    train_score = []
    valid_loss = []
    valid_score = []
    pbar = tqdm(train_loader, desc="description")
    for x_train, y_train in pbar:
        x_train = torch.autograd.Variable(x_train).cuda()
        y_train = torch.autograd.Variable(y_train).cuda()
        optimizer.zero_grad()
        output = model(x_train)
        ## 損失計算
        loss = criterion(output, y_train)
        losses_value = loss.item()
        ## 精度評価
        score = accuracy_metric(output, y_train)
        loss.backward()
        optimizer.step()
        train_loss.append(losses_value)
        train_score.append(score.item())
        pbar.set_description(f"Epoch: {epoch+1}, loss: {losses_value}, IoU: {score}")
    # <---------------評価---------------------->
    with torch.no_grad():
        for image, mask in val_loader:
            image = torch.autograd.Variable(image).cuda()
            mask = torch.autograd.Variable(mask).cuda()
            output = model(image)
            ## 損失計算
            loss = criterion(output, mask)
            losses_value = loss.item()
            ## 精度評価
            score = accuracy_metric(output, mask)
            valid_loss.append(losses_value)
            valid_score.append(score.item())

    total_train_loss.append(np.mean(train_loss))
    total_train_score.append(np.mean(train_score))
    total_valid_loss.append(np.mean(valid_loss))
    total_valid_score.append(np.mean(valid_score))
    print(f"Train Loss: {total_train_loss[-1]}, Train IOU: {total_train_score[-1]}")
    print(f"Valid Loss: {total_valid_loss[-1]}, Valid IOU: {total_valid_score[-1]}")

    checkpoint = {
        "epoch": epoch + 1,
        "valid_loss_min": total_valid_loss[-1],
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }

    ## 容量をとるので必要になったら保存する
    # checkpoint_file = "/checkpoint_{}_weight.pth".format(epoch+1)
    # checkpointの保存
    # torch.save(checkpoint, CHECKPOINT_PATH_UNet + checkpoint_file)

    # 評価データにおいて最高精度のモデルのcheckpointの保存
    if total_valid_loss[-1] <= valid_loss_min:
        print(
            "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
                valid_loss_min, total_valid_loss[-1]
            )
        )
        torch.save(checkpoint, const.CHECKPOINT_PATH_UNet + best_model_file)
        valid_loss_min = total_valid_loss[-1]

    print("")

In [None]:
# # 数値の配列を文字列にして返す関数
# def convert_list_el_to_str_from_num(arr_num, num_decimals = 0):
#     arr_str = []
#     for num in arr_num:
#         if num_decimals == 0:
#             stg = str(num)
#         else:
#             stg = str(math.floor(num * 10 ** num_decimals) / (10 ** num_decimals))
#             while True:
#                 if not(stg[-1] in ["0", "."]):
#                     break
#                 elif stg == "0.0":
#                     stg = "0"
#                 else:
#                     stg = stg[:-1]

#         arr_str.append(stg)
#     return arr_str


def convert_list_el_to_str_from_num(arr_num, num_decimals=0):
    print(arr_num)
    arr_str = []
    for num in arr_num:
        if num_decimals == 0:
            stg = str(num)
        else:
            num_foor = math.floor(num * 10**num_decimals) / (10**num_decimals)
            print("num_foor:" + str(num_foor))
            stg = str(num_foor)
            print("stg:" + stg)
            while True:
                if not (stg[-1] in ["0", "."]):
                    break
                elif stg == "0.0":
                    stg = "0"
                    break
                else:
                    stg = stg[:-1]
                # pdb.set_trace()

        print("append:" + stg)
        arr_str.append(stg)
    return arr_str


plt.figure(1)
plt.figure(figsize=(15, 5))
# sns.set_style(style="darkgrid")

if const.NUM_EPOCHS < 20:
    arange_num = np.arange(1, const.NUM_EPOCHS + 1, step=1)
    arange_str = convert_list_el_to_str_from_num(arange_num)

    # print(arange_str)
    plt.xticks(arange_num, arange_str)
    plt.xlim(1, const.NUM_EPOCHS)
elif const.NUM_EPOCHS < 60:
    plt.xticks(np.arange(1, const.NUM_EPOCHS + 1, step=5))
    plt.xlim(0, const.NUM_EPOCHS)
else:
    plt.xticks(np.arange(1, const.NUM_EPOCHS + 1, step=10))
    plt.xlim(0, const.NUM_EPOCHS)


plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.major.width"] = 1.0
plt.rcParams["ytick.major.width"] = 1.0
plt.rcParams["font.size"] = 18
plt.rcParams["axes.linewidth"] = 1.0


plt.subplot(1, 2, 1)
arange_num = np.arange(0, 0.25 + 0.01, step=0.05)
# arange_str = convert_list_el_to_str_from_num(arange_num, 0)
# めんどくさいから自分で指定する
arange_str = ["0", "0.05", "0.1", "0.15", "0.2", "0.25"]
# print(arange_str)
plt.yticks(arange_num, arange_str)
plt.ylim(0, 0.25)
plt.minorticks_on()
sns.lineplot(x=range(1, const.NUM_EPOCHS + 1), y=total_train_loss, label="Train Loss")
sns.lineplot(x=range(1, const.NUM_EPOCHS + 1), y=total_valid_loss, label="Valid Loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("DiceLoss")


plt.subplot(1, 2, 2)
arange_num = np.arange(0.65, 1 + 0.01, step=0.05)
# arange_str = convert_list_el_to_str_from_num(arange_num, 2)
arange_str = ["0.65", "0.7", "0.75", "0.8", "0.85", "0.9", "0.95", "1"]
# print(arange_str)
plt.yticks(arange_num, arange_str)
plt.ylim(0.65, 1)
plt.minorticks_on()
sns.lineplot(x=range(1, const.NUM_EPOCHS + 1), y=total_train_score, label="Train Score")
sns.lineplot(x=range(1, const.NUM_EPOCHS + 1), y=total_valid_score, label="Valid Score")
plt.title("Score (IoU)")
plt.xlabel("Epochs")
plt.ylabel("IoU")
plt.show()

In [52]:
# bestmodelの読み込み
checkpoint = torch.load(const.CHECKPOINT_PATH_UNet + best_model_file)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"]
valid_loss_min = checkpoint["valid_loss_min"]

In [None]:
def visualize_predict(model, n_images, num_range):
    images = random.sample(range(0, num_range), n_images)
    print(images)
    figure, ax = plt.subplots(nrows=n_images, ncols=3, figsize=(15, 18))

    # now = datetime.now()
    # formatted_time = now.strftime("%Y%m%d%H%M")

    # output_directory = f"{const.APP_PATH}/tmp/{const.TRAIN_DIR}/{formatted_time}"
    # if not os.path.exists(output_directory):
    #     os.makedirs(output_directory)

    with torch.no_grad():
        for data, mask in val_loader:
            data = torch.autograd.Variable(data, volatile=True).cuda()
            mask = torch.autograd.Variable(mask, volatile=True).cuda()
            o = model(data)
            break
    for i in range(0, len(images)):
        img_no = images[i]
        tm = o[i][0].data.cpu().numpy()
        img = data[i].data.cpu().numpy()
        msk = mask[i].data.cpu().numpy()
        img = format_image(img)
        msk = format_mask(msk)
        ax[i, 0].imshow(img)
        ax[i, 1].imshow(msk, interpolation="nearest", cmap="gray")
        ax[i, 2].imshow(tm, interpolation="nearest", cmap="gray")
        ax[i, 0].set_title(f"Input Image No.{img_no+1}")
        ax[i, 1].set_title(f"Label Mask No.{img_no+1}")
        ax[i, 2].set_title(f"Predicted Mask No.{img_no+1}")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
        ax[i, 2].set_axis_off()
        # plt.imsave(f"{output_directory}/Input_Image_No_{img_no+1}.png", img)
        # plt.imsave(f"{output_directory}/Label_Mask_No_{img_no+1}.png", msk)
        # plt.imsave(f"{output_directory}/Predicted_Mask_No_{img_no+1}.png", tm)
    plt.tight_layout()
    plt.show()


visualize_predict(model, 6, num_range)

In [None]:
def visualize_full_predict(model):
    full_dataset = image_loader.LoadDataSet(
        setting.const.TRAIN_PATH,
        const.IMG_HEIGHT,
        const.IMG_WIDTH,
        transform=image_loader.get_train_transform(
            const.IMG_HEIGHT, const.IMG_WIDTH, horizontal_flip=0.0, vertical_flip=0.0
        ),
    )

    all_loader = DataLoader(dataset=full_dataset, batch_size=full_dataset.__len__())

    with torch.no_grad():
        for data, mask in all_loader:
            data = torch.autograd.Variable(data, volatile=True).cuda()
            mask = torch.autograd.Variable(mask, volatile=True).cuda()
            o = model(data)
            break

    n_images = len(data)
    print(n_images)
    figure, ax = plt.subplots(nrows=n_images, ncols=3, figsize=(15, 180))

    now = datetime.now()
    formatted_time = now.strftime("%Y%m%d%H%M")

    output_directory = f"{const.APP_PATH}/tmp/{const.TRAIN_DIR}/{formatted_time}"
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for i in range(0, len(data)):
        img_no = i
        tm = o[i][0].data.cpu().numpy()
        img = data[i].data.cpu().numpy()
        msk = mask[i].data.cpu().numpy()
        img = format_image(img)
        msk = format_mask(msk)
        ax[i, 0].imshow(img)
        ax[i, 1].imshow(msk, interpolation="nearest", cmap="gray")
        ax[i, 2].imshow(tm, interpolation="nearest", cmap="gray")
        ax[i, 0].set_title(f"Input Image No.{img_no+1}")
        ax[i, 1].set_title(f"Label Mask No.{img_no+1}")
        ax[i, 2].set_title(f"Predicted Mask No.{img_no+1}")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
        ax[i, 2].set_axis_off()
        plt.imsave(f"{output_directory}/Input_Image_No_{img_no+1}.png", img)
        plt.imsave(f"{output_directory}/Label_Mask_No_{img_no+1}.png", msk)
        plt.imsave(f"{output_directory}/Predicted_Mask_No_{img_no+1}.png", tm)
    plt.tight_layout()
    plt.show()


visualize_full_predict(model)