In [14]:
# print("Importing ...")
import torch
import torch.nn as nn
import torch.optim as optim
from   torch.utils.data import DataLoader
import torchvision.transforms as transforms # Using TorchIO may help in 3D augmentation *
import nibabel as nib
import numpy as np
import random

# Define your model architecture here
# print("Defining Classes ...")

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels , out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module): #
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffZ = x2.size()[2] - x1.size()[2] # NCXYZ
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2,
                                    diffZ // 2, diffZ - diffZ // 2))
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module): ### Add dropout!
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        output = self.outc(x)
        return output

# Define a custom transform class for applying the same random crop
class RandomCrop3D: ###
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        inputs, targets = sample

        # Get the input size
        input_size = inputs.shape[2:] ###

        # Calculate the starting index for the crop
        start_indexes = [random.randint(0, input_size[i] - self.output_size[i]) for i in range(3)]

        # Perform the crop on both inputs and targets
        inputs  = inputs [:,:, start_indexes[0]:start_indexes[0] + self.output_size[0], 
                               start_indexes[1]:start_indexes[1] + self.output_size[1],
                               start_indexes[2]:start_indexes[2] + self.output_size[2]]

        targets = targets[:,:, start_indexes[0]:start_indexes[0] + self.output_size[0], 
                               start_indexes[1]:start_indexes[1] + self.output_size[1],
                               start_indexes[2]:start_indexes[2] + self.output_size[2]]

        return inputs, targets

# Define the output size for random cropping
output_size = (128, 128, 128)

# Define the transforms
transform = transforms.Compose([
    RandomCrop3D(output_size),              # Custom random crop
    # transforms.RandomVerticalFlip(),        # Random vertical flipping
    # transforms.RandomHorizontalFlip()        # Random horizontal flipping
])


# Define your dataset class for loading CT images and masks

class CTImageDataset(torch.utils.data.Dataset): ###
    def __init__(self, image_paths, mask_paths):
        self.image_paths = image_paths
        self.mask_paths  = mask_paths

    def __getitem__(self, index):
        image = nib.load(self.image_paths[index]).get_fdata()
        mask  = nib.load(self.mask_paths [index]).get_fdata()
        image = torch.from_numpy(image).unsqueeze(0).float() ### 1-Channel?!
        mask  = torch.from_numpy(mask ).unsqueeze(0).long() ### Changed!
        return image, mask

    def __len__(self):
        return len(self.image_paths)

In [15]:
# Define your training function

def train(model, train_loader, criterion, optimizer, device): ###
    model.train() ###
    running_loss = 0.0

    for batch_idx, (images, masks) in enumerate(train_loader):
        # print(f"Batch {batch_idx+1} Started")

        images = images.to(device)
        masks  = masks .to(device)

        # Apply transforms to the inputs and targets
        images, masks = transform((images, masks))

        optimizer.zero_grad()

        # Forward pass
        # print("Passing through Model ...")
        outputs = model(images)

        # Compute loss
        # print("CrossEnthropy() ...")
        loss = criterion(outputs, torch.squeeze(masks, dim=1)) ###

        # Backward pass and optimization
        # print("Backward ...")
        loss.backward()
        # print("Step ...")
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

In [16]:
# Set your training parameters
# print("Setting Parameters & Instanciating ...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ####1
epochs = 10
batch_size = 1 #4 ###
learning_rate = 0.0001 #0.001 ###

# Create your model instance

model = UNet3D(in_channels=1, out_channels=3)
model = model.to(device)

# Create your dataset and data loader instances

image_paths_train = ["Data\SPIROMCS-Case36-Vx3.nii.gz"        , "Data\SPIROMCS-Case43-Vx3.nii.gz"]
mask_paths_train  = ["Data\SPIROMCS-Case36-012Labelmap.nii.gz", "Data\SPIROMCS-Case43-012Labelmap.nii.gz"]
train_dataset = CTImageDataset(image_paths_train, mask_paths_train) ### Cases 43&36 ### M:1 A:2 V:3 > 012!
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) ### Mask: B=1?C=1?XYZ? #shuffle=True

# Define your loss function and optimizer

criterion = nn.CrossEntropyLoss() ####2 ignore_index (int, optional) ***
optimizer = optim.Adam(model.parameters(), lr=learning_rate) ###

In [None]:
# Count the number of trainable parameters
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_parameters}")

In [None]:
# Start the training loop
print("Start Training ...")

for epoch in range(epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device) ########
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}")

# Save the trained model

torch.save(model.state_dict(), "model.pth") ###

In [None]:
# Load the trained model
model.load_state_dict(torch.load("model.pth"))
model.eval()

# Load and preprocess the input image
image = nib.load("Data\SPIROMCS-Case36-Vx3.nii.gz").get_fdata()
mask  = nib.load("Data\SPIROMCS-Case36-012Labelmap.nii.gz").get_fdata()

