In [None]:
# !pip3 install monai
# !pip3 install -U Setuptools
# # !pip install torch
# # !pip install torchvision
# !pip install optuna
# !pip install git+https://github.com/qubvel/segmentation_models.pytorch
# !pip install adabelief-pytorch==0.2.0

## Train ECA-NFNet

In [None]:
import logging
import os
import sys
import tempfile
import glob
import time
import matplotlib.pyplot as plt
import numpy as np

import setuptools
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader 
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference, SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    RandZoomd,
    RandFlipd,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    AsChannelFirstd,
    AsChannelLast,
    Resized,
    RandScaleCropd,
    RandRotated,
    SaveImage,
    RandSpatialCropd,
    Resize,
)
from monai.visualize import plot_2d_or_3d_image

## Monai 

In [None]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

## Data folder
#### 設定資料夾路徑，其中含有 "Train_Images" 與 "Train_Annotations_png" 兩個子資料夾

In [None]:
# Set the Data folder
data_path = './SEG_Train_Datasets/'
os.listdir(data_path)

## Load training data
#### 以字典檔的形式，將每一張圖象(image)對應到其遮罩(mask)，並將字典檔放入列表中，再從列表中切分 training data 與 validation data

In [None]:
val_data = 237

tempdir = data_path + "Train_Images"
train_images = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "Train_Annotations_png"
train_segs = sorted(glob.glob(os.path.join(tempdir, "*.png")))

# training data
train_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[val_data:], train_segs[val_data:])]
print(f" {len(train_images[val_data:])} train_images and {len(train_segs[val_data:])} train_segs")

# validation data
val_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[:val_data], train_segs[:val_data])]
print(f" {len(train_images[:val_data])} val_images and {len(train_segs[:val_data])} val_segs")

In [None]:
# 確認 data 
train_files[:5]

## Define Trasform for image and segmentation

In [None]:
# define transforms for image and segmentation
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        # RandZoomd(keys=["img", "seg"], prob=0.3)
        RandFlipd(keys=["img", "seg"], prob=0.3),
        Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        RandRotate90d(keys=["img", "seg"], prob=0.3, spatial_axes=[0, 1]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        Resized(keys=["img"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

## Create DataLoader for train and validation data

In [None]:
# 依照所使用的 GPU 記憶體大小，自行設置 batch_size
batch_size = 24

# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
    train_ds,
    batch_size= batch_size,
    shuffle=True,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size= batch_size, collate_fn=list_data_collate)

## Define metric and post-processing
####  設定 DiceMetric，與 output 使用的 transform，這部分採用老師所給的範例程式碼

In [None]:
# dicemetric
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

# output transform
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

## Set enviroment

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3'

In [None]:
# check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


## Create visualize function

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, cmap= 'gray')
    plt.show()

## Built model
```
Encoder 採用 tu-eca_nfnet_l2

Decoder 採用 DeepLabV3Plus 
```

In [None]:
import segmentation_models_pytorch as smp

In [None]:
# 模型名稱
encode_mod = 'tu-eca_nfnet_l2_DeepLabV3Plus'

In [None]:
aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.4,               # dropout ratio, default is None
    activation=None,           # activation function, default is None
    classes=1,                 # define number of output labels
)


encodes = 'tu-eca_nfnet_l2'
model_no = smp.DeepLabV3Plus(encodes, aux_params=aux_params)

In [None]:
model = torch.nn.DataParallel(model_no).to(device)

In [None]:
from adabelief_pytorch import AdaBelief
optimizer = AdaBelief(model.parameters(), lr=1e-4, eps=1e-16, betas=(0.9, 0.99), weight_decouple = True, rectify = False, weight_decay = 1e-4)
loss_function = monai.losses.DiceLoss(sigmoid=True)

## Training
#### 從老師的範例代碼中，對於訓練次數、儲存路徑等做部分微調

In [None]:
# start a typical PyTorch training
total_epochs = 250
val_interval = 1
best_metric = 100   # 存best
best_metric_epoch = -1  # 存best
epoch_loss_values = list()   # 空list = []
metric_values = list()
writer = SummaryWriter()  # tensorflow 專有

