In [105]:
DEBUG = False

In [106]:
import os
import gc
import time
import random
import numpy as np
import pandas as pd
from PIL import Image
import nibabel as nib
from scipy.ndimage import zoom
from glob import glob
from tqdm import tqdm
from typing import Optional
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
from torch.utils.data import DataLoader, Dataset
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder

from monai.transforms import Resize
import  monai.transforms as transforms

import timm
from timm.models.layers import Conv2dSame


device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

In [107]:
class Conv3dSame(nn.Conv3d):
    """ Tensorflow like 'SAME' convolution wrapper for 3D convolutions
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
    ):
        # Calculate padding for SAME behavior
        if isinstance(kernel_size, int):
            padding = (kernel_size - 1) // 2
        elif isinstance(kernel_size, (tuple, list)):
            padding = [(k - 1) // 2 for k in kernel_size]
        else:
            raise ValueError("kernel_size must be int or iterable of int")

        super(Conv3dSame, self).__init__(
            in_channels, out_channels, kernel_size,
            stride, padding, dilation, groups, bias,
        )

    def forward(self, x):
        return conv3d_same(
            x, self.weight, self.bias,
            self.stride, self.padding, self.dilation, self.groups,
        )

def conv3d_same(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    # Here you would need to implement or call the actual 3D convolution logic
    # with SAME padding, as PyTorch's native convolution does not have this padding mode.
    # This is a placeholder implementation and should be replaced with actual logic.
    conv = nn.functional.conv3d(
        x, weight, bias, stride, padding, dilation, groups
    )
    return conv

In [108]:
# Define paths
base_path = 'data'
segmentations_path = os.path.join(base_path, 'segmentations')
train_images_path1 = os.path.join(base_path, 'sfd', 'train_images')
train_images_path2 = os.path.join(base_path, 'train_images')
output_path = os.path.join(base_path, 'processed_3d')

kernel_type = 'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
load_kernel = None
load_last = True
n_blocks = 4
n_folds = 5
backbone = 'resnet18d'

image_sizes = [128, 128, 128]
R = Resize(image_sizes)

init_lr = 3e-3
batch_size = 4
drop_rate = 0.
drop_path_rate = 0.
loss_weights = [1, 1]
p_mixup = 0.1

data_dir = output_path
use_amp = True
num_workers = 4
out_dim = 7

n_epochs = 1000

log_dir = './logs'
model_dir = './models'
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

In [109]:
# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

In [110]:
def load_and_process_image_volume(study_path, target_size=(128, 128, 128)):
    # List and sort slice file paths
    t_paths = sorted(glob(os.path.join(study_path, "*")),
                     key=lambda x: int(os.path.basename(x).split('.')[0]))
    
    # Determine the number of scans and calculate quantile indices
    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., target_size[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]
    
    # Load and process slices
    slices = []
    for img_path in t_paths:
        img = Image.open(img_path)
        img = img.resize((target_size[0], target_size[1]))
        slices.append(np.array(img))
    
    # Stack slices into a 3D volume
    volume = np.stack(slices, axis=-1)
    
    # Normalize and scale the volume
    volume = volume - np.min(volume)
    volume = volume / (np.max(volume) + 1e-4)
    volume = (volume * 255).astype(np.uint8)
    
    return volume

def load_and_process_mask(mask_path, target_size=(128, 128, 128), num_classes=7):
    # Load the mask using nibabel
    mask_org = nib.load(mask_path).get_fdata()

    # Adjust mask orientation if needed
    mask_org = mask_org.transpose(1, 0, 2)[::-1, :, ::-1]  # Adjust orientation to (d, w, h)

    # Resize mask to target size
    if mask_org.shape != target_size:
        factors = [t / s for t, s in zip(target_size, mask_org.shape)]
        mask_org = zoom(mask_org, factors, order=0)  # Nearest-neighbor interpolation for masks

    # Create multi-channel mask
    mask = np.zeros((num_classes, target_size[0], target_size[1], target_size[2]))
    for cid in range(num_classes):
        mask[cid] = (mask_org == (cid + 1))

    # Convert mask to [0, 255] and return as uint8
    mask = mask.astype(np.uint8) * 255

    return mask

In [111]:
if not os.path.exists(output_path):
    # Get list of study IDs with segmentations
    segmentation_ids = [f.split('.')[:-1] for f in os.listdir(segmentations_path) if f.endswith('.nii')]

    for study_id_parts in tqdm(segmentation_ids, desc="Processing studies"):
        study_id = '.'.join(study_id_parts)
        # Check for study folder in both locations
        study_path = None
        if os.path.exists(os.path.join(train_images_path1, study_id)):
            study_path = os.path.join(train_images_path1, study_id)
        elif os.path.exists(os.path.join(train_images_path2, study_id)):
            study_path = os.path.join(train_images_path2, study_id)
        
        if study_path is None:
            print(f"Warning: No image folder found for study {study_id}")
            continue
        
        # Process image volume
        image_volume = load_and_process_image_volume(study_path)
        
        # Process mask
        mask_path = os.path.join(segmentations_path, f"{study_id}.nii")
        mask_volume = load_and_process_mask(mask_path)
        
        # Save processed data
        np.save(os.path.join(output_path, f"{study_id}_image.npy"), image_volume)
        np.save(os.path.join(output_path, f"{study_id}_mask.npy"), mask_volume)

    print("Processing complete!")

In [112]:
# Define the data directory
data_dir_parent = 'data'

# Read the training data
df_train = pd.read_csv(os.path.join(data_dir_parent, 'train.csv'))

# Get the list of mask files
mask_files = [f for f in os.listdir(os.path.join(data_dir_parent, 'processed_3d')) if f.endswith('_mask.npy')]

# Create a DataFrame for mask files
df_mask = pd.DataFrame({
    'mask_file': mask_files,
})

# Extract StudyInstanceUID from mask filenames
df_mask['StudyInstanceUID'] = df_mask['mask_file'].apply(lambda x: x.split('_mask.npy')[0])

# Create full paths for mask files
df_mask['mask_file'] = df_mask['mask_file'].apply(lambda x: os.path.join(data_dir_parent, 'processed_3d', x))

# Merge training data with mask data
df = df_train.merge(df_mask, on='StudyInstanceUID', how='left')

# Function to determine the correct image folder
def get_image_folder(study_id):
    folder1 = os.path.join(data_dir_parent, 'train_images', study_id)
    folder2 = os.path.join(data_dir_parent, 'sfd', 'train_images', study_id)
    if os.path.exists(folder1):
        return folder1
    elif os.path.exists(folder2):
        return folder2
    else:
        return ''  # or return None, depending on how you want to handle missing folders

# Add image folder path
df['image_folder'] = df['StudyInstanceUID'].apply(get_image_folder)

# Fill NA values in mask_file column
df['mask_file'] = df['mask_file'].fillna('')

# Filter for segmentation data
df_seg = df[df['mask_file'] != ''].reset_index(drop=True)

# Perform K-Fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
df_seg['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(kf.split(df_seg)):
    df_seg.loc[valid_idx, 'fold'] = fold

# Display the last few rows of the resulting DataFrame
print(df_seg.tail())

# Save the DataFrame if needed
# df_seg.to_csv(os.path.join(data_dir_parent, 'df_seg.csv'), index=False)

                                              slice_path  \
29827  ../input/rsna-2022-cervical-spine-fracture-det...   
29828  ../input/rsna-2022-cervical-spine-fracture-det...   
29829  ../input/rsna-2022-cervical-spine-fracture-det...   
29830  ../input/rsna-2022-cervical-spine-fracture-det...   
29831  ../input/rsna-2022-cervical-spine-fracture-det...   

                StudyInstanceUID  patient_overall  C1  C2  C3  C4  C5  C6  C7  \
29827  1.2.826.0.1.3680043.30524                1   0   0   0   0   0   1   1   
29828  1.2.826.0.1.3680043.30524                1   0   0   0   0   0   1   1   
29829  1.2.826.0.1.3680043.30524                1   0   0   0   0   0   1   1   
29830  1.2.826.0.1.3680043.30524                1   0   0   0   0   0   1   1   
29831  1.2.826.0.1.3680043.30524                1   0   0   0   0   0   1   1   

       slice_id  width  height  \
29827        26    512     512   
29828       185    512     512   
29829       221    512     512   
29830        13 

In [113]:
class SEGDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.file_list = [f.split('_')[0] for f in os.listdir(data_dir) if f.endswith('_image.npy')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        file_id = self.file_list[index]
        
        image_file = os.path.join(self.data_dir, f'{file_id}_image.npy')
        mask_file = os.path.join(self.data_dir, f'{file_id}_mask.npy')
        
        image = np.load(image_file).astype(np.float32)
        mask = np.load(mask_file).astype(np.float32)

        if self.transform:
            res = self.transform({'image': image, 'mask': mask})
            image = res['image']
            mask = res['mask']

        image = image / 255.
        mask = (mask > 127).astype(np.float32)

        image = torch.tensor(image).float()
        mask = torch.tensor(mask).float()

        return image, mask

In [114]:
transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(keys=["image", "mask"], translate_range=[int(x*y) for x, y in zip(image_sizes, [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

transforms_valid = transforms.Compose([
])

In [115]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

In [116]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [117]:
def binary_dice_score(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    threshold: Optional[float] = None,
    nan_score_on_empty=False,
    eps: float = 1e-7,
) -> float:

    if threshold is not None:
        y_pred = (y_pred > threshold).to(y_true.dtype)

    intersection = torch.sum(y_pred * y_true).item()
    cardinality = (torch.sum(y_pred) + torch.sum(y_true)).item()

    score = (2.0 * intersection) / (cardinality + eps)

    has_targets = torch.sum(y_true) > 0
    has_predicted = torch.sum(y_pred) > 0

    if not has_targets:
        if nan_score_on_empty:
            score = np.nan
        else:
            score = float(not has_predicted)
    return score


def multilabel_dice_score(
    y_true: torch.Tensor,
    y_pred: torch.Tensor,
    threshold=None,
    eps=1e-7,
    nan_score_on_empty=False,
):
    ious = []
    num_classes = y_pred.size(0)
    for class_index in range(num_classes):
        iou = binary_dice_score(
            y_pred=y_pred[class_index],
            y_true=y_true[class_index],
            threshold=threshold,
            nan_score_on_empty=nan_score_on_empty,
            eps=eps,
        )
        ious.append(iou)

    return ious


def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


def bce_dice(input, target, loss_weights=loss_weights):
    loss1 = loss_weights[0] * nn.BCEWithLogitsLoss()(input, target)
    loss2 = loss_weights[1] * dice_loss(input, target)
    return (loss1 + loss2) / sum(loss_weights)

criterion = bce_dice

In [118]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam


def train_func(model, loader_train, optimizer, scaler=None):
    model.train()
    train_loss = []
    bar = tqdm(loader_train)
    for images, gt_masks in bar:
        optimizer.zero_grad()
        images = images.cuda()
        gt_masks = gt_masks.cuda()

        do_mixup = False
        if random.random() < p_mixup:
            do_mixup = True
            images, gt_masks, gt_masks_sfl, lam = mixup(images, gt_masks)

        with amp.autocast():
            logits = model(images)
            loss = criterion(logits, gt_masks)
            if do_mixup:
                loss2 = criterion(logits, gt_masks_sfl)
                loss = loss * lam  + loss2 * (1 - lam)

        train_loss.append(loss.item())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bar.set_description(f'smth:{np.mean(train_loss[-30:]):.4f}')

    return np.mean(train_loss)


def valid_func(model, loader_valid):
    model.eval()
    valid_loss = []
    outputs = []
    ths = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    batch_metrics = [[]] * 7
    bar = tqdm(loader_valid)
    with torch.no_grad():
        for images, gt_masks in bar:
            images = images.cuda()
            gt_masks = gt_masks.cuda()

            logits = model(images)
            loss = criterion(logits, gt_masks)
            valid_loss.append(loss.item())
            for thi, th in enumerate(ths):
                pred = (logits.sigmoid() > th).float().detach()
                for i in range(logits.shape[0]):
                    tmp = multilabel_dice_score(
                        y_pred=logits[i].sigmoid().cpu(),
                        y_true=gt_masks[i].cpu(),
                        threshold=0.5,
                    )
                    batch_metrics[thi].extend(tmp)
            bar.set_description(f'smth:{np.mean(valid_loss[-30:]):.4f}')
            
    metrics = [np.mean(this_metric) for this_metric in batch_metrics]
    print('best th:', ths[np.argmax(metrics)], 'best dc:', np.max(metrics))

    return np.mean(valid_loss), np.max(metrics)


In [119]:
def run(fold):

    log_file = os.path.join(log_dir, f'{kernel_type}.txt')
    model_file = os.path.join(model_dir, f'{kernel_type}_fold{fold}_best.pth')

    train_ = df_seg[df_seg['fold'] != fold].reset_index(drop=True)
    valid_ = df_seg[df_seg['fold'] == fold].reset_index(drop=True)
    dataset_train = SEGDataset(train_, 'train', transform=transforms_train)
    dataset_valid = SEGDataset(valid_, 'valid', transform=transforms_valid)
    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    model = TimmSegModel(backbone, pretrained=True)
    model = convert_3d(model)
    model = model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=init_lr)
    scaler = torch.cuda.amp.GradScaler()
    from_epoch = 0
    metric_best = 0.
    loss_min = np.inf

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, n_epochs)

    print(len(dataset_train), len(dataset_valid))

    for epoch in range(1, n_epochs+1):
        scheduler_cosine.step(epoch-1)

        print(time.ctime(), 'Epoch:', epoch)

        train_loss = train_func(model, loader_train, optimizer, scaler)
        valid_loss, metric = valid_func(model, loader_valid)

        content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {valid_loss:.5f}, metric: {(metric):.6f}.'
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')

        if metric > metric_best:
            print(f'metric_best ({metric_best:.6f} --> {metric:.6f}). Saving model ...')
            torch.save(model.state_dict(), model_file)
            metric_best = metric

        # Save Last
        if not DEBUG:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'score_best': metric_best,
                },
                model_file.replace('_best', '_last')
            )

    del model
    torch.cuda.empty_cache()
    gc.collect()


In [120]:
run(0)
run(1)
run(2)
run(3)
run(4)

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'train'