In [1]:
import io
from PIL import Image
import os
import wandb
import glob
import torch
import monai
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

%matplotlib inline

### Hyper-paramter

In [2]:
torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Data Augmentation
transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([224, 224]),
    transforms.GaussianBlur(3)
])

target_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([224, 224])
])

lr = 1e-4
batch_size = 8
weight_decay = 0
num_epochs = 20

run = wandb.init(
    project="Unet",
    # Track hyperparameters and run metadata
    config={
        "learning rate": lr,
        "batch_size": batch_size,
        "weight decay": weight_decay,
        "Epoches number": num_epochs,
        "transform": str(transform),
        "target transform": str(target_transform)
    })

Using cuda:1 device
2023-05-26 16:12:28,352 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mming686[0m ([33mdeeplearning-med[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Create Segmentation Dataset

In [3]:
class SegDataset(Dataset):
    def __init__(self, data_root, transform, target_transform, train=True):
        self.data_root = data_root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.gt_files_path = []
        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))
        # find all files with the suffix _gt.npy
        for patient_directory in patient_directories:
            per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.npy'))
            for path in per_patient_file_path:
                self.gt_files_path.append(path)
        
    def __len__(self):
        return len(self.gt_files_path)
    
    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".npy"
        image = np.load(image_path)
        gt_image = np.load(gt_image_path)
        image = torch.tensor(image[None,:,:]).float()
        gt_image = torch.tensor(gt_image).long()
            
        # Convert the ground truth label to one-hot encoding
        one_hot_label = torch.nn.functional.one_hot(gt_image, num_classes=4)

        # Transpose the tensor to have dimensions (C, H, W)
        one_hot_label = one_hot_label.permute(2, 0, 1)

        # Remove the background channel (dimension 0)
        one_hot_label = one_hot_label[1:, :, :]
        
        # Use seed to make sure image and target has same transform
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        image = self.transform(image)
        random.seed(seed)
        torch.manual_seed(seed)
        target = self.target_transform(one_hot_label)
        
        return image, target

In [4]:
dataset = SegDataset(data_root = './database/training', 
                     transform = transform, 
                     target_transform = target_transform)
# Split into train set and validation set
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

### Init Model

In [5]:
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels
    classes=3,                      # model output channels (number of classes)
)

preprocess_input = get_preprocessing_fn('resnet50', pretrained='imagenet')

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

### Train

In [6]:
def vis_img(img, mask):
    # img: (B, 256, 64, 64), {: (B, 1, 256, 256)
    print(f"{img.shape=}, {mask.shape=}")
    img = np.squeeze(img)
    mask = np.squeeze(mask)
    plt.figure()
    plt.imshow(img, 'gray')
    overlay_mask_0 = np.ma.masked_where(mask[0] == 0, img)
    overlay_mask_1 = np.ma.masked_where(mask[1] == 0, img)
    overlay_mask_2 = np.ma.masked_where(mask[2] == 0, img)
    plt.imshow(overlay_mask_0, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    plt.imshow(overlay_mask_1, 'Reds', alpha = 0.7, clim=[0,1], interpolation='nearest')
    plt.imshow(overlay_mask_2, 'Purples', alpha = 0.7, clim=[0,1], interpolation='nearest')
    buffer = io.BytesIO()
    plt.savefig(buffer, format='jpeg')
    buffer.seek(0)

    # Convert the in-memory buffer to a NumPy array
    image_array = np.array(Image.open(buffer))
    return image_array

In [None]:
# train
best_loss = 1e10
for epoch in range(num_epochs):
    # Train
    model.train()
    epoch_loss = 0
    for step, (img, gt) in enumerate(tqdm(train_loader)):
        img = img.to(device)
        mask = model(img)
        loss = seg_loss(mask, gt.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= step
    print(f'EPOCH: {epoch + 1}, Train Loss: {epoch_loss}')
    
    # Validation
    model.eval()
    val_loss = 0
    last_image_batch = None
    last_gt_mask_batch = None
    last_pr_mask_batch = None
    with torch.no_grad():
        for step, (img, gt) in enumerate(tqdm(val_loader)):
            img = img.to(device)
            mask = model(img)
            loss = seg_loss(mask, gt.to(device))
            val_loss += loss.item()
            last_image_batch = img
            last_gt_mask_batch = gt
            last_pr_mask_batch = mask
            
    val_loss /= step
    print(f'EPOCH: {epoch + 1}, Validation Loss: {val_loss}')
    
    last_image = last_image_batch.detach().cpu().numpy()[0][0]
    last_gt = last_gt_mask_batch.detach().cpu().numpy()[0]
    last_pr = last_pr_mask_batch.detach().cpu().numpy()[0]
    
    threshold = 0.95  # Set your desired threshold value
    binary_mask = (last_pr > threshold)
    
    ground_truth = vis_img(last_image, last_gt)
    predicted = vis_img(last_image, binary_mask)
    # Log
    wandb.log({"loss": epoch_loss,
               "val_loss": val_loss,
               "ground_truth": wandb.Image(ground_truth),
               "prediction": wandb.Image(predicted)})
    
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), './model/unet-test/model_best.pth')

100%|██████████| 191/191 [00:41<00:00,  4.59it/s]


EPOCH: 1, Train Loss: 1.1762848788186124


100%|██████████| 48/48 [00:06<00:00,  7.75it/s]


EPOCH: 1, Validation Loss: 0.837933379284879
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:42<00:00,  4.54it/s]


