In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os 

In [None]:
ct_file = "../Task1/image/left_knee.nii.gz"
mask_file = "..//Task1/output/bone_segmentation_task1_1.nii.gz"
output_folder = "segmented_regions"

In [None]:
def visualize(image_data,name):
    # --- Choose slices to visualize ---
    num_slices = 9
    slice_indices = np.linspace(0, image_data.shape[2] - 1, num_slices, dtype=int)

    # --- Plot slices ---
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.flatten()

    for i, slice_idx in enumerate(slice_indices):
        axes[i].imshow(image_data[:, :, slice_idx], cmap='gray')
        axes[i].set_title(f"Slice {slice_idx}")
        axes[i].axis('off')
    plt.savefig(name)
    plt.tight_layout()
    plt.show()

def load_nifti(file_path):
    """
    Load a NIfTI file and return the data array and affine.
    """
    nifti = nib.load(file_path)
    data = nifti.get_fdata()
    affine = nifti.affine
    return data, affine


def save_nifti(data, affine, output_path):
    """
    Save a NumPy array as a NIfTI file.
    """
    nifti = nib.Nifti1Image(data.astype(np.float32), affine)
    nib.save(nifti, output_path)

In [None]:
image_data,_=load_nifti(ct_file)
visualize(image_data,'test')

In [None]:
def segment_knee_regions(ct_path, mask_path, output_dir):
    """
    Segment Tibia, Femur, and Background from a CT volume using the mask.
    """
    # Load the CT scan and segmentation mask
    ct_data, ct_affine = load_nifti(ct_path)
    mask_data, _ = load_nifti(mask_path)

    # Define region labels (change if your mask uses different values)
    TIBIA_LABEL = 1
    FEMUR_LABEL = 2
    BACKGROUND_LABEL = 0

    # Generate binary masks
    tibia_mask = (mask_data == TIBIA_LABEL)
    femur_mask = (mask_data == FEMUR_LABEL)
    background_mask = (mask_data == BACKGROUND_LABEL)

    # Apply masks to CT data
    tibia_volume = ct_data * tibia_mask
    femur_volume = ct_data * femur_mask
    background_volume = ct_data * background_mask

    # Save each region as a new .nii.gz file
    os.makedirs(output_dir, exist_ok=True)
    save_nifti(tibia_volume, ct_affine, os.path.join(output_dir, "tibia_volume.nii.gz"))
    save_nifti(femur_volume, ct_affine, os.path.join(output_dir, "femur_volume.nii.gz"))
    save_nifti(background_volume, ct_affine, os.path.join(output_dir, "background_volume.nii.gz"))

    print("Segmentation done. Files saved to:", output_dir)


In [None]:
segment_knee_regions(ct_file, mask_file, output_folder)

In [None]:
def overlay_mask(image_slice, mask_slice, alpha=0.4, cmap_img='gray', cmap_mask='jet'):
    """
    Overlay a segmentation mask on a single image slice.
    
    Parameters:
    - image_slice: 2D numpy array of the image slice.
    - mask_slice: 2D numpy array of the mask slice (same size as image_slice).
    - alpha: float, transparency of the mask overlay.
    - cmap_img: str, colormap for the image.
    - cmap_mask: str, colormap for the mask.
    """
    plt.imshow(image_slice, cmap=cmap_img)
    if mask_slice is not None and np.any(mask_slice):
        plt.imshow(mask_slice, cmap=cmap_mask, alpha=alpha)
    plt.axis('off')

