In [21]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import nibabel as nib
from scipy.ndimage import zoom
from glob import glob
from tqdm import tqdm
from typing import Optional

import torch
import torch.nn as nn
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder

import timm
from timm.models.layers import Conv2dSame

In [12]:
class Conv3dSame(nn.Conv3d):
    """ Tensorflow like 'SAME' convolution wrapper for 3D convolutions
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
    ):
        # Calculate padding for SAME behavior
        if isinstance(kernel_size, int):
            padding = (kernel_size - 1) // 2
        elif isinstance(kernel_size, (tuple, list)):
            padding = [(k - 1) // 2 for k in kernel_size]
        else:
            raise ValueError("kernel_size must be int or iterable of int")

        super(Conv3dSame, self).__init__(
            in_channels, out_channels, kernel_size,
            stride, padding, dilation, groups, bias,
        )

    def forward(self, x):
        return conv3d_same(
            x, self.weight, self.bias,
            self.stride, self.padding, self.dilation, self.groups,
        )

def conv3d_same(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    # Here you would need to implement or call the actual 3D convolution logic
    # with SAME padding, as PyTorch's native convolution does not have this padding mode.
    # This is a placeholder implementation and should be replaced with actual logic.
    conv = nn.functional.conv3d(
        x, weight, bias, stride, padding, dilation, groups
    )
    return conv

In [13]:
# Define paths
base_path = 'data'
segmentations_path = os.path.join(base_path, 'segmentations')
train_images_path1 = os.path.join(base_path, 'sfd', 'train_images')
train_images_path2 = os.path.join(base_path, 'train_images')
output_path = os.path.join(base_path, 'processed_3d')

backbone='resnet18d'
drop_rate = 0.
drop_path_rate = 0.
n_blocks = 4
out_dim = 7
loss_weights = [1, 1]

In [14]:
# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

In [15]:
def load_and_process_image_volume(study_path, target_size=(128, 128, 128)):
    # List and sort slice file paths
    t_paths = sorted(glob(os.path.join(study_path, "*")),
                     key=lambda x: int(os.path.basename(x).split('.')[0]))
    
    # Determine the number of scans and calculate quantile indices
    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., target_size[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]
    
    # Load and process slices
    slices = []
    for img_path in t_paths:
        img = Image.open(img_path)
        img = img.resize((target_size[0], target_size[1]))
        slices.append(np.array(img))
    
    # Stack slices into a 3D volume
    volume = np.stack(slices, axis=-1)
    
    # Normalize and scale the volume
    volume = volume - np.min(volume)
    volume = volume / (np.max(volume) + 1e-4)
    volume = (volume * 255).astype(np.uint8)
    
    return volume

def load_and_process_mask(mask_path, target_size=(128, 128, 128), num_classes=7):
    # Load the mask using nibabel
    mask_org = nib.load(mask_path).get_fdata()

    # Adjust mask orientation if needed
    mask_org = mask_org.transpose(1, 0, 2)[::-1, :, ::-1]  # Adjust orientation to (d, w, h)

    # Resize mask to target size
    if mask_org.shape != target_size:
        factors = [t / s for t, s in zip(target_size, mask_org.shape)]
        mask_org = zoom(mask_org, factors, order=0)  # Nearest-neighbor interpolation for masks

    # Create multi-channel mask
    mask = np.zeros((num_classes, target_size[0], target_size[1], target_size[2]))
    for cid in range(num_classes):
        mask[cid] = (mask_org == (cid + 1))

    # Convert mask to [0, 255] and return as uint8
    mask = mask.astype(np.uint8) * 255

    return mask

In [16]:
if not os.path.exists(output_path):
    # Get list of study IDs with segmentations
    segmentation_ids = [f.split('.')[:-1] for f in os.listdir(segmentations_path) if f.endswith('.nii')]

    for study_id_parts in tqdm(segmentation_ids, desc="Processing studies"):
        study_id = '.'.join(study_id_parts)
        # Check for study folder in both locations
        study_path = None
        if os.path.exists(os.path.join(train_images_path1, study_id)):
            study_path = os.path.join(train_images_path1, study_id)
        elif os.path.exists(os.path.join(train_images_path2, study_id)):
            study_path = os.path.join(train_images_path2, study_id)
        
        if study_path is None:
            print(f"Warning: No image folder found for study {study_id}")
            continue
        
        # Process image volume
        image_volume = load_and_process_image_volume(study_path)
        
        # Process mask
        mask_path = os.path.join(segmentations_path, f"{study_id}.nii")
        mask_volume = load_and_process_mask(mask_path)
        
        # Save processed data
        np.save(os.path.join(output_path, f"{study_id}_image.npy"), image_volume)
        np.save(os.path.join(output_path, f"{study_id}_mask.npy"), mask_volume)

    print("Processing complete!")

In [17]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

In [18]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output


m = TimmSegModel(backbone)
m = convert_3d(m)
m(torch.rand(1, 3, 128,128,128)).shape

torch.Size([1, 7, 128, 128, 128])

In [22]:
def binary_dice_score(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    threshold: Optional[float] = None,
    nan_score_on_empty=False,
    eps: float = 1e-7,
) -> float:

    if threshold is not None:
        y_pred = (y_pred > threshold).to(y_true.dtype)

    intersection = torch.sum(y_pred * y_true).item()
    cardinality = (torch.sum(y_pred) + torch.sum(y_true)).item()

    score = (2.0 * intersection) / (cardinality + eps)

    has_targets = torch.sum(y_true) > 0
    has_predicted = torch.sum(y_pred) > 0

    if not has_targets:
        if nan_score_on_empty:
            score = np.nan
        else:
            score = float(not has_predicted)
    return score


def multilabel_dice_score(
    y_true: torch.Tensor,
    y_pred: torch.Tensor,
    threshold=None,
    eps=1e-7,
    nan_score_on_empty=False,
):
    ious = []
    num_classes = y_pred.size(0)
    for class_index in range(num_classes):
        iou = binary_dice_score(
            y_pred=y_pred[class_index],
            y_true=y_true[class_index],
            threshold=threshold,
            nan_score_on_empty=nan_score_on_empty,
            eps=eps,
        )
        ious.append(iou)

    return ious


def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


def bce_dice(input, target, loss_weights=loss_weights):
    loss1 = loss_weights[0] * nn.BCEWithLogitsLoss()(input, target)
    loss2 = loss_weights[1] * dice_loss(input, target)
    return (loss1 + loss2) / sum(loss_weights)

criterion = bce_dice

NameError: name 'loss_weights' is not defined