In [1]:
!pip install pyvista[jupyter] vtk trame jupyter-server-proxy nibabel

Collecting vtk
  Using cached vtk-9.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting trame
  Using cached trame-3.10.2-py3-none-any.whl.metadata (8.2 kB)
Collecting jupyter-server-proxy
  Using cached jupyter_server_proxy-4.4.0-py3-none-any.whl.metadata (8.7 kB)
Collecting nibabel
  Using cached nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Collecting pyvista[jupyter]
  Using cached pyvista-0.45.2-py3-none-any.whl.metadata (15 kB)
Collecting matplotlib>=3.0.1 (from pyvista[jupyter])
  Using cached matplotlib-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy>=1.21.0 (from pyvista[jupyter])
  Using cached numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pillow (from pyvista[jupyter])
  Using cached pillow-11.2.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (8.9 kB)
Collecting pooch (from pyvista[jupyter])
  Using cached pooch-1.8.2-py3-none-a

In [2]:
import pyvista as pv
pv.set_jupyter_backend('trame')

In [3]:
import nibabel as nib
import numpy as np
import pyvista as pv

# 1) Read NIfTI with NiBabel
nifti = nib.load('cropped_norm/cropped_norm/pat0_cropped_norm.nii.gz')
vol = nifti.get_fdata().astype(np.float32)      # shape (Z, Y, X) :contentReference[oaicite:4]{index=4}

# 2) Create the PyVista grid
grid = pv.ImageData()                        # 3D uniform grid container :contentReference[oaicite:4]{index=4}

# 3) Dimensions must be data.shape + 1 for cell_data,
#    or = data.shape for point_data. We'll use point_data here:
grid.dimensions = list(vol.shape)            # (nx, ny, nz) = (Z, Y, X)
grid.origin     = (0, 0, 0)                  # adjust if you know physical origin
grid.spacing    = (1, 1, 1)                  # adjust to voxel size (dx, dy, dz)

# 4) Assign the voxel intensities into point_data
#    Flatten in Fortran order to match VTK memory layout
grid.point_data['MRI'] = vol.flatten(order='F')  

In [4]:
# Create a Plotter for full control
plotter = pv.Plotter()

plotter.add_volume(
    grid,
    scalars='MRI',
    cmap='gray',          # intensity colormap
    opacity='sigmoid',    # smooth transfer function
    shade=True            # adds shading/depth cues
)

# Stream into notebook via Trame
plotter.show(jupyter_backend='trame')

Widget(value='<iframe src="http://localhost:36439/index.html?ui=P_0x7008032bd9c0_0&reconnect=auto" class="pyvi…

In [11]:
# Load segmentation
seg = nib.load('cropped_norm/cropped_norm/pat0_cropped_seg.nii.gz').get_fdata().astype(np.uint8)
seg_grid = pv.ImageData()
seg_grid.dimensions = list(seg.shape)
seg_grid.origin     = (0, 0, 0)
seg_grid.spacing    = (1, 1, 1)
seg_grid.point_data['Seg'] = seg.flatten(order='F')

cmap = [
    'black',    # placeholder for 0 (will be made fully transparent)
    'red',      # label 1
    'green',    # label 2
    'blue',     # label 3
    'yellow',   # label 4
    'magenta',  # label 5
    'cyan',     # label 6
    'orange',   # label 7
    'purple'    # label 8
]
opacity = [
    0.0,  # label 0 transparent
] + [
    1.0   # labels 1–8 fully opaque
] * 8


# Add on top of the MRI
plotter.add_volume(
    seg_grid,
    scalars='Seg',
    cmap=cmap,
    opacity=opacity          # semi-transparent overlay
)
plotter.update()    

In [10]:
np.unique(seg)

array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=uint8)

In [None]:
import torch
import torch.nn as nn

class DoubleConv3D(nn.Module):
    """
    Two successive 3x3x3 convolutions each followed by ReLU activation,
    with a dropout between them to match the Keras implementation.
    """
    def __init__(self, in_channels, out_channels, dropout_prob):
        super(DoubleConv3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu  = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout3d(dropout_prob)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

        # He (Kaiming) uniform initialization to mimic 'he_uniform'
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.dropout(x)
        x = self.relu(self.conv2(x))
        return x


class UNet3D(nn.Module):
    """
    3D U-Net architecture with input and output spatial dimensions matching.
    Mirrors the TensorFlow/Keras implementation provided.
    """
    def __init__(self, in_channels=1, num_classes=1, base_features=16):
        super(UNet3D, self).__init__()
        f = base_features

        # Contracting path
        self.enc1 = DoubleConv3D(in_channels,        f,  dropout_prob=0.1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.enc2 = DoubleConv3D(f,                 f*2, dropout_prob=0.1)
        self.enc3 = DoubleConv3D(f*2,               f*4, dropout_prob=0.2)
        self.enc4 = DoubleConv3D(f*4,               f*8, dropout_prob=0.2)
        self.enc5 = DoubleConv3D(f*8,               f*16,dropout_prob=0.3)

        # Expansive path
        self.up5  = nn.ConvTranspose3d(f*16, f*8, kernel_size=2, stride=2)
        self.dec4 = DoubleConv3D(f*16, f*8, dropout_prob=0.2)
        self.up4  = nn.ConvTranspose3d(f*8,  f*4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv3D(f*8,  f*4, dropout_prob=0.2)
        self.up3  = nn.ConvTranspose3d(f*4,  f*2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv3D(f*4,  f*2, dropout_prob=0.1)
        self.up2  = nn.ConvTranspose3d(f*2,  f,    kernel_size=2, stride=2)
        self.dec1 = DoubleConv3D(f*2,  f,    dropout_prob=0.1)

        # Final 1x1x1 convolution
        self.final_conv = nn.Conv3d(f, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        c1 = self.enc1(x)
        p1 = self.pool(c1)

        c2 = self.enc2(p1)
        p2 = self.pool(c2)

        c3 = self.enc3(p2)
        p3 = self.pool(c3)

        c4 = self.enc4(p3)
        p4 = self.pool(c4)

        c5 = self.enc5(p4)

        # Decoder
        u5 = self.up5(c5)
        u5 = torch.cat((u5, c4), dim=1)
        c6 = self.dec4(u5)

        u4 = self.up4(c6)
        u4 = torch.cat((u4, c3), dim=1)
        c7 = self.dec3(u4)

        u3 = self.up3(c7)
        u3 = torch.cat((u3, c2), dim=1)
        c8 = self.dec2(u3)

        u2 = self.up2(c8)
        u2 = torch.cat((u2, c1), dim=1)
        c9 = self.dec1(u2)

        outputs = self.final_conv(c9)
        return outputs


if __name__ == "__main__":
    # Quick sanity check
    model = UNet3D(in_channels=3, num_classes=4)
    x = torch.randn(1, 3, 128, 128, 128)
    y = model(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {y.shape}")
