In [1]:
import io
from PIL import Image
import os
import wandb
import glob
import torch
import monai
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn


### Set-up transforms

In [2]:
torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Data Augmentation
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize([0.45, ], [0.35, ]),
    transforms.Resize([256, 256]),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop([256, 256]),
    
])

target_transform = transforms.Compose([
    transforms.Resize([256, 256], interpolation=transforms.InterpolationMode.NEAREST),
    transforms.RandomRotation(45, interpolation=transforms.InterpolationMode.NEAREST),
    transforms.RandomResizedCrop([256, 256], interpolation=transforms.InterpolationMode.NEAREST)
])

Using cuda:1 device


### Init hyperparameters

In [3]:
lr = 1e-4
batch_size = 8
weight_decay = 1e-5
num_epochs = 25

run = wandb.init(
    project="Training Models",
    name='Linknet - Resnet34',
    # Track hyperparameters and run metadata
    config={
        "learning rate": lr,
        "batch_size": batch_size,
        "weight decay": weight_decay,
        "Epoches number": num_epochs,
        "transform": str(transform),
        "target transform": str(target_transform)
    })

2023-05-30 21:35:08,235 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mr-j-poelarends[0m ([33mdeeplearning-med[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Create Segmentation Dataset

In [4]:
class SegDataset(Dataset):
    def __init__(self, data_root, transform, target_transform, train=True):
        self.data_root = data_root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.gt_files_path = []
        
        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))
        # find all files with the suffix _gt.npy
        train_size = int(len(patient_directories)*0.8)
        
        if self.train:
            for patient_directory in patient_directories[0:train_size]:
                per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.npy'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)
        else:
            for patient_directory in patient_directories[train_size:]:
                per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.npy'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)
                    
    def __len__(self):
        return len(self.gt_files_path)
    
    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".npy"
        image = np.load(image_path)
        gt_image = np.load(gt_image_path)
        image = torch.tensor(image[None,:,:]).float()
        gt_image = torch.tensor(gt_image).long()
        # Convert the ground truth label to one-hot encoding
        one_hot_label = torch.nn.functional.one_hot(gt_image, num_classes=4)

        # Transpose the tensor to have dimensions (C, H, W)
        one_hot_label = one_hot_label.permute(2, 0, 1)

        # Remove the background channel (dimension 0)
        one_hot_label = one_hot_label[1:, :, :]
        
        # Use seed to make sure image and target has same transform
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        image = self.transform(image)

        random.seed(seed)
        torch.manual_seed(seed)
        target = self.target_transform(one_hot_label)
        return image, target

In [5]:
train_dataset = SegDataset(data_root = './database/training', 
                     transform = transform, 
                     target_transform = target_transform,
                     train = True)

val_dataset = SegDataset(data_root = './database/training', 
                     transform = transform, 
                     target_transform = target_transform,
                     train = False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

### Init Model

In [6]:
model = smp.Linknet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels
    classes=3,                      # model output channels (number of classes)
)

#preprocess_input = get_preprocessing_fn('resnet50', pretrained='imagenet')

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

### Train

In [7]:
def vis_img(img, mask):
    # img: (B, 256, 64, 64), {: (B, 1, 256, 256)
    #print(f"{img.shape=}, {mask.shape=}")
    img = np.squeeze(img)
    mask = np.squeeze(mask)
    plt.figure()  
    plt.imshow(img, 'gray')
    overlay_mask_1 = np.ma.masked_where(mask[0] == 0, img)
    overlay_mask_2 = np.ma.masked_where(mask[1] == 0, img)
    overlay_mask_3 = np.ma.masked_where(mask[2] == 0, img)
    plt.imshow(overlay_mask_1, 'Greens', alpha = 1, interpolation='nearest')
    plt.imshow(overlay_mask_2, 'Reds', alpha = 1, interpolation='nearest')
    plt.imshow(overlay_mask_3, 'Purples', alpha = 1, interpolation='nearest')
    buffer = io.BytesIO()
    plt.savefig(buffer, format='jpeg')
    buffer.seek(0)

    # Convert the in-memory buffer to a NumPy array
    image_array = np.array(Image.open(buffer))
    
    plt.close()
    
    return image_array

In [8]:
# for step, (img, gt) in enumerate(tqdm(train_loader)):
#     vis_img(img[5][0].numpy(), gt[5])
#     break;

In [9]:
# train
best_loss = 1e10
val_freq = 2

for epoch in range(num_epochs):
    # Train
    model.train()
    epoch_loss = 0
    for step, (img, gt) in enumerate(tqdm(train_loader)):
        img = img.to(device)
        mask = model(img)
        loss = seg_loss(mask, gt.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    print(f'EPOCH: {epoch + 1}, Train Loss: {epoch_loss}')
    
    # Validation
    if epoch % val_freq == 0:
        model.eval()
        val_loss = 0
        last_image_batch = None
        last_gt_mask_batch = None
        last_pr_mask_batch = None
        with torch.no_grad():
            for step, (img, gt) in enumerate(tqdm(val_loader)):
                img = img.to(device)
                mask = model(img)
                loss = seg_loss(mask, gt.to(device))
                val_loss += loss.item()
                last_image_batch = img
                last_gt_mask_batch = gt
                last_pr_mask_batch = mask
            
    print(f'EPOCH: {epoch + 1}, Validation Loss: {val_loss}')
    
    last_image = last_image_batch.detach().cpu().numpy()[0][0]
    last_gt = last_gt_mask_batch.detach().cpu().numpy()[0]
    last_pr = last_pr_mask_batch.detach().cpu().numpy()[0]
    
    threshold = 0.95  # Set your desired threshold value
    binary_mask = (last_pr > threshold)
    
    ground_truth = vis_img(last_image, last_gt)
    predicted = vis_img(last_image, binary_mask)
    # Log
    wandb.log({"loss": epoch_loss,
               "val_loss": val_loss,
               "ground_truth": wandb.Image(ground_truth),
               "prediction": wandb.Image(predicted)})
    
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), './model/linknet-resnet34/model_best.pth')
torch.save(model.state_dict(), './model/linknet-resnet34/model_final_epoch.pth')
        


100%|██████████| 192/192 [00:37<00:00,  5.16it/s]


EPOCH: 1, Train Loss: 185.93024796247482


100%|██████████| 46/46 [00:05<00:00,  8.35it/s]


EPOCH: 1, Validation Loss: 43.04719823598862


100%|██████████| 192/192 [00:35<00:00,  5.36it/s]


EPOCH: 2, Train Loss: 161.82809907197952
EPOCH: 2, Validation Loss: 43.04719823598862


100%|██████████| 192/192 [00:35<00:00,  5.36it/s]


EPOCH: 3, Train Loss: 135.51280772686005


100%|██████████| 46/46 [00:05<00:00,  8.90it/s]


EPOCH: 3, Validation Loss: 31.273395657539368


100%|██████████| 192/192 [00:35<00:00,  5.36it/s]


EPOCH: 4, Train Loss: 112.99926188588142
EPOCH: 4, Validation Loss: 31.273395657539368


100%|██████████| 192/192 [00:37<00:00,  5.06it/s]


EPOCH: 5, Train Loss: 96.6851849257946


100%|██████████| 46/46 [00:04<00:00,  9.39it/s]


EPOCH: 5, Validation Loss: 20.191386118531227


100%|██████████| 192/192 [00:36<00:00,  5.23it/s]


EPOCH: 6, Train Loss: 83.74024805426598
EPOCH: 6, Validation Loss: 20.191386118531227


100%|██████████| 192/192 [00:36<00:00,  5.30it/s]


EPOCH: 7, Train Loss: 77.46669572591782


100%|██████████| 46/46 [00:06<00:00,  6.66it/s]


EPOCH: 7, Validation Loss: 17.228460729122162


100%|██████████| 192/192 [00:39<00:00,  4.83it/s]


EPOCH: 8, Train Loss: 74.92693395912647
EPOCH: 8, Validation Loss: 17.228460729122162


100%|██████████| 192/192 [00:38<00:00,  4.97it/s]


EPOCH: 9, Train Loss: 70.60948587954044


100%|██████████| 46/46 [00:05<00:00,  8.50it/s]


EPOCH: 9, Validation Loss: 16.220476284623146


100%|██████████| 192/192 [00:38<00:00,  5.04it/s]


EPOCH: 10, Train Loss: 68.89738744497299
EPOCH: 10, Validation Loss: 16.220476284623146


100%|██████████| 192/192 [00:35<00:00,  5.36it/s]


EPOCH: 11, Train Loss: 64.48595505952835


100%|██████████| 46/46 [00:05<00:00,  8.47it/s]


EPOCH: 11, Validation Loss: 15.231720834970474


100%|██████████| 192/192 [00:36<00:00,  5.19it/s]


EPOCH: 12, Train Loss: 60.88318522274494
EPOCH: 12, Validation Loss: 15.231720834970474


100%|██████████| 192/192 [00:37<00:00,  5.06it/s]


EPOCH: 13, Train Loss: 59.29736603796482


100%|██████████| 46/46 [00:05<00:00,  7.98it/s]


EPOCH: 13, Validation Loss: 14.834540486335754


100%|██████████| 192/192 [00:38<00:00,  4.94it/s]


EPOCH: 14, Train Loss: 58.620212972164154
EPOCH: 14, Validation Loss: 14.834540486335754


100%|██████████| 192/192 [00:42<00:00,  4.55it/s]


EPOCH: 15, Train Loss: 60.054640263319016


100%|██████████| 46/46 [00:05<00:00,  7.75it/s]


EPOCH: 15, Validation Loss: 12.706520937383175


100%|██████████| 192/192 [00:42<00:00,  4.48it/s]


EPOCH: 16, Train Loss: 56.14017079770565
EPOCH: 16, Validation Loss: 12.706520937383175


100%|██████████| 192/192 [00:38<00:00,  4.93it/s]


EPOCH: 17, Train Loss: 55.69070006161928


100%|██████████| 46/46 [00:05<00:00,  8.54it/s]


EPOCH: 17, Validation Loss: 12.148647278547287


100%|██████████| 192/192 [00:43<00:00,  4.41it/s]


EPOCH: 18, Train Loss: 55.577513709664345
EPOCH: 18, Validation Loss: 12.148647278547287


100%|██████████| 192/192 [00:38<00:00,  4.94it/s]


EPOCH: 19, Train Loss: 52.68721070885658


100%|██████████| 46/46 [00:05<00:00,  8.11it/s]


EPOCH: 19, Validation Loss: 12.954071432352066


100%|██████████| 192/192 [00:39<00:00,  4.81it/s]


EPOCH: 20, Train Loss: 51.00518652796745
EPOCH: 20, Validation Loss: 12.954071432352066


100%|██████████| 192/192 [00:37<00:00,  5.06it/s]


EPOCH: 21, Train Loss: 52.13398556411266


100%|██████████| 46/46 [00:09<00:00,  4.97it/s]


EPOCH: 21, Validation Loss: 12.855387702584267


100%|██████████| 192/192 [00:41<00:00,  4.64it/s]


EPOCH: 22, Train Loss: 51.56878016144037
EPOCH: 22, Validation Loss: 12.855387702584267


100%|██████████| 192/192 [00:39<00:00,  4.90it/s]


EPOCH: 23, Train Loss: 53.94743222743273


100%|██████████| 46/46 [00:05<00:00,  8.79it/s]


EPOCH: 23, Validation Loss: 12.873438894748688


100%|██████████| 192/192 [00:41<00:00,  4.61it/s]


EPOCH: 24, Train Loss: 49.926273591816425
EPOCH: 24, Validation Loss: 12.873438894748688


100%|██████████| 192/192 [00:38<00:00,  5.02it/s]


EPOCH: 25, Train Loss: 51.95261249691248


100%|██████████| 46/46 [00:05<00:00,  8.54it/s]


EPOCH: 25, Validation Loss: 12.662543073296547
