In [1]:
import io
import optuna
from PIL import Image
import os
import wandb
import glob
import torch
import monai
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

In [2]:
torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Data Augmentation
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.45, ], [0.35, ]),
    transforms.Resize([256, 256]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([256, 256])
])

target_transform = transforms.Compose([
    transforms.Resize([256, 256], interpolation=transforms.InterpolationMode.NEAREST),
    transforms.RandomRotation(45, interpolation=transforms.InterpolationMode.NEAREST),
    transforms.RandomResizedCrop([256, 256], interpolation=transforms.InterpolationMode.NEAREST)
])

Using cuda:1 device


In [3]:
class SegDataset(Dataset):
    def __init__(self, data_root, transform, target_transform, train=True, to3d=False):
        self.data_root = data_root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.to3d = to3d
        self.gt_files_path = []

        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))
        # find all files with the suffix _gt.npy
        train_size = int(len(patient_directories)*0.8)

        if self.train:
            for patient_directory in patient_directories[0:train_size]:
                per_patient_file_path = glob.glob(
                    os.path.join(patient_directory, '*_gt.npy'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)
        else:
            for patient_directory in patient_directories[train_size:]:
                per_patient_file_path = glob.glob(os.path.join(
                    patient_directory, '*_gt.npy'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)

    def __len__(self):
        return len(self.gt_files_path)

    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".npy"
        image = np.load(image_path)
        gt_image = np.load(gt_image_path)
        image = torch.tensor(image[None, :, :]).float()
        gt_image = torch.tensor(gt_image).long()
        # Convert the ground truth label to one-hot encoding
        one_hot_label = torch.nn.functional.one_hot(gt_image, num_classes=4)

        # Transpose the tensor to have dimensions (C, H, W)
        one_hot_label = one_hot_label.permute(2, 0, 1)

        # Remove the background channel (dimension 0)
        one_hot_label = one_hot_label[1:, :, :]

        # Use seed to make sure image and target has same transform
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        image = self.transform(image)

        random.seed(seed)
        torch.manual_seed(seed)
        target = self.target_transform(one_hot_label)

        # Convert 1d grayscale image to 3d, for transformer backbone
        if self.to3d:
            rgb_tensor = image.repeat(3, 1, 1)
            return rgb_tensor, target
        return image, target

In [4]:
train_dataset = SegDataset(data_root='./database/training',
                           transform=transform,
                           target_transform=target_transform,
                           train=True,
                           to3d=False)

val_dataset = SegDataset(data_root='./database/training',
                         transform=transform,
                         target_transform=target_transform,
                         train=False,
                         to3d=False)

In [5]:
def vis_img(img, mask):
    img = np.squeeze(img)
    mask = np.squeeze(mask)
    plt.figure()
    plt.imshow(img, 'gray')
    overlay_mask_1 = np.ma.masked_where(mask[0] == 0, img)
    overlay_mask_2 = np.ma.masked_where(mask[1] == 0, img)
    overlay_mask_3 = np.ma.masked_where(mask[2] == 0, img)
    plt.imshow(overlay_mask_1, 'Greens', alpha=1, interpolation='nearest')
    plt.imshow(overlay_mask_2, 'Reds', alpha=1, interpolation='nearest')
    plt.imshow(overlay_mask_3, 'Purples', alpha=1, interpolation='nearest')
    buffer = io.BytesIO()
    plt.savefig(buffer, format='jpeg')
    buffer.seek(0)

    # Convert the in-memory buffer to a NumPy array
    image_array = np.array(Image.open(buffer))
    plt.close()
    return image_array

In [8]:
def train_model(trial):

    # Hyperparameters to be optimized
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    batch_size = 16
    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-5])
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
    loss_func = trial.suggest_categorical("loss_func",
                                          ["DiceLoss",
                                           "GeneralizedDiceLoss",
                                           "TverskyLoss"])

    # Choose loss function
    if loss_func == "DiceLoss":
        seg_loss = monai.losses.DiceLoss(sigmoid=True,
                                         squared_pred=True,
                                         reduction='mean')
    elif loss_func == "GeneralizedDiceLoss":
        seg_loss = monai.losses.GeneralizedDiceLoss(sigmoid=True,
                                                    squared_pred=True,
                                                    reduction='mean')
    elif loss_func == "TverskyLoss":
        seg_loss = monai.losses.TverskyLoss(sigmoid=True,
                                            squared_pred=True,
                                            reduction='mean')

    run = wandb.init(
        project="Unet-Res50-Tune",
        name=f'Unet-Res50-trial-{trial.number}',
        config={
            "number of epoches": 30,
            "learning rate": lr,
            "batch_size": batch_size,
            "optimizer": optimizer_name,
            "weight decay": weight_decay,
            "loss function": loss_func,
            "transform": str(transform),
            "target transform": str(target_transform)
        })

    # Set up DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    # Set up model and optimizer
    model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=1,
        classes=3
    )

    model.to(device)

    if optimizer_name == "Adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif optimizer_name == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay)

    best_loss = 1e10
    # Start training and validation
    for epoch in range(30):
        # Train
        model.train()
        epoch_loss = 0
        for step, (img, gt) in enumerate(tqdm(train_loader)):
            img = img.to(device)
            mask = model(img)
            loss = seg_loss(mask, gt.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'EPOCH: {epoch + 1}, Train Loss: {epoch_loss}')
        # Validation
        model.eval()
        val_loss = 0
        last_image_batch = None
        last_gt_mask_batch = None
        last_pr_mask_batch = None
        with torch.no_grad():
            for step, (img, gt) in enumerate(tqdm(val_loader)):
                img = img.to(device)
                mask = model(img)
                loss = seg_loss(mask, gt.to(device))
                val_loss += loss.item()
                last_image_batch = img
                last_gt_mask_batch = gt
                last_pr_mask_batch = mask

        print(f'EPOCH: {epoch + 1}, Validation Loss: {val_loss}')

        last_image = last_image_batch.detach().cpu().numpy()[0][0]
        last_gt = last_gt_mask_batch.detach().cpu().numpy()[0]
        last_pr = last_pr_mask_batch.detach().cpu().numpy()[0]

        threshold = 0.95
        binary_mask = (last_pr > threshold)

        ground_truth = vis_img(last_image, last_gt)
        predicted = vis_img(last_image, binary_mask)
        # Log
        wandb.log({"loss": epoch_loss,
                   "val_loss": val_loss,
                   "ground_truth": wandb.Image(ground_truth),
                   "prediction": wandb.Image(predicted)})

        # save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), './model/unet-tune/model_best.pth')

    # Return the validation loss as this is the value to be optimized
    run.finish()
    return best_loss

In [9]:
study = optuna.create_study(direction="minimize")
study.optimize(train_model, n_trials=100)

# Print out the best hyperparameters
print(study.best_params)

[I 2023-06-07 15:56:53,894] A new study created in memory with name: no-name-9a3003ef-62cb-4905-a4bc-b362fe307c35
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)


2023-06-07 15:56:54,139 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mming686[0m ([33mdeeplearning-med[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 96/96 [02:17<00:00,  1.43s/it]


EPOCH: 1, Train Loss: 87.9100975394249


100%|██████████| 23/23 [00:21<00:00,  1.09it/s]


EPOCH: 1, Validation Loss: 20.88624197244644


  5%|▌         | 5/96 [00:03<01:11,  1.27it/s]
[W 2023-06-07 16:00:41,606] Trial 0 failed with parameters: {'lr': 1.2742501079003229e-05, 'batch_size': 16, 'weight_decay': 1e-05, 'optimizer': 'SGD', 'loss_func': 'DiceLoss'} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/jovyan/.local/lib/python3.8/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_1729/3584275792.py", line 77, in train_model
    epoch_loss += loss.item()
KeyboardInterrupt
[W 2023-06-07 16:00:41,608] Trial 0 failed with value None.


KeyboardInterrupt: 