for epoch in range(total_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{total_epochs}")

    # 開起 train 模式
    model.train()
    epoch_loss = 0
    step = 0

    # 計時
    time_start = time.time()

    for batch_data in train_loader:

        step += 1
        time_end = time.time()
        epoch_len = len(train_ds) // train_loader.batch_size   # 多少epoch
        print(f"Step {step}/{epoch_len}", end='\r')

        inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)   # forward
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += float(loss.item())
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)


    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())    
    print("\n", f"{local_time} epoch {epoch + 1} training average loss: {epoch_loss:.4f}")


    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            steps = 0
            loss_val = 0

            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)

                val_outputs, _ = model(val_images) #forward
                val_outputs = Resize([-1, 1716, 942])(val_outputs)

                val_loss = monai.losses.DiceLoss(sigmoid=True)(val_outputs, val_labels)
                loss_val += float(val_loss)


                # val output 經過轉換
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                

                if  steps < 3 :
                    print("Loss (Validation) : ", val_loss)
                    visualize( 
                        image=val_images[0].cpu().permute(1,2,0), 
                        ground_truth_mask=val_labels[0].cpu().permute(1,2,0), 
                        predicted_mask=val_outputs[0].cpu().permute(1,2,0)

                    )  
                steps += 1
                dice_metric(y_pred=val_outputs, y=val_labels)


            # aggregate the final mean dice result
            print("val_loss = ", loss_val / steps)
            val_loss_ave = loss_val / steps

            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(val_loss_ave)

            if val_loss_ave < best_metric:
                best_metric = val_loss_ave
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), f"{encode_mod}.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current val mean dice loss: {:.4f} best val mean dice loss: {:.4f} at epoch {}".format(
                    epoch + 1, val_loss_ave, best_metric, best_metric_epoch
                )
            )
            
            torch.save(model.state_dict(), f"last_epoch_{encode_mod}.pth")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

## Check save model
#### 再跑一次所儲存的模型，確認Validation分數與儲存之最佳模型一致

In [None]:
model.load_state_dict(torch.load(f"{encode_mod}.pth"))

In [None]:
model.eval()
dice_metric.reset()
with torch.no_grad():
    val_images = None
    val_labels = None
    val_outputs = None
    # show_val = True
    for val_data in val_loader:
        val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
        val_outputs, _ = model(val_images) #forward
        val_outputs = Resize([-1, 1716, 942])(val_outputs)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                
        
        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        print(dice_metric.aggregate())
    # aggregate the final mean dice result
    metric = dice_metric.aggregate().item()
    print("metric = ", metric)
    # reset the status for next validation round
    dice_metric.reset()

## Public dataset

In [None]:
tempdir = "STAS/demo/Public_Image"
test_images = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))

print(f" {len(test_images)} test_images")

test_files = [{"img": img} for img in test_images[:]]
test_files[:5]

In [None]:
test_transforms = Compose(
    [
        LoadImaged(keys=["img"]),      
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img"])
    ]
)
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1,  collate_fn=list_data_collate)

In [None]:
alls = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))
alls[0].split("/")[-1].split(".")[0]

In [None]:
model.eval()
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        test_images = test_data["img"].to(device)
    
        test_outputs, test_out_label = model(test_images) #forward
        test_outputs = Resize([-1, 1716, 942])(test_outputs)

        saverPD = SaveImage(output_dir=f"./Predict_{encode_mod}/outputspub", output_ext=".png", output_postfix=f"{alls[i].split('/')[-1].split('.')[0]}",scale=255,separate_folder=False)
        saverPD(test_outputs[0].cpu())
        

In [None]:
alls = sorted(glob.glob(f"./Predict_{encode_mod}/outputspub/*.png"))
alls[:5]

In [None]:
for jss in alls:
    os.rename(jss, os.path.join(*jss.split("/")[:-1], jss.split("/")[-1].split("_", 1)[-1]))

## Private dataset

In [None]:
tempdir = "STAS/demo/Private_Image"
private_images = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))

print(f" {len(private_images)} private_images")

private_files = [{"img": img} for img in private_images[:]]
private_files[:5]

In [None]:
test_transforms = Compose(
    [
        LoadImaged(keys=["img"]),      
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img"])
    ]
)
test_ds = monai.data.Dataset(data= private_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1,  collate_fn=list_data_collate)

In [None]:
alls = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))
alls[0].split("/")[-1].split(".")[0]

In [None]:
model.eval()
steppp = 0
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        test_images = test_data["img"].to(device)

        test_outputs, test_out_label = model(test_images) #forward
        test_outputs = Resize([-1, 1716, 942])(test_outputs)

        saverPD = SaveImage(output_dir=f"./Predict_{encode_mod}/outputspri", output_ext=".png", output_postfix=f"{alls[i].split('/')[-1].split('.')[0]}",scale=255,separate_folder=False)
        saverPD(test_outputs[0].cpu())
              

In [None]:
alls = sorted(glob.glob(f"./Predict_{encode_mod}/outputspri/*.png"))
alls[:5]

In [None]:
for jss in alls:
    os.rename(jss, os.path.join(*jss.split("/")[:-1], jss.split("/")[-1].split("_", 1)[-1]))

## 移動

In [None]:
!mv Predict_{encode_mod}/outputspri/* Predict_{encode_mod}/outputspub

In [None]:
!zip -r Predict_tu-eca_nfnet_l2_DeepLabV3Plus.zip Predict_tu-eca_nfnet_l2_DeepLabV3Plus