In [1]:
import os
import wandb
import glob
import torch
import monai
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

### Hyper-paramter

In [2]:
torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Data Augmentation
transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([224, 224]),
    transforms.GaussianBlur(3)
])

target_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([224, 224])
])

lr = 1e-4
batch_size = 8
weight_decay = 0
num_epochs = 20

run = wandb.init(
    project="Unet",
    # Track hyperparameters and run metadata
    config={
        "learning rate": lr,
        "batch_size": batch_size,
        "weight decay": weight_decay,
        "Epoches number": num_epochs,
        "transform": str(transform),
        "target transform": str(target_transform)
    })

Using cuda:1 device
Compose(
    Resize(size=[224, 224], interpolation=bilinear)
    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    RandomResizedCrop(size=[224, 224], scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
    GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))
)
2023-05-24 11:45:29,805 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mming686[0m ([33mdeeplearning-med[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Create Segmentation Dataset

In [3]:
class SegDataset(Dataset):
    def __init__(self, data_root, transform, target_transform, train=True):
        self.data_root = data_root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.gt_files_path = []
        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))
        # find all files with the suffix _gt.npy
        for patient_directory in patient_directories:
            per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.npy'))
            for path in per_patient_file_path:
                self.gt_files_path.append(path)
        
    def __len__(self):
        return len(self.gt_files_path)
    
    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".npy"
        image = np.load(image_path)
        gt_image = np.load(gt_image_path)
        image = torch.tensor(image[None,:,:]).float()
        gt_image = torch.tensor(gt_image).long()
            
        # Convert the ground truth label to one-hot encoding
        one_hot_label = torch.nn.functional.one_hot(gt_image, num_classes=4)

        # Transpose the tensor to have dimensions (C, H, W)
        one_hot_label = one_hot_label.permute(2, 0, 1)

        # Remove the background channel (dimension 0)
        one_hot_label = one_hot_label[1:, :, :]
        
        # Use seed to make sure image and target has same transform
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        image = self.transform(image)
        random.seed(seed)
        torch.manual_seed(seed)
        target = self.target_transform(one_hot_label)
        
        return image, target

In [4]:
dataset = SegDataset(data_root = './database/training', 
                     transform = transform, 
                     target_transform = target_transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### Init Model

In [5]:
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels
    classes=3,                      # model output channels (number of classes)
)

preprocess_input = get_preprocessing_fn('resnet50', pretrained='imagenet')

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

### Train

In [None]:
# train
losses = []
best_loss = 1e10
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    for step, (img, gt) in enumerate(tqdm(dataloader)):
        img = img.to(device)
        mask = model(img)
        loss = seg_loss(mask, gt.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= step
    losses.append(epoch_loss)
    wandb.log({"loss": epoch_loss})
    print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), './model/unet-test/model_best.pth')

100%|██████████| 238/238 [01:22<00:00,  2.87it/s]


EPOCH: 0, Loss: 1.2420823460892787


100%|██████████| 238/238 [01:23<00:00,  2.86it/s]


EPOCH: 1, Loss: 0.7412007285069816


100%|█████████▉| 237/238 [01:24<00:00,  2.81it/s]