def visualize_nifti_with_mask(image_path, mask_path=None, num_slices=9, axis=2):
    """
    Visualize NIfTI image slices with optional mask overlays.
    
    Parameters:
    - image_path: str, path to the image .nii.gz file.
    - mask_path: str or None, path to the mask .nii.gz file (optional).
    - num_slices: int, number of slices to visualize.
    - axis: int, axis along which to slice (0=sagittal, 1=coronal, 2=axial).
    """
    # Load image and optional mask
    image, ct_affine = load_nifti(image_path)
    mask, _ = load_nifti(mask_path)
    assert mask is None or image.shape == mask.shape, "Image and mask must have the same shape"

    # Select slices evenly along the chosen axis
    slice_indices = np.linspace(0, image.shape[axis] - 1, num_slices, dtype=int)
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.flatten()

    for i, idx in enumerate(slice_indices):
        ax = axes[i]
        plt.sca(ax)

        if axis == 0:
            img_slice = image[idx, :, :]
            msk_slice = mask[idx, :, :] if mask is not None else None
        elif axis == 1:
            img_slice = image[:, idx, :]
            msk_slice = mask[:, idx, :] if mask is not None else None
        else:  # default to axial
            img_slice = image[:, :, idx]
            msk_slice = mask[:, :, idx] if mask is not None else None

        overlay_mask(img_slice, msk_slice)
        ax.set_title(f"Slice {idx}")
    plt.savefig('slices with masks ')
    plt.tight_layout()
    plt.show()


In [None]:
visualize_nifti_with_mask(
    ct_file,
    mask_file,  # Set to None if you don’t have a mask
    num_slices=9,
    axis=2  # axial view
)

In [None]:
def plot_image_mask_overlay(image_path, mask_path, slice_index, axis=2, alpha=0.4):
    """
    Plot image slice, mask slice, and overlay side by side.

    Parameters:
    - image: 3D numpy array (CT volume)
    - mask: 3D numpy array (segmentation mask)
    - slice_index: int, index of the slice to visualize
    - axis: int, slicing axis (0=sagittal, 1=coronal, 2=axial)
    - alpha: float, transparency for overlay
    """
    image, ct_affine = load_nifti(image_path)
    mask, _ = load_nifti(mask_path)
    # Extract slices
    if axis == 0:
        img_slice = image[slice_index, :, :]
        mask_slice = mask[slice_index, :, :]
    elif axis == 1:
        img_slice = image[:, slice_index, :]
        mask_slice = mask[:, slice_index, :]
    else:
        img_slice = image[:, :, slice_index]
        mask_slice = mask[:, :, slice_index]

    # Plotting
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(img_slice, cmap='gray')
    axs[0].set_title(f'Image Slice {slice_index}')
    axs[0].axis('off')

    axs[1].imshow(mask_slice, cmap='jet')
    axs[1].set_title(f'Mask Slice {slice_index}')
    axs[1].axis('off')

    axs[2].imshow(img_slice, cmap='gray')
    axs[2].imshow(mask_slice, cmap='jet', alpha=alpha)
    axs[2].set_title('Overlay')
    axs[2].axis('off')
    plt.savefig('slice_108_with_maska')
    plt.tight_layout()
    plt.show()

In [None]:
slice_index = 108  
axis = 2  # Axial view

plot_image_mask_overlay(ct_file, mask_file, slice_index, axis=axis)

In [1]:
pip install torch torchvision

