In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Config
seed = 42
# sử dụng 90% cho trainning và 10% cho testing
training_split_ratio = 0.9
num_epochs = 5

# If the following values are False, the models will be downloaded and not computed
compute_histograms = True
train_whole_images = True
train_patches = False

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet highresnet==0.10.2
!pip install --quiet unet==0.7.3
!pip install --quiet torchio==0.18.11
!apt-get -qq install tree

In [None]:
!pip install torch==1.7.0

In [None]:
import copy
import enum
import random; random.seed(seed)
import warnings
import tempfile
import subprocess
import multiprocessing
from pathlib import Path

import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
torch.manual_seed(seed)

import torchio as tio
from torchio import AFFINE, DATA

import numpy as np
import nibabel as nib
from unet import UNet
from scipy import stats
import SimpleITK as sitk
import matplotlib.pyplot as plt

from IPython import display
from tqdm.notebook import tqdm

print('TorchIO version:', tio.__version__)

In [None]:
path_dataset = "../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
root_dir = Path(path_dataset)
flair_path = sorted(list(root_dir.glob('*/*_flair.nii')))
t1_path = sorted(list(root_dir.glob('*/*_t1.nii')))
t1ce_path = sorted(list(root_dir.glob('*/*_t1ce.nii')))
t2_path = sorted(list(root_dir.glob('*/*_t2.nii')))

seg_path = sorted(list(root_dir.glob('*/*_seg.nii')))

image_paths = []
label_paths = []

def convert_path_str(arr):
  for i in arr:
    if "355" in str(i):
      continue
    flair_ = str(i).replace("flair", "flair")
    t1_ = str(i).replace("flair", "t1")
    t1ce_ = str(i).replace("flair", "t1ce")
    t2_ = str(i).replace("flair", "t2")
    image_paths.append({"flair": flair_, "t1": t1_, "t1ce": t1ce_, "t2": t2_})
  
convert_path_str(flair_path)

for i in seg_path:
  label_paths.append(str(i))


In [None]:
subjects = []
for (image_path, label_path) in zip(image_paths, label_paths):
    subject = tio.Subject(
        flair=tio.ScalarImage(image_path["flair"]),
        t1=tio.ScalarImage(image_path["t1"]),
        t1ce=tio.ScalarImage(image_path["t1ce"]),
        t2=tio.ScalarImage(image_path["t2"]), 
        brain=tio.LabelMap(label_path),
    )
    subjects.append(subject)
dataset = tio.SubjectsDataset(subjects)
print('Dataset size:', len(dataset), 'subjects')

In [None]:
training_transform = tio.Compose([
    tio.ToCanonical(),
    tio.RandomMotion(p=0.2),
    tio.RandomBiasField(p=0.3),
    tio.RandomNoise(p=0.5),
    tio.RandomFlip(axes=(0,)),
    tio.OneOf({
        tio.RandomAffine(): 0.8,
        tio.RandomElasticDeformation(): 0.2,
    }),
])

validation_transform = tio.Compose([
    tio.ToCanonical(),
])

num_subjects = len(dataset)
num_training_subjects = int(training_split_ratio * num_subjects)

training_subjects = subjects[:num_training_subjects]
validation_subjects = subjects[num_training_subjects:]

training_set = tio.SubjectsDataset(
    training_subjects, transform=training_transform)

validation_set = tio.SubjectsDataset(
    validation_subjects, transform=validation_transform)

print('Training set:', len(training_set), 'subjects')
print('Validation set:', len(validation_set), 'subjects')

In [None]:
training_transform = tio.Compose([
    tio.ToCanonical(),
    tio.RandomMotion(p=0.2),
    tio.RandomBiasField(p=0.3),
    tio.RandomNoise(p=0.5),
    tio.RandomFlip(axes=(0,)),
    tio.OneOf({
        tio.RandomAffine(): 0.8,
        tio.RandomElasticDeformation(): 0.2,
    }),
])

validation_transform = tio.Compose([
    tio.ToCanonical(),
])

num_subjects = len(dataset)
num_training_subjects = int(training_split_ratio * num_subjects)

training_subjects = subjects[:num_training_subjects]
validation_subjects = subjects[num_training_subjects:]

training_set = tio.SubjectsDataset(
    training_subjects, transform=training_transform)

validation_set = tio.SubjectsDataset(
    validation_subjects, transform=validation_transform)