EPOCH: 2, Train Loss: 0.7137712481774782


100%|██████████| 48/48 [00:08<00:00,  5.74it/s]


EPOCH: 2, Validation Loss: 0.6506437633899932
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:42<00:00,  4.51it/s]


EPOCH: 3, Train Loss: 0.586770035404908


100%|██████████| 48/48 [00:06<00:00,  7.43it/s]


EPOCH: 3, Validation Loss: 0.5767067639117546
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:41<00:00,  4.56it/s]


EPOCH: 4, Train Loss: 0.5431366805967531


100%|██████████| 48/48 [00:07<00:00,  6.58it/s]


EPOCH: 4, Validation Loss: 0.5536772925803002
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:43<00:00,  4.42it/s]


EPOCH: 5, Train Loss: 0.5272362749827536


100%|██████████| 48/48 [00:06<00:00,  7.70it/s]


EPOCH: 5, Validation Loss: 0.5347043351924166
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:39<00:00,  4.83it/s]


EPOCH: 6, Train Loss: 0.5144824864048707


100%|██████████| 48/48 [00:06<00:00,  6.91it/s]


EPOCH: 6, Validation Loss: 0.5370964591807508
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:42<00:00,  4.50it/s]


EPOCH: 7, Train Loss: 0.495151686197833


100%|██████████| 48/48 [00:06<00:00,  6.97it/s]


EPOCH: 7, Validation Loss: 0.5042009486797008
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:41<00:00,  4.60it/s]


EPOCH: 8, Train Loss: 0.4948288541091116


100%|██████████| 48/48 [00:08<00:00,  5.74it/s]


EPOCH: 8, Validation Loss: 0.5173850465328136
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:42<00:00,  4.50it/s]


EPOCH: 9, Train Loss: 0.48874271072839437


100%|██████████| 48/48 [00:07<00:00,  6.42it/s]


EPOCH: 9, Validation Loss: 0.5027152471085812
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:43<00:00,  4.42it/s]


EPOCH: 10, Train Loss: 0.4872175522540745


100%|██████████| 48/48 [00:05<00:00,  8.89it/s]


EPOCH: 10, Validation Loss: 0.5199826543635511
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:41<00:00,  4.61it/s]


EPOCH: 11, Train Loss: 0.4823729047649785


100%|██████████| 48/48 [00:06<00:00,  7.03it/s]
  plt.figure()


EPOCH: 11, Validation Loss: 0.500339666579632
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


100%|██████████| 191/191 [00:42<00:00,  4.49it/s]


EPOCH: 12, Train Loss: 0.48246732479647586


100%|██████████| 48/48 [00:06<00:00,  7.24it/s]


EPOCH: 12, Validation Loss: 0.5030604239473951
img.shape=(224, 224), mask.shape=(3, 224, 224)
img.shape=(224, 224), mask.shape=(3, 224, 224)


 17%|█▋        | 33/191 [00:07<00:36,  4.29it/s]