In [1]:
import torch
import os
import nibabel as nib
from scipy.ndimage import zoom

In [2]:
class CustomDataset(torch.utils.data.Dataset):

    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        image = nib.load(img_path)
        mask = nib.load(mask_path).get_fdata()
        z_dim = image.header.get_zooms()[2]
        image = image.get_fdata()
        image = zoom(image, (0.5, 0.5, z_dim / 2))
        mask = zoom(mask, (0.5, 0.5, z_dim / 2))

        return image, mask



In [3]:
from torch.utils.data import DataLoader

train_data = CustomDataset("WORD-V0.1.0/imagesTr", "WORD-V0.1.0/labelsTr")
test_data = CustomDataset("WORD-V0.1.0/imagesVal", "WORD-V0.1.0/labelsVal")


train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=2, shuffle=True)

In [4]:
device = "cuda:3"

In [5]:
from torch import nn
class customUNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.C1 = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(16),
            nn.ReLU()
        )

        self.C1Pool = nn.MaxPool3d(2)

        self. C2 = nn.Sequential(
            nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.ReLU()
        )
        
        self.C2Pool = nn.MaxPool3d(2)

        self.C3 = nn.Sequential(
            nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU()
        )
        
        self.C3Pool = nn.MaxPool3d(2)

        self.C4 = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )
        
        self.C4Pool = nn.MaxPool3d(2)

        self.Mid = nn.Sequential(
            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(256),
            nn.ReLU(),
            nn.Conv3d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(256),
            nn.ReLU()
        )

        self.U4 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )

        self.S4 = nn.Sequential(
            nn.Conv3d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )

        self.U3 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU()
        )

        self.S3 = nn.Sequential(
            nn.Conv3d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU()
        )

        self.U2 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU()
        )

        self.S2 = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm3d(32),
                nn.ReLU(),
                nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm3d(32),
                nn.ReLU()
        )

        self.U1 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU()
        )

        self.S1 = nn.Sequential(
            nn.Conv3d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(16),
            nn.ReLU()
        )

        self.output = nn.Sequential(
            nn.Conv3d(in_channels=16, out_channels=17, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        C1 = self.C1(input)
        P1 = self.C1Pool(C1)
        C2 = self.C2(P1)
        P2 = self.C2Pool(C2)
        C3 = self.C3(P2)
        P3 = self.C3Pool(C3)
        C4 = self.C4(P3)
        P4 = self.C4Pool(C4)

        middle = self.Mid(P4)

        U4 = self.U4(middle)
        UC4 = torch.cat([C4, U4], dim=1)
        S4 = self.S4(UC4)
        U3 = self.U3(S4)
        UC3 = torch.cat([C3, U3], dim=1)
        S3 = self.S3(UC3)
        U2 = self.U2(S3)
        UC2 = torch.cat([C2, U2], dim=1)
        S2 = self.S2(UC2)
        U1 = self.U1(S2)
        UC1 = torch.cat([C1, U1], dim=1)
        S1 = S1(UC1)
        output = output(S1)
        
        return output

In [7]:
batch = next(iter(test_dataloader))
(x, y) = batch
x = x.to(device)
print(x.shape)




torch.Size([2, 1, 512, 512, 512])


In [8]:
model.eval()

customUNet(
  (C1): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (C1Pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (C2): Sequential(
    (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (C2Pool): MaxPool3d(kernel_size=2, stride=2, pad

In [9]:

output = model.forward(x)

OutOfMemoryError: CUDA out of memory. Tried to allocate 13.50 GiB. GPU 3 has a total capacty of 23.69 GiB of which 6.40 GiB is free. Including non-PyTorch memory, this process has 17.29 GiB memory in use. Of the allocated memory 17.02 GiB is allocated by PyTorch, and 21.26 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
output.shape

torch.Size([2, 16, 256, 256, 172])