print('Training set:', len(training_set), 'subjects')
print('Validation set:', len(validation_set), 'subjects')

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

import numpy as np
import torch
import torch
import torch.nn as nn
import os
import nibabel as nib
from albumentations import Compose, HorizontalFlip
from skimage.util import montage
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import time
import torch.nn.functional as F

from IPython.display import clear_output
import seaborn as sns

In [None]:
class DoubleConv(nn.Module):
    """(Conv3D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, num_groups=8):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True)
          )

    def forward(self,x):
        return self.double_conv(x)

    
class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.encoder(x)

    
class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()
        
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

    
class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)

    def forward(self, x):
        return self.conv(x)


class UNet3d(nn.Module):
    def __init__(self, in_channels, n_classes, n_channels):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)

        return mask

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
# Số lượng kênh đầu vào
AXIS_CONCAT = 1
SPATIAL_DIMENSIONS = 2, 3, 4

# định nghĩa tùy chọn tranning hoặc validation
class Action(enum.Enum):
    TRAIN = 'Training'
    VALIDATE = 'Validation'

# chuẩn bị dữ liệu cho mỗi batch
def prepare_batch(batch, device):
    # Ghép các hình ảnh đầu vào lại với nhau
    input_flair = batch['flair'][DATA].to(device)
    input_t1 = batch['t1'][DATA].to(device)
    input_t1ce = batch['t1ce'][DATA].to(device)
    input_t2 = batch['t2'][DATA].to(device)

    # Chuyển đầu vào về đúng kích thước mô hình
    inputs = torch.cat(
        (input_flair, input_t1, input_t1ce, input_t2), 
        dim=AXIS_CONCAT)
    inputs = torch.movedim(inputs, (0, 1, 2, 3, 4), (0, 1, 4, 3, 2))
    
    # Tạo ra các hình ảnh mask cho từng trường hợp: WT, TC, ET
    mask_tumors = batch['brain'][DATA].to(device)

    mask_WT = mask_tumors.clone()
    mask_WT[mask_WT == 1] = 1
    mask_WT[mask_WT == 2] = 1
    mask_WT[mask_WT == 4] = 1

    mask_TC = mask_tumors.clone()
    mask_TC[mask_TC == 1] = 1
    mask_TC[mask_TC == 2] = 0
    mask_TC[mask_TC == 4] = 1

    mask_ET = mask_tumors.clone()
    mask_ET[mask_ET == 1] = 0
    mask_ET[mask_ET == 2] = 0
    mask_ET[mask_ET == 4] = 1

    # Chuyển đầu ra về đúng kích thước mô hình
    masks = torch.cat(
        (mask_WT, mask_TC, mask_ET),
        dim=AXIS_CONCAT)
    targets = torch.movedim(masks, (0, 1, 2, 3, 4), (0, 1, 4, 3, 2))

    return inputs, targets

# def get_dice_score(output, target, epsilon=1e-9):
#   # Tính dice score
#   p0 = output
#   g0 = target
#   p1 = 1 - p0
#   g1 = 1 - g0
#   tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
#   fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
#   fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
#   num = 2 * tp
#   denom = 2 * tp + fp + fn + epsilon
#   dice_score = num / denom

#   return dice_score

# def get_dice_loss(output, target):
#   # Tính loss với dice score  
#   return 1 - get_dice_score(output, target)

############### FocalTversky Loss ###################
import torch.nn.functional as F
import torch.nn as nn

ALPHA = 0.7
GAMMA = 2
BETA = 0.3

def get_focalTverskyLoss(inputs, targets, smooth=1, alpha=ALPHA, beta=BETA, gamma=GAMMA):     
    #flatten label and prediction tensors
    inputs = inputs.reshape(-1)
    targets = targets.reshape(-1)

    #True Positives, False Positives & False Negatives
    TP = (inputs * targets).sum()    
    FP = ((1-targets) * inputs).sum()
    FN = (targets * (1-inputs)).sum()

    Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
    FocalTversky = (1 - Tversky)**gamma

    return FocalTversky

def forward(model, inputs):
  # Thực hiện truyền tiến trong model
  with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    inputs = inputs.float()
    inputs = inputs.to(device)
    logits = model(inputs)
    logits = logits.float()
  
  return logits

def get_model_and_optimizer(device):
    model = UNet3d(in_channels=4, n_classes=3, n_channels=24).to(device)
    optimizer = torch.optim.AdamW(model.parameters())
    return model, optimizer

def run_epoch(epoch_idx, action, loader, model, optimizer):
    is_training = action == Action.TRAIN
    epoch_losses = []
    model.train(is_training)
    for batch_idx, batch in enumerate(tqdm(loader)):
        inputs, targets = prepare_batch(batch, device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(is_training):
            logits = forward(model, inputs)
#             probabilities = F.softmax(logits, dim=AXIS_CONCAT)
            probabilities = torch.sigmoid(logits)
#             batch_losses = get_dice_loss(probabilities, targets)
            batch_losses = get_focalTverskyLoss(probabilities, targets)
            batch_loss = batch_losses.mean()
            if is_training:
                batch_loss.backward()
                optimizer.step()
            epoch_losses.append(batch_loss.item())
    epoch_losses = np.array(epoch_losses)
    print(f'{action.value} mean loss: {epoch_losses.mean():0.3f}')

def train(num_epochs, training_loader, validation_loader, model, optimizer, weights_stem, start):
    run_epoch(0, Action.VALIDATE, validation_loader, model, optimizer)
    for epoch_idx in range(1, num_epochs + 1):
        print('Starting epoch', epoch_idx)
        run_epoch(epoch_idx, Action.TRAIN, training_loader, model, optimizer)
        run_epoch(epoch_idx, Action.VALIDATE, validation_loader, model, optimizer)
        torch.save(model.state_dict(), f'./{weights_stem}_whole_epoch_{epoch_idx+start}.pth')

In [None]:
training_batch_size = 1
validation_batch_size = 1 * training_batch_size

training_loader = torch.utils.data.DataLoader(
    training_set,
    batch_size=training_batch_size,
    shuffle=True,
    num_workers=multiprocessing.cpu_count(),
)

validation_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_size=validation_batch_size,
    num_workers=multiprocessing.cpu_count(),
)

In [None]:
model, optimizer = get_model_and_optimizer(device)
train_whole_images = True
load_pretrain = True
if train_whole_images:
    weights_stem = 'whole_images'
    if load_pretrain:
        weights_path = '../input/weight-unet3d/whole_images_whole_epoch_8.pth'
        start = 9
        model.load_state_dict(torch.load(weights_path))
    train(num_epochs, training_loader, validation_loader, model, optimizer, weights_stem, start)
else:
    weights_path = '/content/drive/MyDrive/BigData/Project/Computer Vision Project/EDA/weights/patches_epoch_2.pth'
    model.load_state_dict(torch.load(weights_path))

In [None]:
# import torch.nn.functional as F
# import torch.nn as nn

# ALPHA = 0.8
# GAMMA = 2

# def get_focalTverskyLoss(inputs, targets, smooth=1, alpha=ALPHA, beta=BETA, gamma=GAMMA):     
#     #flatten label and prediction tensors
#     inputs = inputs.view(-1)
#     targets = targets.view(-1)

#     #True Positives, False Positives & False Negatives
#     TP = (inputs * targets).sum()    
#     FP = ((1-targets) * inputs).sum()
#     FN = (targets * (1-inputs)).sum()

#     Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
#     FocalTversky = (1 - Tversky)**gamma

#     return FocalTversky

In [None]:
# def get_focalLoss(inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
#     #flatten label and prediction tensors
#     inputs = inputs.view(-1)
#     targets = targets.view(-1)

#     #first compute binary cross-entropy 
#     BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
#     BCE_EXP = torch.exp(-BCE)
#     focal_loss = alpha * (1-BCE_EXP)**gamma * BCE

#     return focal_loss

In [None]:
# def get_dice_score(output, target, epsilon=1e-9):
#   # Tính dice score
#   p0 = output
#   g0 = target
#   p1 = 1 - p0
#   g1 = 1 - g0
#   tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
#   fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
#   fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
#   num = 2 * tp
#   denom = 2 * tp + fp + fn + epsilon
#   dice_score = num / denom

#   return dice_score

# def get_dice_loss(output, target):
#   # Tính loss với dice score  
#   return 1 - get_dice_score(output, target)

# def get_BCEDiceLoss(output, target):
#     dice_loss = get_dice_loss(output, target)
#     bce_loss = F.binary_cross_entropy(output, target, reduction='mean')
    
#     return bce_loss + dice_loss