image = torch.from_numpy(image).unsqueeze(0).float() ### Channels=1 !
mask  = torch.from_numpy(mask ).unsqueeze(0).long()  ### Changed!

input_tensor = image.unsqueeze(0).to(device)  # Add batch dimension
mask_tensor  = mask .unsqueeze(0).to(device)  # Add batch dimension ############## to(device) after tramsform!

input_tensor, mask_tensor = transform((input_tensor, mask_tensor))

# Perform inference
with torch.no_grad():
    output_tensor = model(input_tensor)

# Post-process the output tensor
output_tensor = torch.argmax(output_tensor, dim=1)  # Convert to class labels (assuming CrossEntropyLoss was used)

# Convert the output tensor to numpy array
output_array = output_tensor.squeeze(0).cpu().numpy()

In [None]:
# import itkwidgets
# # import itk
# # itk_np = itk.GetImageFromArray(output_array)
# # itk.imwrite(output_array, output_file_name)
# # view(output_array)
# # # view ?
# your_array = np.random.random((64, 64, 64))
# # Visualize the 3D numpy array
# itkwidgets.view(your_array)


# import numpy as np
# import itk

# # Generate or load your 3D numpy array
# your_array = np.random.random((64, 64, 64))  # Replace with your own 3D numpy array

# # Convert numpy array to ITK image
# image = itk.GetImageFromArray(your_array)

# # Visualize the ITK image
# itkwidgets.view(image)


# import itkwidgets
# import numpy as np
# import itk

# # Generate or load your 3D numpy array
# # your_array = np.random.random((64, 64, 64))  # Replace with your own 3D numpy array

# # Ensure the array has the correct shape and datatype
# your_array = np.asarray(output_array, dtype=np.float32)

# # Create an ITK image from the numpy array
# # image = itk.image_view_from_array(your_array)
# image = itk.GetImageFromArray(your_array)

# # Visualize the ITK image
# itkwidgets.view(image)

In [None]:
# import ipyvolume as ipv
# import numpy as np

# # Generate or load your 3D numpy array
# your_array = np.random.random((64, 64, 64))  # Replace with your own 3D numpy array

# # Visualize the 3D numpy array
# ipv.quickvolshow(your_array)

# # Display the visualization
# # ipv.show()


# import numpy as np
# import ipyvolume as ipv
# V = np.zeros((128,128,128)) # our 3d array
# # outer box
# V[30:-30,30:-30,30:-30] = 0.75
# V[35:-35,35:-35,35:-35] = 0.0
# # inner box
# V[50:-50,50:-50,50:-50] = 0.25
# V[55:-55,55:-55,55:-55] = 0.0
# ipv.quickvolshow(V, level=[0.25, 0.75], opacity=0.03, level_width=0.1, data_min=0, data_max=1)
# ipv.show()

In [None]:
# class ImageSliceViewer3D & ipywidgets
import ipywidgets as ipyw
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

class ImageSliceViewer3D:
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('plasma'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, volume, figsize=(8,8), cmap='plasma'):
        self.volume = volume
        self.figsize = figsize
        self.cmap = cmap
        self.v = [np.min(volume), np.max(volume)]
        
        # Call to select slice plane
        ipyw.interact(self.view_selection, view=ipyw.RadioButtons(
            options=['x-y','y-z', 'z-x'], value='x-y', 
            description='Slice plane selection:', disabled=False,
            style={'description_width': 'initial'}))
    
    def view_selection(self, view):
        # Transpose the volume to orient according to the slice plane selection
        orient = {"y-z":[1,2,0], "z-x":[2,0,1], "x-y": [0,1,2]}
        self.vol = np.transpose(self.volume, orient[view])
        maxZ = self.vol.shape[2] - 1
        
        # Call to view a slice within the selected slice plane
        ipyw.interact(self.plot_slice, 
            z=ipyw.IntSlider(min=0, max=maxZ, step=1, continuous_update=False, 
            description='Image Slice:'))
        
    def plot_slice(self, z):
        # Plot slice for the given plane and slice
        self.fig = plt.figure(figsize=self.figsize)
        plt.imshow(self.vol[:,:,z], cmap=plt.get_cmap(self.cmap), 
            vmin=self.v[0], vmax=self.v[1])
        
# Create a 3D array with random numbers
x = np.random.rand(256,256,96)

ImageSliceViewer3D(x)

# The static rendering of Github does not display the image widget, and the 
# ability to interact with the image widget did not work with nbviewer when 
# last (26/05/2018) checked. 

In [None]:
# Ideas/Notes:
# Normalization ****
# Dropout
# Several threads and gpus

# nn.CrossEntropyLoss(): label_smoothing=0.0?!!

# np.prod(input_tensor.size())/8*32 =
# print(input_tensor.storage().nbytes())

# import sys
# !{sys.executable} -m pip install scipy

# python -c "import torch; print(torch.cuda.is_available())"