In [1]:
!pip install SimpleITK
!pip install opencv-python
!pip install segmentation-models-pytorch
!pip install monai

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [7]:
import os
import cv2
import glob
import torch
import monai
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import SimpleITK as sitk

torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda:1 device


In [8]:
def convert_to_npy(source_path):
    """
    Convert the nii file to npz, reduce RAM cost at the creation of dataset

    Args:
        source_path: file path to nii files

    Returns:
    """
    patient_directories = glob.glob(os.path.join(source_path, 'patient*'))

    for patient_dir in patient_directories:
        gt_image_paths = glob.glob(os.path.join(patient_dir, '*_gt.nii'))
        for gt_image_path in gt_image_paths:
            prefix = gt_image_path[:-7]
            image_path = prefix + ".nii"
            gt_image_array = sitk.GetArrayFromImage(sitk.ReadImage(gt_image_path))
            image_array = sitk.GetArrayFromImage(sitk.ReadImage(image_path))
            for i, image in enumerate(image_array):
                np.save(prefix + '_slice' + str(i) + '.npy', image)
            for i, gt_image in enumerate(gt_image_array):
                np.save(prefix + '_slice' + str(i) + '_gt.npy', gt_image)

In [9]:
# convert_to_npy('./database/training')

In [10]:
class SegDataset(Dataset):
    def __init__(self, data_root, transform=None, train=True):
        self.data_root = data_root
        self.transform = transform
        self.train = train
        self.gt_files_path = []
        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))

        # find all files with the suffix _gt.npy
        for patient_directory in patient_directories:
            per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.npy'))
            for path in per_patient_file_path:
                self.gt_files_path.append(path)
        
    def __len__(self):
        return len(self.gt_files_path)
    
    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".npy"
        image = np.load(image_path)
        gt_image = np.load(gt_image_path)
        image = torch.tensor(image[None,:,:]).float()
        gt_image = torch.tensor(gt_image).long()
            
        # Convert the ground truth label to one-hot encoding
        one_hot_label = torch.nn.functional.one_hot(gt_image, num_classes=4)

        # Transpose the tensor to have dimensions (C, H, W)
        one_hot_label = one_hot_label.permute(2, 0, 1)

        # Remove the background channel (dimension 0)
        one_hot_label = one_hot_label[1:, :, :]
        
        transform = transforms.Resize([224, 224])
        
        return transform(image), transform(one_hot_label)

In [11]:
# transform = transforms.Compose([
#     transforms.Resize([224, 224]),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

dataset = SegDataset(data_root = './database/training')
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
print(len(dataset))
for img, gt in dataloader:
    # img: (B, 256, 64, 64), {: (B, 1, 256, 256)
    print(f"{img.shape=}, {gt.shape=}")
    break

1902
img.shape=torch.Size([8, 1, 224, 224]), gt.shape=torch.Size([8, 3, 224, 224])


In [12]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

preprocess_input = get_preprocessing_fn('resnet50', pretrained='imagenet')

#model = torch.nn.DataParallel(model)
model.to(device)

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [13]:
# Set up the optimizer, hyperparameter tuning will improve performance here
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [14]:
#%% train
num_epochs = 20
losses = []
best_loss = 1e10
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    # train
    for step, (img, gt) in enumerate(tqdm(dataloader)):
        img = img.to(device)
        print(img.shape)
        mask = model(img)
        loss = seg_loss(mask, gt.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= step
    losses.append(epoch_loss)
    print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), './model/unet-test/model_best.pth')

  0%|          | 0/238 [00:00<?, ?it/s]

torch.Size([8, 1, 224, 224])





RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED