In [1]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os 

In [2]:
ct_file = "../Task1/image/left_knee.nii.gz"
mask_file = "..//Task1/output/bone_segmentation_task1_1.nii.gz"
output_folder = "segmented_regions"

In [None]:
def visualize(image_data,name):
    # --- Choose slices to visualize ---
    num_slices = 9
    slice_indices = np.linspace(0, image_data.shape[2] - 1, num_slices, dtype=int)

    # --- Plot slices ---
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.flatten()

    for i, slice_idx in enumerate(slice_indices):
        axes[i].imshow(image_data[:, :, slice_idx], cmap='gray')
        axes[i].set_title(f"Slice {slice_idx}")
        axes[i].axis('off')
    plt.savefig(name)
    plt.tight_layout()
    plt.show()

def load_nifti(file_path):
    """
    Load a NIfTI file and return the data array and affine.
    """
    nifti = nib.load(file_path)
    data = nifti.get_fdata()
    affine = nifti.affine
    return data, affine


def save_nifti(data, affine, output_path):
    """
    Save a NumPy array as a NIfTI file.
    """
    nifti = nib.Nifti1Image(data.astype(np.float32), affine)
    nib.save(nifti, output_path)

In [None]:
image_data,_=load_nifti(ct_file)
visualize(image_data,'test')

In [None]:
def segment_knee_regions(ct_path, mask_path, output_dir):
    """
    Segment Tibia, Femur, and Background from a CT volume using the mask.
    """
    # Load the CT scan and segmentation mask
    ct_data, ct_affine = load_nifti(ct_path)
    mask_data, _ = load_nifti(mask_path)

    # Define region labels (change if your mask uses different values)
    TIBIA_LABEL = 1
    FEMUR_LABEL = 2
    BACKGROUND_LABEL = 0

    # Generate binary masks
    tibia_mask = (mask_data == TIBIA_LABEL)
    femur_mask = (mask_data == FEMUR_LABEL)
    background_mask = (mask_data == BACKGROUND_LABEL)

    # Apply masks to CT data
    tibia_volume = ct_data * tibia_mask
    femur_volume = ct_data * femur_mask
    background_volume = ct_data * background_mask

    # Save each region as a new .nii.gz file
    os.makedirs(output_dir, exist_ok=True)
    save_nifti(tibia_volume, ct_affine, os.path.join(output_dir, "tibia_volume.nii.gz"))
    save_nifti(femur_volume, ct_affine, os.path.join(output_dir, "femur_volume.nii.gz"))
    save_nifti(background_volume, ct_affine, os.path.join(output_dir, "background_volume.nii.gz"))

    print("Segmentation done. Files saved to:", output_dir)


In [None]:
segment_knee_regions(ct_file, mask_file, output_folder)

In [None]:
def overlay_mask(image_slice, mask_slice, alpha=0.4, cmap_img='gray', cmap_mask='jet'):
    """
    Overlay a segmentation mask on a single image slice.
    
    Parameters:
    - image_slice: 2D numpy array of the image slice.
    - mask_slice: 2D numpy array of the mask slice (same size as image_slice).
    - alpha: float, transparency of the mask overlay.
    - cmap_img: str, colormap for the image.
    - cmap_mask: str, colormap for the mask.
    """
    plt.imshow(image_slice, cmap=cmap_img)
    if mask_slice is not None and np.any(mask_slice):
        plt.imshow(mask_slice, cmap=cmap_mask, alpha=alpha)
    plt.axis('off')

def visualize_nifti_with_mask(image_path, mask_path=None, num_slices=9, axis=2):
    """
    Visualize NIfTI image slices with optional mask overlays.
    
    Parameters:
    - image_path: str, path to the image .nii.gz file.
    - mask_path: str or None, path to the mask .nii.gz file (optional).
    - num_slices: int, number of slices to visualize.
    - axis: int, axis along which to slice (0=sagittal, 1=coronal, 2=axial).
    """
    # Load image and optional mask
    image, ct_affine = load_nifti(image_path)
    mask, _ = load_nifti(mask_path)
    assert mask is None or image.shape == mask.shape, "Image and mask must have the same shape"

    # Select slices evenly along the chosen axis
    slice_indices = np.linspace(0, image.shape[axis] - 1, num_slices, dtype=int)
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.flatten()

    for i, idx in enumerate(slice_indices):
        ax = axes[i]
        plt.sca(ax)

        if axis == 0:
            img_slice = image[idx, :, :]
            msk_slice = mask[idx, :, :] if mask is not None else None
        elif axis == 1:
            img_slice = image[:, idx, :]
            msk_slice = mask[:, idx, :] if mask is not None else None
        else:  # default to axial
            img_slice = image[:, :, idx]
            msk_slice = mask[:, :, idx] if mask is not None else None

        overlay_mask(img_slice, msk_slice)
        ax.set_title(f"Slice {idx}")
    plt.savefig('slices with masks ')
    plt.tight_layout()
    plt.show()


In [None]:
visualize_nifti_with_mask(
    ct_file,
    mask_file,  # Set to None if you don’t have a mask
    num_slices=9,
    axis=2  # axial view
)

In [None]:
def plot_image_mask_overlay(image_path, mask_path, slice_index, axis=2, alpha=0.4):
    """
    Plot image slice, mask slice, and overlay side by side.

    Parameters:
    - image: 3D numpy array (CT volume)
    - mask: 3D numpy array (segmentation mask)
    - slice_index: int, index of the slice to visualize
    - axis: int, slicing axis (0=sagittal, 1=coronal, 2=axial)
    - alpha: float, transparency for overlay
    """
    image, ct_affine = load_nifti(image_path)
    mask, _ = load_nifti(mask_path)
    # Extract slices
    if axis == 0:
        img_slice = image[slice_index, :, :]
        mask_slice = mask[slice_index, :, :]
    elif axis == 1:
        img_slice = image[:, slice_index, :]
        mask_slice = mask[:, slice_index, :]
    else:
        img_slice = image[:, :, slice_index]
        mask_slice = mask[:, :, slice_index]

    # Plotting
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(img_slice, cmap='gray')
    axs[0].set_title(f'Image Slice {slice_index}')
    axs[0].axis('off')

    axs[1].imshow(mask_slice, cmap='jet')
    axs[1].set_title(f'Mask Slice {slice_index}')
    axs[1].axis('off')

    axs[2].imshow(img_slice, cmap='gray')
    axs[2].imshow(mask_slice, cmap='jet', alpha=alpha)
    axs[2].set_title('Overlay')
    axs[2].axis('off')
    plt.savefig('slice_108_with_maska')
    plt.tight_layout()
    plt.show()

In [None]:
slice_index = 108  
axis = 2  # Axial view

plot_image_mask_overlay(ct_file, mask_file, slice_index, axis=axis)

In [1]:
pip install torch torchvision