Collecting torch
  Downloading torch-2.7.1-cp312-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.7.1-cp312-none-macosx_11_0_arm64.whl (68.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.6/68.6 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading torchvision-0.22.1-cp312-cp312-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m16.8 MB/s[0m eta [

Output shape: torch.Size([2, 16384])


In [7]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

# --- Helper Functions for Inflation ---

def inflate_conv(conv2d, kernel_depth):
    """Inflates a Conv2d layer to a Conv3d layer by repeating weights."""
    if conv2d.in_channels % conv2d.groups != 0 or conv2d.out_channels % conv2d.groups != 0:
         raise ValueError("Inflating grouped convolutions is not straightforward with simple repetition.")

    # Get 2D parameters
    in_channels = conv2d.in_channels
    out_channels = conv2d.out_channels
    kernel_size_2d = conv2d.kernel_size
    stride_2d = conv2d.stride
    padding_2d = conv2d.padding
    dilation_2d = conv2d.dilation
    groups = conv2d.groups
    bias = conv2d.bias is not None

    # Create 3D convolution parameters
    # Temporal kernel size
    kernel_size_3d = (kernel_depth, kernel_size_2d[0], kernel_size_2d[1])
    # Typically stride 1 in time for simple inflation unless downsampling is desired
    # Based on the previous I3D code structure, transitions downsample spatially,
    # and the initial conv might downsample spatially. Temporal downsampling is handled
    # by the pooling layers in transition blocks.
    # So we'll use a temporal stride of 1 here for convolution inflation.
    stride_3d = (1, stride_2d[0], stride_2d[1])
    # Temporal padding to keep temporal dimension size
    padding_3d = (kernel_depth // 2, padding_2d[0], padding_2d[1])
    dilation_3d = (1, dilation_2d[0], dilation_2d[1]) # Dilation usually 1 in time

    # Create 3D Conv layer
    conv3d = nn.Conv3d(in_channels, out_channels, kernel_size_3d, stride_3d, padding_3d,
                       dilation=dilation_3d, groups=groups, bias=bias)

    # Inflate weights (replication and scaling)
    conv2d_weights = conv2d.weight.data # Shape [out_channels, in_channels, height, width]

    # Ensure the temporal kernel depth is valid (e.g., odd for easy padding)
    # If kernel_depth is even, padding needs careful consideration or might not be centered
    # For simple repetition, odd is easiest for centered padding.
    if kernel_depth % 2 == 0 and kernel_depth > 1:
         warnings.warn(f"Inflating Conv2D to 3D with even kernel_depth ({kernel_depth}). "
                       "Temporal padding will be asymmetric or custom padding is needed.")
         # For simplicity here, we'll assume odd kernel_depth or handle even with floor division
         padding_3d = (kernel_depth // 2, padding_2d[0], padding_2d[1]) # Still use floor division for padding

    # Reshape 2D weights to [out_channels, in_channels, 1, height, width] (add a temporal dimension)
    inflated_weights = conv2d_weights.unsqueeze(2)
    # Repeat along the new temporal dimension
    inflated_weights = inflated_weights.repeat(1, 1, kernel_depth, 1, 1)

    # Normalize by dividing by the depth (as per requirement ii)
    if kernel_depth > 0: # Avoid division by zero if somehow kernel_depth is 0
        inflated_weights = inflated_weights / kernel_depth

    # Copy inflated weights to the 3D Conv layer
    conv3d.weight.data.copy_(inflated_weights)

    # Copy bias if it exists
    if bias:
        conv3d.bias.data.copy_(conv2d.bias.data)

    return conv3d

def inflate_batch_norm(bn2d):
    """Inflates a BatchNorm2d layer to a BatchNorm3d layer."""
    bn3d = nn.BatchNorm3d(bn2d.num_features)
    # Copy parameters and running statistics
    bn3d.weight.data.copy_(bn2d.weight.data)
    bn3d.bias.data.copy_(bn2d.bias.data)
    bn3d.running_mean.copy_(bn2d.running_mean)
    bn3d.running_var.copy_(bn2d.running_var)
    # bn3d.num_batches_tracked.copy_(bn2d.num_batches_tracked) # Copy if your BN uses this
    return bn3d

def inflate_relu(relu2d):
    """Returns a ReLU layer (same for 2D and 3D)."""
    return nn.ReLU(inplace=relu2d.inplace)

def inflate_pool(pool2d, temporal_stride=1):
    """Inflates a Pooling layer to a 3D Pooling layer."""
    # Get 2D parameters
    kernel_size_2d = pool2d.kernel_size
    stride_2d = pool2d.stride
    padding_2d = pool2d.padding
    dilation_2d = pool2d.dilation if hasattr(pool2d, 'dilation') else 1 # MaxPool2d has dilation
    return_indices = pool2d.return_indices if hasattr(pool2d, 'return_indices') else False # MaxPool2d has this
    ceil_mode = pool2d.ceil_mode if hasattr(pool2d, 'ceil_mode') else False # Pool2d has this

    # Ensure kernel_size, stride, padding are tuples for consistency
    if not isinstance(kernel_size_2d, tuple): kernel_size_2d = (kernel_size_2d,) * 2
    if not isinstance(stride_2d, tuple): stride_2d = (stride_2d,) * 2
    if not isinstance(padding_2d, tuple): padding_2d = (padding_2d,) * 2
    if not isinstance(dilation_2d, tuple): dilation_2d = (dilation_2d,) * 2


    # Create 3D pooling parameters
    # Temporal kernel size (1 for no pooling in time by the pool layer itself, temporal_stride handles downsampling)
    kernel_size_3d = (1, kernel_size_2d[0], kernel_size_2d[1])
    # Temporal stride for downsampling
    stride_3d = (temporal_stride, stride_2d[0], stride_2d[1])
    # Temporal padding (0 as we don't pool over time)
    padding_3d = (0, padding_2d[0], padding_2d[1])
    dilation_3d = (1, dilation_2d[0], dilation_2d[1])

    if isinstance(pool2d, nn.MaxPool2d):
        return nn.MaxPool3d(kernel_size_3d, stride=stride_3d, padding=padding_3d,
                            dilation=dilation_3d, return_indices=return_indices, ceil_mode=ceil_mode)
    elif isinstance(pool2d, nn.AvgPool2d):
        # AvgPool2d also has count_include_pad attribute
        count_include_pad = pool2d.count_include_pad if hasattr(pool2d, 'count_include_pad') else True
        return nn.AvgPool3d(kernel_size_3d, stride=stride_3d, padding=padding_3d, ceil_mode=ceil_mode,
                            count_include_pad=count_include_pad) # Count include pad might need adjustment for 3D?
    else:
        raise ValueError(f"Unsupported pooling type: {type(pool2d)}")

# --- Inflated DenseNet Components ---

class InflatedDenseLayer(nn.Module):
    def __init__(self, dense_layer2d, conv_kernel_depth=3):
        super(InflatedDenseLayer, self).__init__()
        self.layers = nn.Sequential()
        for name, child in dense_layer2d.named_children():
            if isinstance(child, nn.BatchNorm2d):
                self.layers.add_module(name, inflate_batch_norm(child))
            elif isinstance(child, nn.ReLU):
                self.layers.add_module(name, inflate_relu(child))
            elif isinstance(child, nn.Conv2d):
                # For the bottleneck 1x1 conv, kernel_depth is 1
                # For the 3x3 conv, kernel_depth is user-specified (default 3)
                current_kernel_depth = 1 if child.kernel_size == (1, 1) else conv_kernel_depth
                self.layers.add_module(name, inflate_conv(child, current_kernel_depth))
            else:
                raise ValueError(f"Unsupported layer type in DenseLayer: {type(child)}")
        # DenseLayer also has a drop_rate attribute
        self.drop_rate = dense_layer2d.drop_rate

    def forward(self, x):
        new_features = self.layers(x)
        if self.drop_rate > 0 and self.training: # Apply dropout only during training
            new_features = nn.functional.dropout(new_features, p=self.drop_rate, training=self.training)
        # DenseNet connectivity: concatenate input with new features
        return torch.cat([x, new_features], 1)

class InflatedTransition(nn.Module):
    def __init__(self, transition2d, temporal_pool_stride=2):
        super(InflatedTransition, self).__init__()
        self.layers = nn.Sequential()
        for name, child in transition2d.named_children():
            if isinstance(child, nn.BatchNorm2d):
                self.layers.add_module(name, inflate_batch_norm(child))
            elif isinstance(child, nn.ReLU):
                self.layers.add_module(name, inflate_relu(child))
            elif isinstance(child, nn.Conv2d): # This is the 1x1 convolution
                 self.layers.add_module(name, inflate_conv(child, kernel_depth=1))
            elif isinstance(child, nn.AvgPool2d): # This is the pooling layer for downsampling
                self.layers.add_module(name, inflate_pool(child, temporal_stride=temporal_pool_stride))
            else:
                raise ValueError(f"Unsupported layer type in Transition: {type(child)}")

    def forward(self, x):
        return self.layers(x)

# --- Main Inflation Function ---

def inflate_densenet121(densenet2d, frame_nb, conv_kernel_depth=3, temporal_pool_stride=2):
    """
    Inflates a torchvision DenseNet121 model to a 3D model.

    Args:
        densenet2d (torchvision.models.densenet.DenseNet): The pre-trained 2D DenseNet121 model.
        frame_nb (int): The expected number of frames in the input video.
        conv_kernel_depth (int): The temporal kernel size to use for inflating 2D convs > 1x1.
                                 Default is 3 (3x3x3).
        temporal_pool_stride (int): The temporal stride to use for inflating spatial pooling
                                    layers in transition blocks. Default is 2.

    Returns:
        torch.nn.Module: The inflated 3D DenseNet model.
    """
    # Inflate the features part (contains initial conv, pool, dense blocks, transitions)
    features_3d = nn.Sequential()
    transition_nb = 0

    for name, child in densenet2d.features.named_children():
        if isinstance(child, nn.Conv2d): # Initial Conv2d
             # The initial convolution in DenseNet121 is 7x7. Inflate it.
             features_3d.add_module(name, inflate_conv(child, kernel_depth=child.kernel_size[0])) # Use spatial kernel size for time
        elif isinstance(child, nn.BatchNorm2d): # Initial BatchNorm
             features_3d.add_module(name, inflate_batch_norm(child))
        elif isinstance(child, nn.ReLU): # Initial ReLU
             features_3d.add_module(name, inflate_relu(child))
        elif isinstance(child, nn.MaxPool2d): # Initial MaxPooling
             # Initial pool typically reduces spatial dimensions but not temporal
             # We'll make the temporal stride 1 here
             features_3d.add_module(name, inflate_pool(child, temporal_stride=1))
        elif isinstance(child, models.densenet._DenseBlock):
             # Inflate the DenseBlock
             block_3d = nn.Sequential()
             for nested_name, nested_child in child.named_children():
                 # Each child in a DenseBlock is a DenseLayer
                 assert isinstance(nested_child, models.densenet._DenseLayer)
                 block_3d.add_module(nested_name, InflatedDenseLayer(nested_child, conv_kernel_depth=conv_kernel_depth))
             features_3d.add_module(name, block_3d)
        elif isinstance(child, models.densenet._Transition):
             # Inflate the Transition layer
             features_3d.add_module(name, InflatedTransition(child, temporal_pool_stride=temporal_pool_stride))
             transition_nb += 1
        else:
            # print(f"Warning: Skipping unhandled layer type in features: {name} ({type(child)})")
            pass # Skip layers like OrderedDictWrapper if they appear

    # Calculate the final temporal dimension
    # Assumes each transition block reduces temporal dimension by temporal_pool_stride
        # Calculate final_time_dim
    temporal_reduction_factor = int(math.pow(temporal_pool_stride, transition_nb))
    final_time_dim = frame_nb // temporal_reduction_factor
    if frame_nb % temporal_reduction_factor != 0:
         warnings.warn(f"Input frame_nb ({frame_nb}) is not perfectly divisible by temporal reduction factor ({temporal_reduction_factor}). "
                       "Final temporal dimension will be floor division result.")

    # Inflate the classifier part
    # Get the original 2D Linear classifier module directly
    original_linear_classifier = densenet2d.classifier

    if not isinstance(original_linear_classifier, nn.Linear):
         raise TypeError(f"Expected original classifier to be nn.Linear, but got {type(original_linear_classifier)}")

    # Calculate the inflated input features for the linear layer
    original_in_features = original_linear_classifier.in_features # This should be 2048
    original_out_features = original_linear_classifier.out_features # This should be 1000
    inflated_in_features = final_time_dim * original_in_features

    # Create the inflated 3D Linear layer
    inflated_linear_layer = nn.Linear(inflated_in_features, original_out_features, bias=original_linear_classifier.bias is not None)

    # Inflate weights
    linear2d_weights = original_linear_classifier.weight.data # Shape [out_features, in_features_2d]

    # Reshape 2D weights to [out_features_2d, in_features_2d, 1]
    inflated_weights = linear2d_weights.unsqueeze(2)
    # Repeat along the new temporal dimension
    inflated_weights = inflated_weights.repeat(1, 1, final_time_dim)
    # Reshape to match the 3D linear layer's expected shape [out_features_2d, inflated_in_features]
    inflated_weights = inflated_weights.view(original_out_features, inflated_in_features)

    inflated_linear_layer.weight.data.copy_(inflated_weights)

    # Copy bias if it exists
    if original_linear_classifier.bias is not None:
        inflated_linear_layer.bias.data.copy_(original_linear_classifier.bias.data)

    # Create a Sequential module containing just the inflated linear layer
    classifier_3d = nn.Sequential(inflated_linear_layer)


    class InflatedDenseNetModel(nn.Module):
        def __init__(self, features_3d, classifier_3d, final_time_dim, final_layer_nb):
            super().__init__()
            self.features = features_3d
            self.classifier = classifier_3d
            self.final_time_dim = final_time_dim
            self.final_layer_nb = final_layer_nb # This is the number of channels before global pooling

        def forward(self, x):
            x = self.features(x)
            # Apply ReLU after features (matches typical DenseNet flow before classification)
            x = nn.functional.relu(x)
            # Global spatial average pooling. Kernel size matches expected output spatial dims.
            # Assuming 7x7 spatial dims before classification for DenseNet121 on ImageNet size inputs
            spatial_kernel_h = x.shape[-2]
            spatial_kernel_w = x.shape[-1]
            x = nn.functional.avg_pool3d(x, kernel_size=(1, spatial_kernel_h, spatial_kernel_w))
            # Flatten for classifier
            # Original shape: [batch, channels, depth, 1, 1] after spatial pooling
            # Permute to [batch, depth, channels, 1, 1]
            x = x.permute(0, 2, 1, 3, 4).contiguous()
            # View to [batch, depth * channels]
            x = x.view(-1, self.final_time_dim * self.final_layer_nb)
            # Pass through classifier
            x = self.classifier(x)
            return x

    # Get the number of channels before the original 2D classifier
    final_layer_nb = densenet2d.classifier.in_features

    return InflatedDenseNetModel(features_3d, classifier_3d, final_time_dim, final_layer_nb)


# --- Testing the Inflation ---

import warnings

# a. Take a 2D pretrained DenseNet121 model
model_2d = models.densenet121(pretrained=True)
print("Loaded 2D DenseNet121 model.")

# Expected number of frames in your video input
input_frame_nb = 16
# Temporal kernel depth for inflating 3x3 convs
conv_k_depth = 3
# Temporal stride for pooling in transition blocks
pool_t_stride = 2

# Inflate the model
i3d_densenet_model = inflate_densenet121(
    model_2d,
    frame_nb=input_frame_nb,
    conv_kernel_depth=conv_k_depth,
    temporal_pool_stride=pool_t_stride
)

print(f"\nInflated 3D DenseNet model created for {input_frame_nb} frames.")
# Print the structure of the inflated model (optional, can be long)
# print(i3d_densenet_model)


# Create a dummy 3D input tensor
# Shape: [batch_size, channels, depth, height, width]
batch_size = 2
channels = 3
input_height = 224
input_width = 224

dummy_input_3d = torch.randn(batch_size, channels, input_frame_nb, input_height, input_width)
print("\nDummy input shape:", dummy_input_3d.shape)


# Perform a forward pass with the dummy input
try:
    i3d_densenet_model.eval() # Set model to evaluation mode (disables dropout, uses running stats for BN)
    with torch.no_grad(): # Disable gradient calculation for inference
        output = i3d_densenet_model(dummy_input_3d)

    print("\nForward pass successful!")
    print("Output shape:", output.shape)
    # The output shape should be [batch_size, num_classes]
    # For DenseNet121 pre-trained on ImageNet, num_classes is 1000
    print("Expected output shape:", (batch_size, 1000))

except Exception as e:
    print(f"\nForward pass failed: {e}")

Loaded 2D DenseNet121 model.

Inflated 3D DenseNet model created for 16 frames.

Dummy input shape: torch.Size([2, 3, 16, 224, 224])

Forward pass successful!
Output shape: torch.Size([2, 1000])
Expected output shape: (2, 1000)


In [6]:
print(model_2d)

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [8]:
print(i3d_densenet_model)

InflatedDenseNetModel(
  (features): Sequential(
    (conv0): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
    (norm0): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=(1, 1, 1), ceil_mode=False)
    (denseblock1): Sequential(
      (denselayer1): InflatedDenseLayer(
        (layers): Sequential(
          (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (norm2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
      )
      (denselayer

In [None]:
import torch.nn.functional as F

def global_avg_pool_3d(x):
    return F.adaptive_avg_pool3d(x, 1).view(x.size(0), -1)

def extract_region_features(model_3d, volume_dict):
    """
    Extract features from 3 layers for each region in a 3D CT scan.

    Args:
        model_3d: Inflated DenseNet121 model (Conv3D based).
        volume_dict: Dict with keys 'Tibia', 'Femur', 'Background', each with a 3D volume tensor [1, C, D, H, W].

    Returns:
        Dictionary:
        {
            "Tibia": {"fifth_last": ..., "third_last": ..., "last": ...},
            "Femur": {...},
            "Background": {...}
        }
    """
    model_3d.eval()
    output_dict = {}

    # Prepare hooks
    feature_maps = {}
    def make_hook(name):
        def hook(module, input, output):
            feature_maps[name] = output
        return hook

    # Register hooks for the 3 layers
    hooks = []
    hooks.append(model_3d.features.denseblock2.register_forward_hook(make_hook('fifth_last')))
    hooks.append(model_3d.features.denseblock3.register_forward_hook(make_hook('third_last')))
    hooks.append(model_3d.features.denseblock4.register_forward_hook(make_hook('last')))

    # Process each region separately
    with torch.no_grad():
        for region_name, volume in volume_dict.items():
            feature_maps.clear()
            _ = model_3d(volume)  # Forward pass
            output_dict[region_name] = {
                name: global_avg_pool_3d(feat) for name, feat in feature_maps.items()
            }

    # Remove hooks
    for h in hooks:
        h.remove()

    return output_dict


In [None]:
import nibabel as nib

def load_nifti_to_tensor(path):
    """
    Load .nii or .nii.gz file into a PyTorch tensor of shape [D, H, W]
    """
    img = nib.load(path)
    data = img.get_fdata()  # Returns NumPy array, shape: [D, H, W] or [H, W, D]
    data = torch.tensor(data).permute(2, 0, 1)  # Rearrange to [D, H, W]
    return data.float()
def prepare_volumes(ct_scan, mask):
    """
    Create region-specific masked 3D volumes as input to 3D CNN.

    Args:
        ct_scan: Tensor of shape [D, H, W]
        mask: Tensor of same shape [D, H, W], values: 0=Background, 1=Tibia, 2=Femur

    Returns:
        Dict with keys "Tibia", "Femur", "Background", values are tensors of shape [1, 3, D, H, W]
    """
    ct_scan = normalize_ct(ct_scan)
    
    tibia_mask = (mask == 1).float()
    femur_mask = (mask == 2).float()
    background_mask = (mask == 0).float()

    tibia_volume = ct_scan * tibia_mask
    femur_volume = ct_scan * femur_mask
    background_volume = ct_scan * background_mask

    # Duplicate channels to get 3 input channels
    def to_3channel(vol):
        vol = vol.unsqueeze(0)  # [1, D, H, W]
        vol = vol.repeat(3, 1, 1, 1)  # [3, D, H, W]
        return vol.unsqueeze(0)  # [1, 3, D, H, W]

    return {
        "Tibia": to_3channel(tibia_volume),
        "Femur": to_3channel(femur_volume),
        "Background": to_3channel(background_volume),
    }


ct_path = ct_file
mask_path = mask_file

ct_scan = load_nifti_to_tensor(ct_path)   # Shape: [D, H, W]
mask = load_nifti_to_tensor(mask_path)    # Shape: [D, H, W], values: 0, 1, 2

# Prepare inputs
region_volumes = prepare_volumes(ct_scan, mask)

# Send to device if needed
for k in region_volumes:
    region_volumes[k] = region_volumes[k].to("cuda" if torch.cuda.is_available() else "cpu")

# Run feature extraction
model_3d = inflate_densenet121_to_3d().to(region_volumes["Tibia"].device)
features = extract_region_features(model_3d, region_volumes)

# Example: show shape of Tibia feature from last layer
print(features["Tibia"]["last"].shape)