In [None]:
import sys


sys.path.append("../src/")

In [None]:
from sennet.core.mmap_arrays import read_mmap_array, create_mmap_array
from sennet.custom_modules.metrics.surface_dice_metric import create_table_neighbour_code_to_surface_area
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np

In [None]:
from pathlib import Path


root_path = Path("/home/clay/research/kaggle/sennet/data_dumps/predicted/ensembled/kidney_3_dense/")
mmap_paths = sorted([p for p in root_path.glob("chunk_*")])
for p in mmap_paths:
    print(p)
mmaps = [read_mmap_array(p / "mean_prob")for p in mmap_paths]
print([m.shape for m in mmaps])
# mmap = np.concatenate([m.data for m in mmaps], axis=0)
# mmap = read_mmap_array("/home/clay/research/kaggle/sennet/data_dumps/predicted/ensembled/kidney_3_dense/chunk_02/thresholded_prob")

In [None]:
label = read_mmap_array("/home/clay/research/kaggle/sennet/data_dumps/processed/kidney_3_dense/label", mode="r")

In [None]:
device = "cuda"
dtype_index = torch.int32


def compute_area(y: list, unfold: nn.Unfold, area: torch.Tensor) -> torch.Tensor:
    """
    Args:
      y (list[Tensor]): A pair of consecutive slices of mask
      unfold: nn.Unfold(kernel_size=(2, 2), padding=1)
      area (Tensor): surface area for 256 patterns (256, )

    Returns:
      Surface area of surface in 2x2x2 cube
    """
    # Two layers of segmentation masks
    yy = torch.stack(y, dim=0).to(torch.float16).unsqueeze(0)
    # (batch_size=1, nch=2, H, W) 
    # bit (0/1) but unfold requires float

    # unfold slides through the volume like a convolution
    # 2x2 kernel returns 8 values (2 channels * 2x2)
    cubes_float = unfold(yy).squeeze(0)  # (8, n_cubes)

    # Each of the 8 values are either 0 or 1
    # Convert those 8 bits to one uint8
    cubes_byte = torch.zeros(cubes_float.size(1), dtype=dtype_index, device=device)
    # indices are required to be int32 or long for area[cube_byte] below, not uint8
    # Can be int32 for torch 2.0.0, int32 raise IndexError in torch 1.13.1.
    
    for k in range(8):
        cubes_byte += cubes_float[k, :].to(dtype_index) << k

    # Use area lookup table: pattern index -> area [float]
    cubes_area = area[cubes_byte]

    return cubes_area


def compute_surface_dice_score(
    mean_prob_chunks: list[np.ndarray], 
    label: np.ndarray,
    threshold: float,
) -> float:
    """
    Compute surface Dice score for one 3D volume

    submit (pd.DataFrame): submission file with id and rle
    label (pd.DataFrame): ground truth id, rle, and also image height, width
    """
    # submit and label must contain exact same id in same order
    assert sum(p.shape[0] for p in mean_prob_chunks) == label.shape[0]

    # Surface area lookup table: Tensor[float32] (256, )
    area = create_table_neighbour_code_to_surface_area((1, 1, 1))
    area = torch.from_numpy(area).to(device)  # torch.float32

    # Slide through the volume like a convolution
    unfold = torch.nn.Unfold(kernel_size=(2, 2), padding=1)

    h = label.shape[1]
    w = label.shape[2]
    n_slices = label.shape[0]

    # Padding before first slice
    y0 = y0_pred = torch.zeros((h, w), dtype=torch.uint8, device=device)

    num = 0 # numerator of surface Dice
    denom = 0 # denominator
    i = 0
    for chunk in mean_prob_chunks:
        for c in range(chunk.shape[0]):
            if i < n_slices:
                y1 = torch.from_numpy(label[i, :, :].copy()).to(device)
                y1_pred = torch.from_numpy(chunk[c, :, :] > threshold).to(device)
            else:
                y1 = y1_pred = torch.zeros((h, w), dtype=torch.uint8, device=device)

             # Compute the surface area between two slices (n_cubes,)
            area_pred = compute_area([y0_pred, y1_pred], unfold, area)
            area_true = compute_area([y0, y1], unfold, area)
    
            # True positive cube indices
            idx = torch.logical_and(area_pred > 0, area_true > 0)
    
            # Surface dice numerator and denominator
            num += area_pred[idx].sum() + area_true[idx].sum()
            denom += area_pred.sum() + area_true.sum()
    
            # Next slice
            y0 = y1
            y0_pred = y1_pred
            i += 1
    dice = num / denom.clamp(min=1e-8)
    return dice.item()


In [None]:
import time


t0 = time.time()
dice = compute_surface_dice_score(
    [m.data for m in mmaps],
    label.data,
    0.2,
)
t1 = time.time()
print(f"{t1 - t0}")