Collecting torch
  Downloading torch-2.7.1-cp312-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.7.1-cp312-none-macosx_11_0_arm64.whl (68.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.6/68.6 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading torchvision-0.22.1-cp312-cp312-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m16.8 MB/s[0m eta [

In [24]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

# --- Helper Functions for Inflation ---

def inflate_conv(conv2d, kernel_depth):
    """Inflates a Conv2d layer to a Conv3d layer by repeating weights."""
    if conv2d.in_channels % conv2d.groups != 0 or conv2d.out_channels % conv2d.groups != 0:
         raise ValueError("Inflating grouped convolutions is not straightforward with simple repetition.")

    # Get 2D parameters
    in_channels = conv2d.in_channels
    out_channels = conv2d.out_channels
    kernel_size_2d = conv2d.kernel_size
    stride_2d = conv2d.stride
    padding_2d = conv2d.padding
    dilation_2d = conv2d.dilation
    groups = conv2d.groups
    bias = conv2d.bias is not None

    # Create 3D convolution parameters
    # Temporal kernel size
    kernel_size_3d = (kernel_depth, kernel_size_2d[0], kernel_size_2d[1])
    
    # Typically stride 1 in time for simple inflation unless downsampling is desired
    # Based on the previous I3D code structure, transitions downsample spatially,
    # and the initial conv might downsample spatially. Temporal downsampling is handled
    # by the pooling layers in transition blocks.
    # So we'll use a temporal stride of 1 here for convolution inflation.
    stride_3d = (1, stride_2d[0], stride_2d[1])
    # Temporal padding to keep temporal dimension size
    padding_3d = (kernel_depth // 2, padding_2d[0], padding_2d[1])
    dilation_3d = (1, dilation_2d[0], dilation_2d[1]) # Dilation usually 1 in time

    # Create 3D Conv layer
    conv3d = nn.Conv3d(in_channels, out_channels, kernel_size_3d, stride_3d, padding_3d,
                       dilation=dilation_3d, groups=groups, bias=bias)

    # Inflate weights (replication and scaling)
    conv2d_weights = conv2d.weight.data #  Shape [out_channels, 3, H, W]
    #averaged_weights = conv2d_weights.mean(dim=1, keepdim=True) # Shape [out_channels, 1, H, W]

                 # Now repeat this 1-channel weight along the temporal dimension
    inflated_weights = conv2d_weights.unsqueeze(2).repeat(1, 1, kernel_depth, 1, 1)



    # Normalize by dividing by the depth (as per requirement ii)
    if kernel_depth > 0: # Avoid division by zero if somehow kernel_depth is 0
        inflated_weights = inflated_weights / kernel_depth

    # Copy inflated weights to the 3D Conv layer
    conv3d.weight.data.copy_(inflated_weights)

    # Copy bias if it exists
    if bias:
        conv3d.bias.data.copy_(conv2d.bias.data)

    return conv3d

def inflate_batch_norm(bn2d):
    """Inflates a BatchNorm2d layer to a BatchNorm3d layer."""
    bn3d = nn.BatchNorm3d(bn2d.num_features)
    # Copy parameters and running statistics
    bn3d.weight.data.copy_(bn2d.weight.data)
    bn3d.bias.data.copy_(bn2d.bias.data)
    bn3d.running_mean.copy_(bn2d.running_mean)
    bn3d.running_var.copy_(bn2d.running_var)
    # bn3d.num_batches_tracked.copy_(bn2d.num_batches_tracked) # Copy if your BN uses this
    return bn3d

def inflate_relu(relu2d):
    """Returns a ReLU layer (same for 2D and 3D)."""
    return nn.ReLU(inplace=relu2d.inplace)

def inflate_pool(pool2d, temporal_stride=1):
    """Inflates a Pooling layer to a 3D Pooling layer."""
    # Get 2D parameters
    kernel_size_2d = pool2d.kernel_size
    stride_2d = pool2d.stride
    padding_2d = pool2d.padding
    dilation_2d = pool2d.dilation if hasattr(pool2d, 'dilation') else 1 # MaxPool2d has dilation
    return_indices = pool2d.return_indices if hasattr(pool2d, 'return_indices') else False # MaxPool2d has this
    ceil_mode = pool2d.ceil_mode if hasattr(pool2d, 'ceil_mode') else False # Pool2d has this

    # Ensure kernel_size, stride, padding are tuples for consistency
    if not isinstance(kernel_size_2d, tuple): kernel_size_2d = (kernel_size_2d,) * 2
    if not isinstance(stride_2d, tuple): stride_2d = (stride_2d,) * 2
    if not isinstance(padding_2d, tuple): padding_2d = (padding_2d,) * 2
    if not isinstance(dilation_2d, tuple): dilation_2d = (dilation_2d,) * 2


    # Create 3D pooling parameters
    # Temporal kernel size (1 for no pooling in time by the pool layer itself, temporal_stride handles downsampling)
    kernel_size_3d = (1, kernel_size_2d[0], kernel_size_2d[1])
    # Temporal stride for downsampling
    stride_3d = (temporal_stride, stride_2d[0], stride_2d[1])
    # Temporal padding (0 as we don't pool over time)
    padding_3d = (0, padding_2d[0], padding_2d[1])
    dilation_3d = (1, dilation_2d[0], dilation_2d[1])

    if isinstance(pool2d, nn.MaxPool2d):
        return nn.MaxPool3d(kernel_size_3d, stride=stride_3d, padding=padding_3d,
                            dilation=dilation_3d, return_indices=return_indices, ceil_mode=ceil_mode)
    elif isinstance(pool2d, nn.AvgPool2d):
        # AvgPool2d also has count_include_pad attribute
        count_include_pad = pool2d.count_include_pad if hasattr(pool2d, 'count_include_pad') else True
        return nn.AvgPool3d(kernel_size_3d, stride=stride_3d, padding=padding_3d, ceil_mode=ceil_mode,
                            count_include_pad=count_include_pad) # Count include pad might need adjustment for 3D?
    else:
        raise ValueError(f"Unsupported pooling type: {type(pool2d)}")

# --- Inflated DenseNet Components ---

class InflatedDenseLayer(nn.Module):
    def __init__(self, dense_layer2d, conv_kernel_depth=3):
        super(InflatedDenseLayer, self).__init__()
        self.layers = nn.Sequential()
        for name, child in dense_layer2d.named_children():
            if isinstance(child, nn.BatchNorm2d):
                self.layers.add_module(name, inflate_batch_norm(child))
            elif isinstance(child, nn.ReLU):
                self.layers.add_module(name, inflate_relu(child))
            elif isinstance(child, nn.Conv2d):
                # For the bottleneck 1x1 conv, kernel_depth is 1
                # For the 3x3 conv, kernel_depth is user-specified (default 3)
                self.layers.add_module(name, inflate_conv(child, kernel_depth=child.kernel_size[0] if child.kernel_size != (1,1) else 1))

            else:
                raise ValueError(f"Unsupported layer type in DenseLayer: {type(child)}")
        # DenseLayer also has a drop_rate attribute
        self.drop_rate = dense_layer2d.drop_rate

    def forward(self, x):
        new_features = self.layers(x)
        if self.drop_rate > 0 and self.training: # Apply dropout only during training
            new_features = nn.functional.dropout(new_features, p=self.drop_rate, training=self.training)
        # DenseNet connectivity: concatenate input with new features
        return torch.cat([x, new_features], 1)

class InflatedTransition(nn.Module):
    def __init__(self, transition2d, temporal_pool_stride=2):
        super(InflatedTransition, self).__init__()
        self.layers = nn.Sequential()
        for name, child in transition2d.named_children():
            if isinstance(child, nn.BatchNorm2d):
                self.layers.add_module(name, inflate_batch_norm(child))
            elif isinstance(child, nn.ReLU):
                self.layers.add_module(name, inflate_relu(child))
            elif isinstance(child, nn.Conv2d): # This is the 1x1 convolution
                 self.layers.add_module(name, inflate_conv(child, kernel_depth=1))
            elif isinstance(child, nn.AvgPool2d): # This is the pooling layer for downsampling
                self.layers.add_module(name, inflate_pool(child, temporal_stride=temporal_pool_stride))
            else:
                raise ValueError(f"Unsupported layer type in Transition: {type(child)}")

    def forward(self, x):
        return self.layers(x)

# --- Main Inflation Function ---

def inflate_densenet121(densenet2d, frame_nb, conv_kernel_depth=3, temporal_pool_stride=2,input_channels=1):
    """
    Inflates a torchvision DenseNet121 model to a 3D model.

    Args:
        densenet2d (torchvision.models.densenet.DenseNet): The pre-trained 2D DenseNet121 model.
        frame_nb (int): The expected number of frames in the input video.
        conv_kernel_depth (int): The temporal kernel size to use for inflating 2D convs > 1x1.
                                 Default is 3 (3x3x3).
        temporal_pool_stride (int): The temporal stride to use for inflating spatial pooling
                                    layers in transition blocks. Default is 2.

    Returns:
        torch.nn.Module: The inflated 3D DenseNet model.
    """
    # Inflate the features part (contains initial conv, pool, dense blocks, transitions)
    features_3d = nn.Sequential()
    transition_nb = 0
    # Flag to identify the very first convolutional layer
    is_first_conv_layer = True
    for name, child in densenet2d.features.named_children():
        if isinstance(child, nn.Conv2d): # Initial Conv2d
            if is_first_conv_layer:
                 # --- SPECIAL HANDLING FOR THE VERY FIRST CONV LAYER ---
                 # The first Conv2d in DenseNet121 has in_channels=3.
                 # We need to change its in_channels to 'input_channels' (e.g., 1)
                 # and adapt its weights.

                 original_conv2d = child
                 original_in_channels = original_conv2d.in_channels # This is 3
                 if original_in_channels != 3:
                     # This inflation logic assumes the original model starts with 3 channels
                     warnings.warn(f"Expected first Conv2d to have 3 input channels, but got {original_in_channels}. Inflation logic might be incorrect.")

                 # Get original parameters except in_channels
                 out_channels = original_conv2d.out_channels
                 kernel_size_2d = original_conv2d.kernel_size
                 stride_2d = original_conv2d.stride
                 padding_2d = original_conv2d.padding
                 dilation_2d = original_conv2d.dilation
                 groups = original_conv2d.groups
                 bias = original_conv2d.bias is not None

                 # Create the new Conv3d with the DESIRED input_channels
                 # The temporal kernel size for the first layer is typically the spatial kernel size (7x7 -> 7x7x7)
                 kernel_depth_3d = kernel_size_2d[0] # Use spatial kernel size for time

                 kernel_size_3d = (kernel_depth_3d, kernel_size_2d[0], kernel_size_2d[1])
                 stride_3d = (1, stride_2d[0], stride_2d[1]) # Temporal stride 1
                 padding_3d = (kernel_depth_3d // 2, padding_2d[0], padding_2d[1])
                 dilation_3d = (1, dilation_2d[0], dilation_2d[1])

                 first_conv3d = nn.Conv3d(input_channels, out_channels, kernel_size_3d, stride_3d, padding_3d,
                                          dilation=dilation_3d, groups=groups, bias=bias)

                 # Inflate weights: Original weights shape [out_channels, 3, H, W]
                 original_weights = original_conv2d.weight.data

                 if input_channels == 1 and original_in_channels == 3:
                     # Average the original 3 input channel weights
                     adapted_weights_2d = original_weights.mean(dim=1, keepdim=True) # Shape [out_channels, 1, H, W]
                 elif input_channels == original_in_channels:
                     # Input channels match, just use original weights
                     adapted_weights_2d = original_weights
                 else:
                     # Handle other input channel numbers if needed (e.g., randomly initialize new weights)
                     warnings.warn(f"Handling inflation from {original_in_channels} to {input_channels} channels. "
                                   "Using averaged weights if original=3, otherwise weights might need custom handling.")
                     if original_in_channels == 3:
                         adapted_weights_2d = original_weights.mean(dim=1, keepdim=True).repeat(1, input_channels, 1, 1) # Repeat averaged weight
                     else:
                         # Fallback: Random initialization for the new layer if sizes don't match
                         print(f"Initializing weights for first Conv3d ({original_in_channels} -> {input_channels}) randomly.")
                         # The first_conv3d layer is already initialized randomly by default, so nothing more to do here.
                         # We can just skip the weight copying step.
                         adapted_weights_2d = None # Indicate no specific weights to copy


                 if adapted_weights_2d is not None:
                    # Repeat the adapted 2D weights along the temporal dimension
                    # Shape [out_channels, input_channels, 1, H, W] -> [out_channels, input_channels, D, H, W]
                    inflated_weights = adapted_weights_2d.unsqueeze(2).repeat(1, 1, kernel_depth_3d, 1, 1)

                    # Normalize by dividing by the depth
                    if kernel_depth_3d > 0:
                         inflated_weights = inflated_weights / kernel_depth_3d

                    # Copy inflated weights to the new Conv3d
                    first_conv3d.weight.data.copy_(inflated_weights)

                 if bias:
                     first_conv3d.bias.data.copy_(original_conv2d.bias.data)

                 features_3d.add_module(name, first_conv3d)
                 is_first_conv_layer = False # Mark that the first conv is processed

            else:
                 # --- STANDARD INFLATION FOR SUBSEQUENT CONV LAYERS ---
                 # These layers should use the standard inflate_conv logic
                 # Their in_channels will match the out_channels of the preceding 3D layer.
                 # We still need to handle the kernel depth for non-1x1 convs.
                 temporal_k_depth = child.kernel_size[0] if child.kernel_size != (1,1) else 1
                 features_3d.add_module(name, inflate_conv_standard(child, kernel_depth=temporal_k_depth))

             # The initial convolution in DenseNet121 is 7x7. Inflate it.
            
        elif isinstance(child, nn.BatchNorm2d): # Initial BatchNorm
             features_3d.add_module(name, inflate_batch_norm(child))
        elif isinstance(child, nn.ReLU): # Initial ReLU
             features_3d.add_module(name, inflate_relu(child))
        elif isinstance(child, nn.MaxPool2d): # Initial MaxPooling
             # Initial pool typically reduces spatial dimensions but not temporal
             # We'll make the temporal stride 1 here
             features_3d.add_module(name, inflate_pool(child, temporal_stride=1))
        elif isinstance(child, models.densenet._DenseBlock):
             # Inflate the DenseBlock
             block_3d = nn.Sequential()
             for nested_name, nested_child in child.named_children():
                 # Each child in a DenseBlock is a DenseLayer
                 assert isinstance(nested_child, models.densenet._DenseLayer)
                 block_3d.add_module(nested_name, InflatedDenseLayer(nested_child, conv_kernel_depth=conv_kernel_depth))
             features_3d.add_module(name, block_3d)
        elif isinstance(child, models.densenet._Transition):
             # Inflate the Transition layer
             features_3d.add_module(name, InflatedTransition(child, temporal_pool_stride=temporal_pool_stride))
             transition_nb += 1
        else:
            # print(f"Warning: Skipping unhandled layer type in features: {name} ({type(child)})")
            pass # Skip layers like OrderedDictWrapper if they appear

    # Calculate the final temporal dimension
    # Assumes each transition block reduces temporal dimension by temporal_pool_stride
    temporal_reduction_factor = int(math.pow(temporal_pool_stride, transition_nb))
    final_time_dim = frame_nb // temporal_reduction_factor
    if frame_nb % temporal_reduction_factor != 0:
         warnings.warn(f"Input frame_nb ({frame_nb}) is not perfectly divisible by temporal reduction factor ({temporal_reduction_factor}). "
                       "Final temporal dimension will be floor division result.")


    # Inflate the classifier part
    classifier_3d = nn.Sequential()
    for name, child in densenet2d.classifier.named_children():
        if isinstance(child, nn.Linear):
            # The linear layer input size needs to account for the flattened 3D features
            # It's final_time_dim * original_classifier_in_features
            original_in_features = child.in_features
            inflated_in_features = final_time_dim * original_in_features

            # Create new 3D linear layer
            linear3d = nn.Linear(inflated_in_features, child.out_features, bias=child.bias is not None)

            # Inflate weights
            linear2d_weights = child.weight.data # Shape [out_features, in_features_2d]

            # Reshape 2D weights to [out_features, in_features_2d, 1] (add a temporal dimension)
            inflated_weights = linear2d_weights.unsqueeze(2)
            # Repeat along the new temporal dimension
            inflated_weights = inflated_weights.repeat(1, 1, final_time_dim)
            # Reshape to match the 3D linear layer's expected shape [out_features, inflated_in_features]
            inflated_weights = inflated_weights.view(child.out_features, inflated_in_features)

            linear3d.weight.data.copy_(inflated_weights)

            # Copy bias if it exists
            if child.bias is not None:
                linear3d.bias.data.copy_(child.bias.data)

            classifier_3d.add_module(name, linear3d)

        else:
            # print(f"Warning: Skipping unhandled layer type in classifier: {name} ({type(child)})")
            pass # DenseNet's classifier is typically just a Linear layer

    # Combine features and classifier into a new Sequential model
    # We need to manually handle the final pooling and flatten step in the forward pass
    # like in the previous I3D code, as Sequential doesn't do this automatically.
    # So we'll create a custom module for the final model.

    class InflatedDenseNetModel(nn.Module):
        def __init__(self, features_3d, classifier_3d, final_time_dim, final_layer_nb):
            super().__init__()
            self.features = features_3d
            self.classifier = classifier_3d
            self.final_time_dim = final_time_dim
            self.final_layer_nb = final_layer_nb # This is the number of channels before global pooling

        def forward(self, x):
            x = self.features(x)
            # Apply ReLU after features (matches typical DenseNet flow before classification)
            x = nn.functional.relu(x)
            # Global spatial average pooling. Kernel size matches expected output spatial dims.
            # Assuming 7x7 spatial dims before classification for DenseNet121 on ImageNet size inputs
            spatial_kernel_h = x.shape[-2]
            spatial_kernel_w = x.shape[-1]
            x = nn.functional.avg_pool3d(x, kernel_size=(1, spatial_kernel_h, spatial_kernel_w))
            # Flatten for classifier
            # Original shape: [batch, channels, depth, 1, 1] after spatial pooling
            # Permute to [batch, depth, channels, 1, 1]
            x = x.permute(0, 2, 1, 3, 4).contiguous()
            # View to [batch, depth * channels]
            x = x.view(-1, self.final_time_dim * self.final_layer_nb)
            # Pass through classifier
            x = self.classifier(x)
            return x

    # Get the number of channels before the original 2D classifier
    final_layer_nb = densenet2d.classifier.in_features

    return InflatedDenseNetModel(features_3d, classifier_3d, final_time_dim, final_layer_nb)


# --- Testing the Inflation ---

import warnings

# a. Take a 2D pretrained DenseNet121 model
model_2d = models.densenet121(pretrained=True)
print("Loaded 2D DenseNet121 model.")

# Expected number of frames in your video input
input_frame_nb = 16
# Temporal kernel depth for inflating 3x3 convs
conv_k_depth = 3
# Temporal stride for pooling in transition blocks
pool_t_stride = 2

# Inflate the model
i3d_densenet_model = inflate_densenet121(
    model_2d,
    frame_nb=input_frame_nb,
    conv_kernel_depth=conv_k_depth,
    temporal_pool_stride=pool_t_stride
)

print(f"\nInflated 3D DenseNet model created for {input_frame_nb} frames.")
# Print the structure of the inflated model (optional, can be long)
# print(i3d_densenet_model)


# Create a dummy 3D input tensor
# Shape: [batch_size, channels, depth, height, width]
batch_size = 2
channels = 1
input_height = 224
input_width = 224

dummy_input_3d = torch.randn(batch_size, channels, input_frame_nb, input_height, input_width)
print("\nDummy input shape:", dummy_input_3d.shape)


# Perform a forward pass with the dummy input
try:
    i3d_densenet_model.eval() # Set model to evaluation mode (disables dropout, uses running stats for BN)
    with torch.no_grad(): # Disable gradient calculation for inference
        output = i3d_densenet_model(dummy_input_3d)

    print("\nForward pass successful!")
    print("Output shape:", output.shape)
    # The output shape should be [batch_size, num_classes]
    # For DenseNet121 pre-trained on ImageNet, num_classes is 1000
    print("Expected output shape:", (batch_size, 1000))

except Exception as e:
    print(f"\nForward pass failed: {e}")

Loaded 2D DenseNet121 model.

Inflated 3D DenseNet model created for 16 frames.

Dummy input shape: torch.Size([2, 1, 16, 224, 224])

Forward pass successful!
Output shape: torch.Size([2, 2048])
Expected output shape: (2, 1000)


NameError: name 'densenet2d' is not defined

## Task3

In [None]:
import torch
import torch.nn as nn
import nibabel as nib  # For loading NIfTI images
import numpy as np
# Assuming your previous code is in a file named 'i3d_densenet_inflation.py'
# from i3d_densenet_inflation import inflate_densenet121
import torchvision.models as models # Needed to load the 2D model


# --- (Include the inflate_densenet121 function and its helpers from previous responses) ---
# You should copy and paste the inflate_conv, inflate_batch_norm, inflate_relu,
# inflate_pool, InflatedDenseLayer, InflatedTransition, and inflate_densenet121
# definitions here or import them from your file.
# For brevity, I'll assume they are defined above this code block.

# --- Helper function to find specific convolution layers ---

def find_convolution_layers(module):
    """Recursively finds all Conv3d layers within a PyTorch module."""
    conv_layers = []
    for name, child in module.named_children():
        if isinstance(child, nn.Conv3d):
            conv_layers.append((name, child))
        else:
            # Recurse into submodules
            conv_layers.extend(find_convolution_layers(child))
    return conv_layers

# --- Main Feature Extraction Function ---

def extract_features_from_volume(model_3d, volume_path, frame_nb, target_layers=['last', 'third_last', 'fifth_last']):
    """
    Loads a NIfTI volume, runs it through the 3D CNN, and extracts feature maps
    from specified convolution layers, applying Global Average Pooling.

    Args:
        model_3d (torch.nn.Module): The inflated 3D DenseNet model.
        volume_path (str): Path to the NIfTI file (.nii or .nii.gz).
        frame_nb (int): The expected number of frames/slices in the input volume.
                         This should match the `frame_nb` used during inflation.
        target_layers (list): List of strings specifying which convolution layers
                              to extract features from. Options: 'last', 'second_last',
                              'third_last', etc., or a specific layer name (though names
                              can be complex in nested structures).

    Returns:
        dict: A dictionary where keys are the target layer names/indices
              and values are the corresponding N-dimensional feature vectors
              after Global Average Pooling. Returns an empty dictionary if no
              valid layers are specified or found.
    Raises:
        FileNotFoundError: If the volume_path does not exist.
        ValueError: If the volume's temporal dimension does not match frame_nb
                    or if specified target layers are not found.
        RuntimeError: If a forward pass fails.
    """
    # 1. Load NIfTI image
    if not os.path.exists(volume_path):
        raise FileNotFoundError(f"Volume not found at: {volume_path}")

    img = nib.load(volume_path)
    data = img.get_fdata()
    # Ensure data is float32, which is typical for model inputs
    data = data.astype(np.float32)

    # NIfTI data often has shape [depth, height, width] or similar spatial dims
    # We need to prepare it for the PyTorch model: [batch_size, channels, depth, height, width]
    # Assuming your NIfTI is grayscale (1 channel) or already has channel dim
    # If grayscale [depth, height, width], add a channel dimension at position 0 or 1
    # If it's [depth, height, width] and represents time series, depth is time.
    # Let's assume input NIfTI is [depth, height, width] (spatial) and we add a channel and batch dim.
    # Or, if it's a video/sequence [time, depth, height, width], we need to reorder.
    # Based on the model's expected input [batch_size, channels, depth, height, width],
    # let's assume the NIfTI data is [depth, height, width] and represents *one time step*,
    # and you are providing multiple volumes or slices that form the 'frame_nb' dimension.
    # A more realistic approach for medical imaging might be [slices, height, width]
    # where 'slices' corresponds to the depth dimension in the 3D convolution.
    # Let's assume the NIfTI is [depth, height, width] and treat 'depth' as the temporal dim.
    # We need to ensure it matches the expected `frame_nb`.

    # Adjust this reshaping based on your actual NIfTI data format and how you want to map it to [channels, depth, height, width]
    # Assuming NIfTI is [Depth, Height, Width] -> map to [1, Depth, Height, Width] (1 channel grayscale)
    # Then add batch dim -> [1, 1, Depth, Height, Width]
    # And finally ensure Depth matches frame_nb.
    # A common format might be [Depth, Height, Width, Channels] for color or multi-spectral.
    # Let's assume input NIfTI data has shape [SpatialDepth, SpatialHeight, SpatialWidth]
    # and you want to stack `frame_nb` such volumes or slices along the first dim to create the temporal dimension.
    # Or, maybe your NIfTI itself is [Time, SpatialDepth, SpatialHeight, SpatialWidth].

    # Let's simplify and assume the loaded 'data' from NIfTI is already in
    # a format where you can directly extract `frame_nb` slices/volumes that
    # when stacked or selected, form the input `[channels, depth, height, width]` tensor.
    # For a single 3D volume treated as the 'depth' dimension in the CNN:
    # Assume data is [Depth, Height, Width]. We want [1, Depth, Height, Width] for 1 channel.
    # The model expects [batch_size, channels, depth, height, width]
    # So, input_tensor will be [1, 1, Depth, Height, Width] where Depth = frame_nb

    # Let's refine the input tensor preparation:
    # Assuming your NIfTI data is a single 3D volume of shape [D, H, W]
    # And you want to treat D as the 'depth' dimension for the 3D CNN.
    # Your CNN expects [batch_size, channels, depth, height, width].
    # So, for a single volume, batch_size=1, channels=1 (grayscale), depth=D, height=H, width=W.
    # The `frame_nb` passed to this function should be D.

    input_data_shape = data.shape
    if len(input_data_shape) != 3:
        # Handle potential 4D NIfTI data if needed (e.g., [Depth, H, W, Channels] or [Time, D, H, W])
        # This part requires knowing your specific NIfTI format
        raise ValueError(f"Expected 3D NIfTI data [D, H, W], but got shape {input_data_shape}. "
                         "Adjust input processing based on your NIfTI format.")

    current_volume_depth = input_data_shape[0] # Assuming depth is the first dim

    if current_volume_depth != frame_nb:
         warnings.warn(f"NIfTI volume depth ({current_volume_depth}) does not match frame_nb ({frame_nb}) "
                       "used during model inflation. This might cause issues or requires padding/cropping.")
         # You might need padding or cropping here if the volume depth doesn't match frame_nb
         # For simplicity, we'll proceed but be aware this is a potential problem.
         # If your model expects a fixed frame_nb, you *must* match it.
         # If your model can handle variable temporal length (less common for fixed-size inputs),
         # you need to adjust the model's final layer logic.

    # Add channel and batch dimensions
    # data shape [D, H, W] -> tensor shape [1, 1, D, H, W]
    input_tensor = torch.from_numpy(data).unsqueeze(0).unsqueeze(0) # Add channel and batch dim

    print(f"Loaded volume shape: {input_data_shape}, Prepared tensor shape: {input_tensor.shape}")


    # Find all Conv3d layers and their names
    all_conv_layers = find_convolution_layers(model_3d)
    if not all_conv_layers:
        raise ValueError("No Conv3d layers found in the model's features.")

    # Map target layer names/indices to actual layers
    layers_to_extract = {}
    num_conv_layers = len(all_conv_layers)

    for target in target_layers:
        if target == 'last':
            if num_conv_layers >= 1:
                layers_to_extract['last'] = all_conv_layers[-1]
            else:
                warnings.warn("Target 'last' specified, but no Conv3d layers found.")
        elif target.endswith('_last'):
            try:
                index = int(target.split('_')[0])
                if index > 0 and num_conv_layers >= index:
                    layers_to_extract[target] = all_conv_layers[-index]
                else:
                    warnings.warn(f"Target '{target}' specified, but not enough Conv3d layers ({num_conv_layers}) found.")
            except ValueError:
                warnings.warn(f"Invalid target layer format: {target}. Use 'last', 'second_last', etc.")
        # Add support for specific layer names if needed
        # elif isinstance(target, str):
        #    found = False
        #    for layer_name, layer_module in all_conv_layers:
        #        if layer_name == target:
        #             layers_to_extract[target] = (layer_name, layer_module)
        #             found = True
        #             break
        #    if not found:
        #        warnings.warn(f"Target layer '{target}' not found by name.")
        else:
            warnings.warn(f"Invalid target layer specification: {target}.")

    if not layers_to_extract:
        warnings.warn("No valid target layers specified or found for feature extraction.")
        return {}

    # Hook to capture intermediate feature maps
    feature_maps = {}
    hooks = []

    def hook_fn(module, input, output):
        # We identify the layer by its module object reference
        layer_info = None
        for name, layer_module in all_conv_layers:
             if layer_module is module:
                 layer_info = (name, layer_module)
                 break
        if layer_info:
             # Store the output tensor. Clone it to avoid issues if the tensor is modified later.
             feature_maps[layer_info] = output.clone()

    # Register hooks on the target layers
    for _, (layer_name, layer_module) in layers_to_extract.items():
         # Find the actual module instance in the model's features sequential
         # This part can be tricky with deeply nested Sequential/ModuleList structures.
         # A simpler approach is to use the layer_module object directly from all_conv_layers
         # and register the hook on that specific instance.
         hooks.append(layer_module.register_forward_hook(hook_fn))


    # Set model to evaluation mode and disable gradients
    model_3d.eval()
    with torch.no_grad():
        # Run the forward pass
        # The forward pass will execute the hooks when the target layers are reached.
        try:
            # We run the full forward pass, but only care about the captured feature maps
            # The final output of the model is not needed for feature extraction
            _ = model_3d(input_tensor)
        except Exception as e:
            # Clean up hooks even if forward pass fails
            for h in hooks:
                h.remove()
            raise RuntimeError(f"Forward pass failed during feature extraction: {e}")


    # Clean up the hooks
    for h in hooks:
        h.remove()

    # 3. Apply Global Average Pooling (GAP) to extracted feature maps
    extracted_features = {}
    for (layer_name, layer_module), feature_map in feature_maps.items():
        # Feature map shape is [batch_size, channels, depth, height, width]
        # GAP should pool over depth, height, and width dimensions
        # For a single volume, batch_size is 1.
        # Output shape after GAP should be [batch_size, channels]

        # Ensure feature_map has at least 3 spatial/temporal dimensions to pool over
        if feature_map.dim() < 5:
             warnings.warn(f"Feature map for layer {layer_name} has unexpected dimensions {feature_map.shape}. Skipping GAP.")
             continue # Skip GAP if shape is not as expected

        # Apply 3D Global Average Pooling
        # Reduce dimensions 2, 3, and 4 (depth, height, width) to size 1
        pooled_features = nn.functional.avg_pool3d(feature_map, kernel_size=feature_map.size()[2:])

        # The shape is now [batch_size, channels, 1, 1, 1]. Squeeze out the singleton dimensions.
        feature_vector = pooled_features.squeeze(-1).squeeze(-1).squeeze(-1) # Shape [batch_size, channels]

        # For a single volume (batch_size=1), the shape is [1, channels]. Squeeze the batch dim.
        if feature_vector.shape[0] == 1:
             feature_vector = feature_vector.squeeze(0) # Shape [channels]

        extracted_features[(layer_name, type(layer_module).__name__)] = feature_vector


    return extracted_features

# --- Example Usage ---

import os # Needed for os.path.exists

if __name__ == "__main__":
    # Assuming your NIfTI files are organized by region
    nifti_dir = 'segmented_regions' # Change this to your directory
    regions = ['Tibia', 'Femur', 'Background'] # Or whatever your regions are

    # --- Load and Inflate the Model ---
    try:
        model_2d = models.densenet121(pretrained=True)
        input_frame_nb_for_inflation = 64 # Example: Inflate for volumes with 64 slices/frames
        # Make sure this matches the actual depth of your NIfTI volumes
        i3d_densenet_model = inflate_densenet121(model_2d, frame_nb=input_frame_nb_for_inflation)
        print("Model loaded and inflated successfully.")
    except Exception as e:
        print(f"Error loading or inflating model: {e}")
        exit() # Exit if model setup fails




    # Iterate through regions and potentially multiple volumes per region
    for region in regions:
        print(f"\nProcessing region: {region}")
        region_features = []
        #region_dir = os.path.join(nifti_dir, region)

        # Assuming NIfTI files are directly in the region directory
        volume_files = [f for f in os.listdir(nifti_dir) if f.endswith('.nii') or f.endswith('.nii.gz')]

        if not volume_files:
            print(f"No NIfTI files found in {region_dir}. Skipping.")
            continue

        for volume_file in volume_files:
            volume_path = os.path.join(nifti_dir, volume_file)
            print(f"  Processing volume: {volume_file}")

            try:
                # Define the layers you want to extract from.
                # This depends on the specific structure of the inflated DenseNet.
                # You might need to inspect the printed model structure to figure out which
                # indices correspond to the last, third-last, fifth-last *convolution* layers.
                # The 'last', 'third_last', etc. logic in find_convolution_layers handles this based on indices.
                layers_to_get = ['last', 'third_last', 'fifth_last'] # Example targets

                extracted_feats = extract_features_from_volume(
                    i3d_densenet_model,
                    volume_path,
                    frame_nb=input_frame_nb_for_inflation, # Must match inflation frame_nb
                    target_layers=layers_to_get
                )
                region_features.append((volume_file, extracted_feats))
                print(f"    Extracted features for layers: {list(extracted_feats.keys())}")
                # Example: print the shape of a feature vector
                # if extracted_feats:
                #      first_key = list(extracted_feats.keys())[0]
                #      print(f"    Shape of features from {first_key}: {extracted_feats[first_key].shape}")


            except (FileNotFoundError, ValueError, RuntimeError) as e:
                print(f"    Error processing volume {volume_file}: {e}")
                continue # Continue to the next volume even if one fails

        all_extracted_features[region] = region_features

    print("\n--- Feature Extraction Complete ---")
    # all_extracted_features is a dictionary:
    # {
    #   'Tibia': [(volume_file1, {layer1: features, layer2: features, ...}), (volume_file2, {...}), ...],
    #   'Femur': [...],
    #   'Background': [...]
    # }

    # You can now use the 'all_extracted_features' dictionary for downstream tasks
    # (e.g., training a classifier on these feature vectors)
    # Example: Access features for the first volume of the Tibia region, from the 'last' convolution layer
    # if 'Tibia' in all_extracted_features and all_extracted_features['Tibia']:
    #     first_tibia_volume_features = all_extracted_features['Tibia'][0][1]
    #     if ('last', 'Conv3d') in first_tibia_volume_features: # Check if 'last' was extracted and is Conv3d
    #          last_conv_features = first_tibia_volume_features[('last', 'Conv3d')]
    #          print("\nExample: Features from last Conv3d of first Tibia volume:")
    #          print(last_conv_features)
    #          print("Shape:", last_conv_features.shape)

Model loaded and inflated successfully.

Processing region: Tibia
  Processing volume: tibia.nii.gz




Loaded volume shape: (512, 512, 216), Prepared tensor shape: torch.Size([1, 1, 512, 512, 216])


Model loaded and inflated successfully.
  Loaded original volume data with shape: (512, 512, 216)
      Processing region: tibia
        Preprocessed tensor shape: torch.Size([1, 1, 16, 101, 216])
  Error processing volume:       Forward pass failed for region tibia: Given groups=1, weight of size [64, 3, 7, 7, 7], expected input[1, 1, 16, 101, 216] to have 3 channels, but got 1 channels instead

--- Region Feature Extraction Complete ---




In [27]:
region_features

[('Background_sample_volume.nii.gz',
  {('conv2',
    'Conv3d'): tensor([ 0.0051,  0.0122, -0.0055, -0.0065,  0.0012, -0.0178, -0.0320, -0.0056,
           -0.0361, -0.0274, -0.0244, -0.0153, -0.0003, -0.0026, -0.0136,  0.0120,
            0.0023,  0.0165, -0.0097, -0.0062, -0.0051,  0.0130, -0.0266,  0.0109,
           -0.0180, -0.0190, -0.0017,  0.0228, -0.0170,  0.0546,  0.0062,  0.0015])})]