In [1]:
import torch
import torch.utils.data as dataset_torch
from abc import ABC
import os
import SimpleITK as sitk
import time
import numpy as np
from lungmask import utils

from torchvision import transforms

from lungmask import mask
import SimpleITK as sitk

In [2]:
import os

data_extensions = [
    '.nii.gz',
]
roi_extensions = [
    '.roi.nii.gz',
]


def is_image_file(filename, mode='data'):
    if mode == 'roi':
        return any(filename.endswith(extension) for extension in roi_extensions)
    elif mode == 'data':
        if not 'roi' in filename and \
                any(filename.endswith(extension) for extension in data_extensions):
            return True
        else:
            return False
    else:
        raise ValueError('Undefined mode %s while reading data' % mode)


def make_dataset(dir, max_dataset_size=float("inf"), mode='data'):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        # return of os.walk: root dir, folders, files
        for fname in fnames:
            if is_image_file(fname, mode):
                path = os.path.join(root, fname)
                images.append(path)
    return images[:min(max_dataset_size, len(images))]




In [3]:
from torchsummary import summary

model = mask.get_model('unet','LTRCLobes')
model.to('cuda')

summary(model, (1,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
       BatchNorm2d-3         [-1, 64, 256, 256]             128
            Conv2d-4         [-1, 64, 256, 256]          36,928
              ReLU-5         [-1, 64, 256, 256]               0
       BatchNorm2d-6         [-1, 64, 256, 256]             128
     UNetConvBlock-7         [-1, 64, 256, 256]               0
            Conv2d-8        [-1, 128, 128, 128]          73,856
              ReLU-9        [-1, 128, 128, 128]               0
      BatchNorm2d-10        [-1, 128, 128, 128]             256
           Conv2d-11        [-1, 128, 128, 128]         147,584
             ReLU-12        [-1, 128, 128, 128]               0
      BatchNorm2d-13        [-1, 128, 128, 128]             256
    UNetConvBlock-14        [-1, 128, 1

In [9]:
class BaseDataset(dataset_torch.Dataset, ABC):
    def __init__(self, dir):
        """
        dir: File directory.
        """
        self.dir = dir
        self.img_list = sorted(make_dataset(dir, mode='data'))
        self.roi_list = sorted(make_dataset(dir, mode='roi'))
        
        self.A_size = len(self.img_list)  # get the size of dataset
        self.B_size = len(self.roi_list)  # get the size of roi-set
        
        assert(self.A_size == self.B_size)
        if self.A_size == 0:
            raise(RuntimeError("Found 0 datafiles in: " + dir))

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index (int)      -- a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor)       -- an image in the input domain
            B (tensor)       -- its corresponding image in the target domain
            A_paths (str)    -- image paths
            B_paths (str)    -- image paths
        """
        A_path = self.img_list[index] #% self.A_size]  # make sure index is within then range
        B_path = self.roi_list[index] #% self.B_size]  # make sure index is within then range

        # apply image transformation
        A = sitk.ReadImage(A_path)  # data
        B = sitk.ReadImage(B_path)  # roi

        return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        """Return the total number of images in the dataset.

        As we have two datasets with potentially different number of images,
        we take a maximum of
        """
        return max(self.A_size, self.B_size)

class LungLabelsDS_inf(dataset_torch.Dataset):
    def __init__(self, ds, lb):
        self.dataset = ds
        self.label = lb

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx, None, :, :], self.label[idx, None, :, :]
    
def preprocess(img_data, roi_data):
    img_raw = sitk.GetArrayFromImage(img_data)
    roi_raw = sitk.GetArrayFromImage(roi_data)
    directions = np.asarray(img_data.GetDirection())
    if len(directions) == 9:
        img_raw = np.flip(img_raw, np.where(directions[[0, 4, 8]][::-1] < 0)[0])
        roi_raw = np.flip(roi_raw, np.where(directions[[0, 4, 8]][::-1] < 0)[0])

    tvolslices, labelslices, xnew_box = utils.preprocess(img_raw, label=roi_raw, resolution=[256, 256])
    
#     print(np.shape(xnew_box))
    
    tvolslices = np.divide((tvolslices + 1024), 1624)
    torch_ds_val = LungLabelsDS_inf(tvolslices, xnew_box)
    
    return torch.utils.data.DataLoader(torch_ds_val, batch_size=1, shuffle=True, num_workers=1, pin_memory=False)

In [10]:
import torch
import torch.nn as nn
import torch.functional as f
import numpy as np


class DICELossMultiClass(nn.Module):

    def __init__(self):
        super(DICELossMultiClass, self).__init__()

    def forward(self, output, mask_original):
        num_classes = output.size(1)
        
        dice_eso = 0
        for i in range(num_classes):
            probs = torch.squeeze(output[:, i, :, :], 1)
#             mask = torch.squeeze(mask[:, i, :, :], 1)
            mask = torch.where(mask_original == i, torch.ones_like(Y).to(device), torch.zeros_like(Y).to(device))

            num = probs * mask
            num = torch.sum(num, 2)
            num = torch.sum(num, 1)

            # print( num )

            den1 = probs * probs
            # print(den1.size())
            den1 = torch.sum(den1, 2)
            den1 = torch.sum(den1, 1)

            # print(den1.size())

            den2 = mask * mask
            # print(den2.size())
            den2 = torch.sum(den2, 2)
            den2 = torch.sum(den2, 1)

            # print(den2.size())
            eps = 0.0000001
            dice = 2 * ((num + eps) / (den1 + den2 + eps))
            # dice_eso = dice[:, 1:]
            dice_eso += dice

        loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
        return loss

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = mask.get_model('unet','LTRCLobes')
model.to(device)

criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-1)

dir = r'/home/avitech-pc4/Nam/data'
dataset = BaseDataset(dir)

total_iters = 0
torch.cuda.empty_cache()


for epoch in range(10):

    for i, data in enumerate(dataset):

        dataloader_val = preprocess(data['A'], data['B'])
        
        Loss = 0
        epoch_loss = 0

        for i, (X, Y) in enumerate(dataloader_val):
            pass
#             if (i % 2 == 0) and (i // 80 == 0):
                
#                 X = X.float().to(device)
#                 label = torch.tensor(Y[:,0,:,:],dtype=torch.long).to(device)

#                 prediction = model(X)
                
#                 Loss += criterion(prediction, label)

#         optimizer.zero_grad()
#         Loss.backward()
#         epoch_loss += Loss.item()
#         print(epoch_loss)
#         optimizer.step()


            
            


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return self.dataset[idx, None, :, :].astype(np.float), self.label[idx, None, :, :].astype(np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return self.dataset[idx, None, :, :].astype(np.float), self.label[idx, None, :, :].astype(np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return self.dataset[idx, None, :, :].astype(np.float), self.label[idx, None, :, :].astype(np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return self.dataset[idx, None, :, :].astype(np.float), self.label[idx, None, :, :].astype(np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.

In [40]:
Y.shape

torch.Size([1, 1, 256, 256])

In [41]:
prediction.shape

torch.Size([1, 6, 256, 256])

In [56]:
loss = nn.NLLLoss()

label = torch.tensor(Y[:,0,:,:],dtype=torch.long)

loss(prediction, label.to(device))

tensor(0.1807, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [53]:
label = torch.tensor(Y[:,0,:,:],dtype=torch.long)

In [54]:
label.shape

torch.Size([1, 256, 256])