# Breast Tumor in Ultrasound 2D Segmentation with MONAI

1. Transforms for dictionary format data.
1. Load Nifti image with metadata.
1. Cache IO and transforms to accelerate training and validation.
1. 2D UNet model, Dice loss function, Mean Dice metric for 2D segmentation task.
1. Deterministic training for reproducibility.

Target: Breast Tumor  
Modality: Ultrasound  
Dataset: 2D images in NII (80% Training + 10% Validation + 10% Testing)

### MDA-Net

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandAdjustContrastd,
    RandZoomd,
    RandGridDistortiond,
)
from monai.handlers.utils import from_engine
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import time
import glob
import tqdm
import numpy as np
from IPython.core.debugger import set_trace
from torch.utils.tensorboard import SummaryWriter
import random
from monai.data.utils import pad_list_data_collate
import pdb
import torch.nn as nn
import nibabel as nib

from MDANet import MDA_Net


## Setup imports

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


print_config()

## Dataset path

In [None]:
data_dir = "/home/users/jvilaca/raul/3_2_CBIS_Croped_padding"
print(data_dir)

## Set Breast Ultrasound dataset path

In [None]:
#Train
train_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
train_files = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]


#Val
val_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesVl", "*.nii.gz")))
val_labels = sorted(
    glob.glob(os.path.join(data_dir, "labelsVl", "*.nii.gz")))
val_files = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(val_images, val_labels)
]

#Test
test_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesTs", "*.nii.gz")))
test_labels = sorted(
    glob.glob(os.path.join(data_dir, "labelsTs", "*.nii.gz")))
test_files = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(test_images, test_labels)
]

In [None]:
print(test_files)

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.
1. `EnsureTyped` converts the numpy array to PyTorch Tensor for further steps.

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=255,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.5, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.5, prob=0.5),
        RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=0.05),
        RandGaussianSmoothd(keys="image", prob=0.25),
        RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.5,2.5)),
        RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=1, max_zoom=1.3),
        RandGridDistortiond(keys=["image", "label"], prob=0.5, distort_limit=(-0.2,0.2)),

        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=255,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)

test_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=255,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)

## Check transforms in DataLoader

In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image, cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label)
plt.show()

In [None]:
np.shape(image)

## Define CacheDataset and DataLoader for training and validation

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  
To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  
Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  
And set `num_workers` to enable multi-threads during caching.  
If want to to try the regular Dataset, just change to use the commented code below.

In [None]:
train_ds = CacheDataset(
    data=train_files, transform=train_transforms,
    cache_rate=1.0, num_workers=4)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

train_loader = DataLoader(train_ds, batch_size=10, shuffle=True, num_workers=0)


val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)
# val_ds = Dataset(data=val_files, transform=val_transforms)

val_loader = DataLoader(val_ds, batch_size=5, num_workers=4)


test_ds = CacheDataset(
    data=test_files, transform=test_transforms, cache_rate=1.0, num_workers=0)
# val_ds = Dataset(data=val_files, transform=val_transforms)

test_loader = DataLoader(test_ds, batch_size=1, num_workers=4)

In [None]:
# Hyperparameters etc.
#DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device("cuda:0")
NUM_EPOCHS = 5000
#LOAD_MODEL = False
best_metric = -1
best_metric_epoch = -1

def train_fn(loader, model, optimizer, loss_fn, scaler):
    #loop = tqdm(loader)
    
    for batch_data in loader:
        #data = data.to(device=DEVICE)
        #targets = targets.float().unsqueeze(1).to(device=DEVICE)
        inputs, labels = (
            batch_data["image"].to(device=DEVICE),
            batch_data["label"].to(device=DEVICE).unsqueeze(1),
            )
        dims = np.shape(inputs)
        inputs = torch.reshape(inputs,[dims[0],1,dims[2],dims[3]])
        print(np.shape(inputs))
        labels = torch.reshape(labels,[dims[0],1,dims[2],dims[3]])
        #print(np.shape(labels))
        
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(inputs)
            loss = loss_fn(predictions, labels)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        # loader.set_postfix(loss=loss.item())
        
    return loss

In [None]:
def check_accuracy(loader, model, device, loss_fn):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for batch_data in loader:
            inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device).unsqueeze(1),
            )
            dims = np.shape(inputs)        
            inputs = torch.reshape(inputs,[dims[0],1,dims[2],dims[3]])
            # print(np.shape(inputs))
            labels = torch.reshape(labels,[dims[0],1,dims[2],dims[3]])
            # print(np.shape(labels))
            #x = x.to(device)
            #y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(inputs))
            preds = (preds > 0.5).float()
            num_correct += (preds == labels).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * labels).sum()) / (
                (preds + labels).sum() + 1e-8
            )
            
             # forward
            with torch.cuda.amp.autocast():
                predictions = model(inputs)
                loss = loss_fn(predictions, labels)

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")

    dice = dice_score/len(loader)

    model.train()

    return dice, loss

## Create root path

In [None]:
parent_dir = "/home/users/jvilaca/raul/Methods/MDA-UNet"
mes_ext = {1: 'Jan', 2 : 'Feb', 3: 'Mar', 4: 'Apr', 5: 'May',6:'Jun',7:'Jul', 8:'Aug',9:'Sep', 10: 'Oct', 11:'Nov', 12:'Dec'}
month = time.strftime("%m")
time_str = time.strftime("%d_%H-%M-%S")
filename = mes_ext[int(month)]+time_str;

path = os.path.join(parent_dir, filename)
os.mkdir(path)

root_dir = path
print(root_dir)

# Training!

In [None]:
model = MDA_Net(img_ch=1, output_ch=1).to(DEVICE)
#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = DiceCELoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# if LOAD_MODEL:
#     checkpoint = torch.load("my_checkpoint.pth.tar")
#     print("=> Loading checkpoint")
#     model.load_state_dict(checkpoint["state_dict"])

# pdb.set_trace()
check_accuracy(val_loader, model, DEVICE, loss_fn)
scaler = torch.cuda.amp.GradScaler()

writer = SummaryWriter()

for epoch in range(NUM_EPOCHS):

    print(str(epoch) + '/' + str(NUM_EPOCHS) + ' epochs')

    loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
    print('Loss: ' + str(loss))
    writer.add_scalar("Loss/train", loss, epoch+1)
    
    # save model
    # checkpoint = {
    #     "state_dict": model.state_dict(),
    #     "optimizer":optimizer.state_dict(),
    # }
    #print("=> Saving checkpoint")
    # torch.save(checkpoint, "my_checkpoint.pth")
    
    # check accuracy
    dice, loss_val = check_accuracy(val_loader, model, DEVICE, loss_fn)
    print('Dice: ' + str(dice))
    writer.add_scalar("Dice/val", dice, epoch+1)
    
    print('Loss: ' + str(loss))
    writer.add_scalar("Loss/val", loss_val, epoch+1)
    
    metric = dice
    
    if metric > best_metric:
        best_metric = metric
        best_metric_epoch = epoch + 1
        torch.save({'model': model.state_dict(),'epoch': epoch+1,'optimizer': optimizer.state_dict()}, os.path.join(
            root_dir, "best_metric_model.pth"))
        print("=> Saving best metric model checkpoint")
        
    # print some examples to a folder
    #save_predictions_as_imgs(
    #    val_loader, model, folder="saved_images/", device=DEVICE)
    
torch.save({'model': model.state_dict(),'epoch': epoch+1,'optimizer': optimizer.state_dict()}, os.path.join(
                root_dir, "final_metric_model.pth"))
print("=> Saving final metric model checkpoint")

In [None]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")

# Testing!

## Load model and save prediction masks 

In [None]:
root_dir = "/home/users/jvilaca/raul/Methods/MDA-UNet/Jan04_18-32-41"
pathSave = root_dir + "/predLabelsSig"
os.mkdir(pathSave)
pathSave = pathSave + "/"
pathSave

In [None]:
# model_name = "final_metric_model.pth"
model_name = "best_metric_model.pth"

checkpoint = torch.load(os.path.join(root_dir , model_name))
model.load_state_dict(checkpoint['model'])

model.eval()
i = 0
sig=nn.Sigmoid()
with torch.no_grad():
    for test_data in test_loader:
        test_inputs, test_labels = (
            test_data["image"].to(device=DEVICE),
            test_data["label"].to(device=DEVICE),
        )
        dimsTest = np.shape(test_inputs)
        test_inputs = torch.reshape(test_inputs,[dimsTest[0],1,dimsTest[2],dimsTest[3]])
        test_labels = torch.reshape(test_labels,[dimsTest[0],1,dimsTest[2],dimsTest[3]])
        kk = model(test_inputs)
        kk = sig(kk)
#         kk=kk[0]
        kk = torch.where(kk>0.90, 1, 0)
        kk = np.expand_dims(kk[0,0,:,:].cpu().detach().numpy(), axis = 2) 
        name = test_files[i]['image']
        masknii = nib.load(name)  
        m = masknii.affine
        header = masknii.header
        predicted = nib.Nifti1Image(kk, m, header)
        nib.save(predicted, pathSave + name.split('/')[-1])
        i = i + 1

In [None]:
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image")
plt.imshow(test_inputs[0, 0, :, :].cpu().detach().numpy(), cmap="gray")
plt.subplot(1, 3, 2)
plt.title(f"label")
plt.imshow(test_labels[0, 0, :, :].cpu().detach().numpy(), cmap="gray")
plt.subplot(1, 3, 3)
plt.title(f"output")
plt.imshow(kk[:, :, 0], cmap="gray")
plt.show()