<a href="https://colab.research.google.com/github/theiostream/rational/blob/master/CT2US.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## CT2US

This tool is intended to automate the generation of simulated ultrasound image and label pairs from ct images (.nii/.nii.gz).

---


## WIP:
- there will be a visualizer for the end (slices) and intermediary (segmentation) results
- progress bar will be improved
- code for two alternate optimized segmentation pipelines is still being developed
  - one focusing on avoiding internal totalsegmentator steps being saved to memory
  - another further optimizes by properly using gpu and cpu acceleration.
    
---
Old version of the optimized versions present here, check for potential updates in https://github.com/lczamprogno/CT2US.

In [None]:
#@title This needs to be run once and then the session needs to be restarted
%pip install totalsegmentator dask numba cupy-cuda12x dask_cuda torchvision xmltodict torchio cucim "bokeh>=3.1.0" di gradio tensordict pathlib

In [None]:
#@title Necessary step to use gradio UI within notebook
%load_ext gradio

In [None]:
#@title IMPORTANT: Acquire a totalsegmentator key (https://backend.totalsegmentator.com/license-academic/) and set it in the line in this block marked by a comment
import sys
from pathlib import PosixPath as path
import os

import json
import numpy as np
import xmltodict
from itertools import tee, islice

from nibabel import nifti1
from numpy import uint8

import totalsegmentator.python_api as ts
from totalsegmentator.config import setup_nnunet, setup_totalseg, set_config_key, get_weights_dir

this_folder = path("../CT2US").resolve()

sys.path.append(this_folder)
ts_cfg_path = path.joinpath(this_folder, ".totalsegmentator")
ts_cfg_path.mkdir(exist_ok=True, parents=True)
os.environ["TOTALSEG_HOME_DIR"] = str(ts_cfg_path)

setup_nnunet()
setup_totalseg()
# TODO: Request and set the totalsegmentator license here
ts.set_license_number("aca_IY2KSBZZUM5QZO")
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.utilities.file_path_utilities import get_output_folder

from totalsegmentator.libs import download_model_with_license_and_unpack, download_url_and_unpack
from totalsegmentator.map_to_binary import commercial_models

from numba import jit, njit, cuda
import cupy as cp
import cupyx.scipy.ndimage as cusci
import asyncio
from dask.distributed import Client, LocalCluster
from dask import delayed
from dask.distributed import print as pr
import dask.bag as db
import dask
import torch
from torch.distributions import Normal

import matplotlib.pyplot as plt
import os
from torchvision import transforms
from math import pi


from typing import List, Dict

import torchvision.transforms.functional as F


In [None]:
#@title Setup configs
!wget "https://drive.google.com/uc?export=download&id=1JqoSIrLnZt3Dfet3qTnUtPR77eFfckAb"
!unzip '/content/uc?export=download&id=1JqoSIrLnZt3Dfet3qTnUtPR77eFfckAb' -d '/content/CT2US'
!rm '/content/uc?export=download&id=1JqoSIrLnZt3Dfet3qTnUtPR77eFfckAb'
with open(path.joinpath(this_folder, "configs", "name2label.json")) as n:
    name2label = dict(json.load(n))

In [None]:
#@title US slice simulation code (from https://github.com/danivelikova/lotus/blob/main/models/us_rendering_model.py)
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from torchvision import transforms
from math import pi

# 2 - lung; 3 - fat; 4 - vessel; 6 - kidney; 8 - muscle; 9 - background; 11 - liver; 12 - soft tissue; 13 - bone;
# Default Parameters from: https://github.com/Blito/burgercpp/blob/master/examples/ircad11/liver.scene , labels 8, 9 and 12 approximated from other labels

                     # indexes:           2       3     4     6     8      9     11    12    13
acoustic_imped_def_dict = torch.tensor([0.0004, 1.38, 1.61,  1.62, 1.62,  0.3,  1.65, 1.63, 7.8], requires_grad=True).to(device='cuda')    # Z in MRayl
attenuation_def_dict =    torch.tensor([1.64,   0.63, 0.18,  1.0,  1.09, 0.54,  0.7,  0.54, 5.0], requires_grad=True).to(device='cuda')    # alpha in dB cm^-1 at 1 MHz
mu_0_def_dict =           torch.tensor([0.78,   0.5,  0.001, 0.45,  0.45,  0.3,  0.4, 0.45, 0.78], requires_grad=True).to(device='cuda') # mu_0 - scattering_mu   mean brightness
mu_1_def_dict =           torch.tensor([0.56,   0.5,  0.0,   0.6,  0.64,  0.2,  0.8,  0.64, 0.56], requires_grad=True).to(device='cuda') # mu_1 - scattering density, Nr of scatterers/voxel
sigma_0_def_dict =        torch.tensor([0.1,    0.0,  0.01,  0.3,  0.1,   0.0,  0.14, 0.1,  0.1], requires_grad=True).to(device='cuda') # sigma_0 - scattering_sigma - brightness std


alpha_coeff_boundary_map = 0.1
beta_coeff_scattering = 10  #100 approximates it closer
TGC = 8
CLAMP_VALS = True


def gaussian_kernel(size: int, mean: float, std: float):
    d1 = torch.distributions.Normal(mean, std)
    d2 = torch.distributions.Normal(mean, std*3)
    vals_x = d1.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()
    vals_y = d2.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()

    gauss_kernel = torch.einsum('i,j->ij', vals_x, vals_y)

    return gauss_kernel / torch.sum(gauss_kernel).reshape(1, 1)

g_kernel = gaussian_kernel(3, 0., 0.5)
g_kernel = torch.tensor(g_kernel[None, None, :, :], dtype=torch.float32).to(device='cuda')


class UltrasoundRendering(torch.nn.Module):
    def __init__(self, params, default_param=False):
        super(UltrasoundRendering, self).__init__()
        self.params = params

        if default_param:
            self.acoustic_impedance_dict = acoustic_imped_def_dict.detach().clone()
            self.attenuation_dict = attenuation_def_dict.detach().clone()
            self.mu_0_dict = mu_0_def_dict.detach().clone()
            self.mu_1_dict = mu_1_def_dict.detach().clone()
            self.sigma_0_dict = sigma_0_def_dict.detach().clone()

        else:
            self.acoustic_impedance_dict = torch.nn.Parameter(acoustic_imped_def_dict)
            self.attenuation_dict = torch.nn.Parameter(attenuation_def_dict)

            self.mu_0_dict = torch.nn.Parameter(mu_0_def_dict)
            self.mu_1_dict = torch.nn.Parameter(mu_1_def_dict)
            self.sigma_0_dict = torch.nn.Parameter(sigma_0_def_dict)

        self.labels = ["lung", "fat", "vessel", "kidney", "muscle", "background", "liver", "soft tissue", "bone"]

        self.attenuation_medium_map, self.acoustic_imped_map, self.sigma_0_map, self.mu_1_map, self.mu_0_map  = ([] for i in range(5))


    def map_dict_to_array(self, dictionary, arr):
        mapping_keys = torch.tensor([2, 3, 4, 6, 8, 9, 11, 12, 13], dtype=torch.long).to(device='cuda')
        keys = torch.unique(arr).to(device='cuda')

        index = torch.where(mapping_keys[None, :] == keys[:, None])[1]
        values = torch.gather(dictionary, dim=0, index=index)
        values = values.to(device='cuda')
        # values.register_hook(lambda grad: print(grad))    # check the gradient during training

        mapping = torch.zeros(keys.max().item() + 1).to(device='cuda')
        mapping[keys] = values
        return mapping[arr]


    def plot_fig(self, fig, fig_name, grayscale):
        save_dir='results_test/'
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        plt.clf()

        if torch.is_tensor(fig):
            fig = fig.cpu().detach().numpy()

        if grayscale:
            plt.imshow(fig, cmap='gray', vmin=0, vmax=1, interpolation='none', norm=None)
        else:
            plt.imshow(fig, interpolation='none', norm=None)
        plt.axis('off')
        plt.savefig(save_dir + fig_name + '.png', bbox_inches='tight',transparent=True, pad_inches=0)


    def clamp_map_ranges(self):
        self.attenuation_medium_map = torch.clamp(self.attenuation_medium_map, 0, 10)
        self.acoustic_imped_map = torch.clamp(self.acoustic_imped_map, 0, 10)
        self.sigma_0_map = torch.clamp(self.sigma_0_map, 0, 1)
        self.mu_1_map = torch.clamp(self.mu_1_map, 0, 1)
        self.mu_0_map = torch.clamp(self.mu_0_map, 0, 1)


    def rendering(self, H, W, z_vals=None, refl_map=None, boundary_map=None):

        dists = torch.abs(z_vals[..., :-1, None] - z_vals[..., 1:, None])     # dists.shape=(W, H-1, 1)
        dists = dists.squeeze(-1)                                             # dists.shape=(W, H-1)
        dists = torch.cat([dists, dists[:, -1, None]], dim=-1)                # dists.shape=(W, H)

        attenuation = torch.exp(-self.attenuation_medium_map * dists)
        attenuation_total = torch.cumprod(attenuation, dim=1, dtype=torch.float32, out=None)

        gain_coeffs = np.linspace(1, TGC, attenuation_total.shape[1])
        gain_coeffs = np.tile(gain_coeffs, (attenuation_total.shape[0], 1))
        gain_coeffs = torch.tensor(gain_coeffs).to(device='cuda')
        attenuation_total = attenuation_total * gain_coeffs     # apply TGC

        reflection_total = torch.cumprod(1. - refl_map * boundary_map, dim=1, dtype=torch.float32, out=None)
        reflection_total = reflection_total.squeeze(-1)
        reflection_total_plot = torch.log(reflection_total + torch.finfo(torch.float32).eps)

        texture_noise = torch.randn(H, W, dtype=torch.float32).to(device='cuda')
        scattering_probability = torch.randn(H, W, dtype=torch.float32).to(device='cuda')

        scattering_zero = torch.zeros(H, W, dtype=torch.float32).to(device='cuda')

        z = self.mu_1_map - scattering_probability
        sigmoid_map = torch.sigmoid(beta_coeff_scattering * z)

        # approximating  Eq. (4) to be differentiable:
        # where(scattering_probability <= mu_1_map,
        #                     texture_noise * sigma_0_map + mu_0_map,
        #                     scattering_zero)
        scatterers_map =  (sigmoid_map) * (texture_noise * self.sigma_0_map + self.mu_0_map) + (1 -sigmoid_map) * scattering_zero   # Eq. (6)

        psf_scatter_conv = torch.nn.functional.conv2d(input=scatterers_map[None, None, :, :], weight=g_kernel, stride=1, padding="same")
        psf_scatter_conv = psf_scatter_conv.squeeze()

        b = attenuation_total * psf_scatter_conv    # Eq. (3)

        border_convolution = torch.nn.functional.conv2d(input=boundary_map[None, None, :, :], weight=g_kernel, stride=1, padding="same")
        border_convolution = border_convolution.squeeze()

        r = attenuation_total * reflection_total * refl_map * border_convolution # Eq. (2)

        intensity_map = b + r   # Eq. (1)
        intensity_map = intensity_map.squeeze()
        intensity_map = torch.clamp(intensity_map, 0, 1)

        return intensity_map, attenuation_total, reflection_total_plot, scatterers_map, scattering_probability, border_convolution, texture_noise, b, r


    def render_rays(self, W, H):
        N_rays = W
        t_vals = torch.linspace(0., 1., H).to(device='cuda')   # 0-1 linearly spaced, shape H
        z_vals = t_vals.unsqueeze(0).expand(N_rays , -1) * 4

        return z_vals

    # warp the linear US image to approximate US image from curvilinear US probe
    def warp_img(self, inputImage):
        resultWidth = 360
        resultHeight = 220
        centerX = resultWidth / 2
        centerY = -120.0
        maxAngle =  60.0 / 2 / 180 * pi #rad
        minAngle = -maxAngle
        minRadius = 140.0
        maxRadius = 340.0

        h, w = inputImage.squeeze().shape

        import torch.nn.functional as F

        # Create x and y grids
        x = torch.arange(resultWidth).float() - centerX
        y = torch.arange(resultHeight).float() - centerY
        xx, yy = torch.meshgrid(x, y)

        # Calculate angle and radius
        angle = torch.atan2(xx, yy)
        radius = torch.sqrt(xx ** 2 + yy ** 2)

        # Create masks for angle and radius
        angle_mask = (angle > minAngle) & (angle < maxAngle)
        radius_mask = (radius > minRadius) & (radius < maxRadius)

        # Calculate original column and row
        origCol = (angle - minAngle) / (maxAngle - minAngle) * w
        origRow = (radius - minRadius) / (maxRadius - minRadius) * h

        # Reshape input image to be a batch of 1 image
        inputImage = inputImage.float().unsqueeze(0).unsqueeze(0)

        # Scale original column and row to be in the range [-1, 1]
        origCol = origCol / (w - 1) * 2 - 1
        origRow = origRow / (h - 1) * 2 - 1

        # Transpose input image to have channels first
        inputImage = inputImage.permute(0, 1, 3, 2)

        # Use grid_sample to interpolate
        grid = torch.stack([origCol, origRow], dim=-1).unsqueeze(0).to('cuda')
        resultImage = F.grid_sample(inputImage, grid, mode='bilinear', align_corners=True)

        # Apply masks and set values outside of mask to 0
        resultImage[~(angle_mask.unsqueeze(0).unsqueeze(0) & radius_mask.unsqueeze(0).unsqueeze(0))] = 0.0
        resultImage_resized = transforms.Resize((256,256))(resultImage).float().squeeze()

        return resultImage_resized


    def forward(self, ct_slice):
        if self.params["debug"]: self.plot_fig(ct_slice, "ct_slice", False)

        #init tissue maps
        #generate 2D acousttic_imped map
        self.acoustic_imped_map = self.map_dict_to_array(self.acoustic_impedance_dict, ct_slice)#.astype('int64'))

        #generate 2D attenuation map
        self.attenuation_medium_map = self.map_dict_to_array(self.attenuation_dict, ct_slice)

        if self.params["debug"]:
            self.plot_fig(self.acoustic_imped_map, "acoustic_imped_map", False)
            self.plot_fig(self.attenuation_medium_map, "attenuation_medium_map", False)

        self.mu_0_map = self.map_dict_to_array(self.mu_0_dict, ct_slice)

        self.mu_1_map = self.map_dict_to_array(self.mu_1_dict, ct_slice)

        self.sigma_0_map = self.map_dict_to_array(self.sigma_0_dict, ct_slice)

        self.acoustic_imped_map = torch.rot90(self.acoustic_imped_map, 1, [0, 1])
        diff_arr = torch.diff(self.acoustic_imped_map, dim=0)

        diff_arr = torch.cat((torch.zeros(diff_arr.shape[1], dtype=torch.float32).unsqueeze(0).to(device='cuda'), diff_arr))

        boundary_map =  -torch.exp(-(diff_arr**2)/alpha_coeff_boundary_map) + 1

        boundary_map = torch.rot90(boundary_map, 3, [0, 1])

        if self.params["debug"]:
           self.plot_fig(diff_arr, "diff_arr", False)
           self.plot_fig(boundary_map, "boundary_map", True)

        shifted_arr = torch.roll(self.acoustic_imped_map, -1, dims=0)
        shifted_arr[-1:] = 0

        sum_arr = self.acoustic_imped_map + shifted_arr
        sum_arr[sum_arr == 0] = 1
        div = diff_arr / sum_arr

        refl_map = div ** 2
        refl_map = torch.sigmoid(refl_map)      # 1 / (1 + (-refl_map).exp())
        refl_map = torch.rot90(refl_map, 3, [0, 1])

        if self.params["debug"]: self.plot_fig(refl_map, "refl_map", True)

        z_vals = self.render_rays(ct_slice.shape[0], ct_slice.shape[1])

        if CLAMP_VALS:
            self.clamp_map_ranges()

        ret_list = self.rendering(ct_slice.shape[0], ct_slice.shape[1], z_vals=z_vals, refl_map=refl_map, boundary_map=boundary_map)

        intensity_map  = ret_list[0]

        if self.params["debug"]:
            self.plot_fig(intensity_map, "intensity_map", True)

            result_list = ["intensity_map", "attenuation_total", "reflection_total",
                            "scatters_map", "scattering_probability", "border_convolution",
                            "texture_noise", "b", "r"]

            for k in range(len(ret_list)):
                result_np = ret_list[k]
                if torch.is_tensor(result_np):
                    result_np = result_np.detach().cpu().numpy()

                if k==2:
                    self.plot_fig(result_np, result_list[k], False)
                else:
                    self.plot_fig(result_np, result_list[k], True)
                # print(result_list[k], ", ", result_np.shape)

        intensity_map_masked = self.warp_img(intensity_map)
        intensity_map_masked = torch.rot90(intensity_map_masked, 3)

        if self.params["debug"]:  self.plot_fig(intensity_map_masked, "intensity_map_masked", True)

        return intensity_map_masked


In [None]:
#@title Segmentation, Composition and US slicing code
import scipy
from torch import device
# from torch.nn import OptimizedModule
def dict_2_map(d: dict[list[uint8], uint8]) -> list[list[uint8]]:
    map = [[] for _ in range(15)]

    for k, v in d.items():
        int_k = uint8(k)
        map[v].append(int_k)

    return map

def batched(self, iterable, n):
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch

# Save time by initializing predictors once, instead of for each task
def initialize_predictors(device: str = 'cuda',
                        folds: list = (0,)) -> dict:
    """
    Initialize nnUNetPredictor instances for each segmentation task.

    Args:
        device (str): Device to run predictions on ('cuda', 'cpu', 'mps').
        use_folds (tuple): Fold indices to use for prediction.

    Returns:
        dict: Dictionary mapping task names to their respective nnUNetPredictor instances.
    """
    # Define tasks
    tasks = [("total",
            [291, 292, 293, 294, 295],
            ["Dataset291_TotalSegmentator_part1_organs_1559subj",
            "Dataset292_TotalSegmentator_part2_vertebrae_1532subj",
            "Dataset293_TotalSegmentator_part3_cardiac_1559subj",
            "Dataset294_TotalSegmentator_part4_muscles_1559subj",
            "Dataset295_TotalSegmentator_part5_ribs_1559subj"],
            ["/v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip",
            "/v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip",
            "/v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip",
            "/v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip",
            "/v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"],
            "nnUNetTrainerNoMirroring",
            False),
            ("tissue_types",
            [481],
            ["Dataset481_tissue_1559subj"],
            [],
            "nnUNetTrainer",
            True),
            ("body",
            [299],
            ["Dataset299_body_1559subj"],
            ["/v2.0.0-weights/Dataset299_body_1559subj.zip"],
            "nnUNetTrainer",
            False)]

    commercial_models_inv = {v: k for k, v in commercial_models.items()}
    base_url = "https://github.com/wasserth/TotalSegmentator/releases/download"

    # Get weights directory
    weights_dir = get_weights_dir()

    predictors = {}
    for task_name, task_ids, paths, urls, trainer,with_license in tasks:
        print(f"INIT: {task_name} predictor")
        if with_license:
            for i in range(len(task_ids)):
                if paths[i] not in os.listdir(weights_dir):
                    download_model_with_license_and_unpack(commercial_models_inv[task_ids[i]], weights_dir / paths[i])

                # Initialize the predictor
                predictor = nnUNetPredictor(
                    tile_step_size=0.5,
                    use_gaussian=True,
                    use_mirroring=True,
                    perform_everything_on_device=(device == 'cuda'),
                    device=torch.device(device, 0),
                    verbose=True,
                    allow_tqdm=True
                )
                # Initialize from the trained model folder
                predictor.initialize_from_trained_model_folder(
                    str(weights_dir / paths[i] / (trainer + "__nnUNetPlans__3d_fullres")),
                    use_folds=folds,
                    checkpoint_name='checkpoint_final.pth'
                )

                predictors[task_ids[i]] = predictor

        else:
            for i in range(len(urls)):
                if paths[i] not in os.listdir(weights_dir):
                    download_url_and_unpack(base_url + urls[i], weights_dir / paths[i])

                # Initialize the predictor
                predictor = nnUNetPredictor(
                    tile_step_size=0.5,
                    use_gaussian=True,
                    use_mirroring=True,
                    perform_everything_on_device=(device == 'cuda'),
                    device=torch.device(device, 0),
                    verbose=True,
                    allow_tqdm=True
                )
                # Initialize from the trained model folder
                predictor.initialize_from_trained_model_folder(
                    str(weights_dir / paths[i] / (trainer + "__nnUNetPlans__3d_fullres")),
                    use_folds=folds,
                    checkpoint_name='checkpoint_final.pth'
                )
                predictors[task_ids[i]] = predictor

    return predictors

def bin_erosion(kernel:torch.Tensor, padded:torch.Tensor, ret:torch.Tensor):
    # Assumes stacked 3d and no normalization needed
    i, hdx, idx, jdx = cuda.grid(4)

    # Run kernel
    window = padded[i,
                    hdx-int((kernel.shape[0]-1) / 2):hdx+int((kernel.shape[0]-1) / 2),
                    idx-int((kernel.shape[0]-1) / 2):idx+int((kernel.shape[0]-1) / 2),
                    jdx-int((kernel.shape[0]-1) / 2):jdx+int((kernel.shape[0]-1) / 2)]
    # TODO: does this also get JITed?
    match = torch.all(kernel == window)
    ret[i, hdx, idx, jdx] = 1 if match else 0

def bin_dilation(kernel:torch.Tensor, padded:torch.Tensor, ret:torch.Tensor):
    # Assumes stacked 3d and no normalization needed
    i, hdx, idx, jdx = cuda.grid(4)

    # Run kernel
    window = padded[i,
                    hdx-int((kernel.shape[0]-1) / 2):hdx+int((kernel.shape[0]-1) / 2),
                    idx-int((kernel.shape[0]-1) / 2):idx+int((kernel.shape[0]-1) / 2),
                    jdx-int((kernel.shape[0]-1) / 2):jdx+int((kernel.shape[0]-1) / 2)]
    # TODO: does this also get JITed?
    match = torch.any(kernel == window)
    ret[i, hdx, idx, jdx] = 1 if match else 0

class CT2US(torch.nn.Module):
    def seg_predictor(self, imgs, properties, task, resamp_thr):
        return self.predictors[task].predict_from_data_iterator(
                                        self.iterator(self.predictors[task], imgs, properties),
                                        save_probabilities=False,
                                        num_processes_segmentation_export=resamp_thr
                                    )

        # Does not work for some reason, totalsegmentator returns zeros instead of labels
    def seg_old(self, imgs, properties, task, resamp_thr):
        ret = []

        # TODO Get list from "total" labels and find matches in totalsegmentator
        roi = [l for _, l in name2label["total"].items()]
        roi = np.concatenate(roi).tolist()
        print(roi)
        if task == "total":
            for img in imgs:
                ret.append(np.asarray(ts.totalsegmentator(
                                    input=img,
                                    task=task,
                                    nr_thr_resamp=resamp_thr
                                    # roi_subset=roi
                                ).dataobj, dtype=np.uint8))

        else:
            for img in imgs:
                ret.append(np.asarray(ts.totalsegmentator(
                                    input=img,
                                    task=task,
                                    nr_thr_resamp=resamp_thr
                                ).dataobj, dtype=np.uint8))


        return ret

    def __init__(self, method: str = 'paths', device: str = 'cuda'):
        super(CT2US, self).__init__()
        methods = {'old', 'new', 'paths'}
        if not method in methods:
            raise KeyError(f"Method not supported, choose from {methods}")
        else:
            self.method = method

        if device == 'cuda' and torch.cuda.is_available():
            self.device = torch.device(device, torch.cuda.current_device())
            self.m = cp
        else:
            self.device = torch.device('cpu')
            self.m = np

        self._dil_cuda = cuda.jit(bin_dilation)
        self._er_cuda = cuda.jit(bin_erosion)
        self._dil_cpu = njit(bin_dilation)
        self._er_cpu = njit(bin_erosion)

        if method == 'predictor':
            predictors = initialize_predictors(device=device, folds=[0])
            self.predictors = predictors
            self.predictor_keys = predictors.keys()

        segmentator = {
            'new': lambda imgs, properties, task, resamp_thr: (NotImplementedError("WIP")), #self.predict_tensor_iter,
            'predictor': self.seg_predictor,
            'old': self.seg_old
        }

        self.segmentator = segmentator[method]
        us = {
            'new': self.to_us_sim_new,
            'predictor': self.to_us_sim_old,
            'old': self.to_us_sim_old
        }
        self.us = us[method]

        composer = {
            'new': self.stacked_assemble_tid,
            'predictor': self.assemble,
            'old': self.assemble
        }
        self.composer = composer[method]

        hparams = {
            'debug' : False,
            'device' : device
        }

        self.ultrasound_rendering = UltrasoundRendering(hparams, default_param=True)

        with open(pthlib(this_folder, "configs", "total_lmaps.json")) as p:
            total_lmap = dict(json.load(p))

        with open(pthlib.joinpath(pthlib(this_folder), "configs", "name2label.json")) as n:
            self.name2label = dict(json.load(n))

        self.tmap = dict_2_map(total_lmap)

    def bin_dilation(self, imgs:torch.Tensor, kernel_size:int ,iterations:int):
        kernel = torch.ones((kernel_size, kernel_size, kernel_size), dtype=torch.uint8)
        if imgs.is_cuda:
            d_imgs = cuda.as_cuda_array(imgs.detach())
            kernel = cuda.as_cuda_array(kernel.detach())
            threadsperblock = (1, kernel_size, kernel_size, kernel_size)
            blocks = (imgs.shape[0],
                        np.ceil(imgs.shape[1] / threadsperblock[1]),
                        np.ceil(imgs.shape[2] / threadsperblock[2]),
                        np.ceil(imgs.shape[3] / threadsperblock[3]))
            for _ in iterations:
                ret = cuda.as_cuda_array(torch.zeros(imgs.shape, device=imgs.device).detach())
                padded = cuda.as_cuda_array(
                            imgs.to_padded_tensor(
                                padding=0,
                                output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                            ).detach())
                self._dil_cuda[blocks, threadsperblock](kernel, padded, ret)
                d_imgs = ret
        else:
            d_imgs = imgs.detach().numpy()
            kernel = kernel.detach()
            for _ in iterations:
                ret = torch.zeros(imgs.shape, device=imgs.device).detach()
                padded = imgs.to_padded_tensor(
                            padding=0,
                            output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                        ).detach()

                self._dil_cpu(kernel, padded, ret)
                d_imgs = ret

        return d_imgs

    def bin_erosion(self, imgs:torch.Tensor, kernel_size:int ,iterations:int):
        kernel = torch.ones((kernel_size, kernel_size, kernel_size), dtype=torch.uint8)
        if imgs.is_cuda:
            d_imgs = cuda.as_cuda_array(imgs.detach())
            kernel = cuda.as_cuda_array(kernel.detach())
            threadsperblock = (1, kernel_size, kernel_size, kernel_size)
            blocks = (imgs.shape[0],
                        np.ceil(imgs.shape[1] / threadsperblock[1]),
                        np.ceil(imgs.shape[2] / threadsperblock[2]),
                        np.ceil(imgs.shape[3] / threadsperblock[3]))
            for _ in iterations:
                ret = cuda.as_cuda_array(torch.zeros(imgs.shape, device=imgs.device).detach())
                padded = cuda.as_cuda_array(
                            imgs.to_padded_tensor(
                                padding=0,
                                output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                            ).detach())
                self._er_cuda[blocks, threadsperblock](kernel, padded, ret)
                d_imgs = ret
        else:
            d_imgs = imgs.detach().numpy()
            kernel = kernel.detach()
            for _ in iterations:
                ret = torch.zeros(imgs.shape, device=imgs.device).detach()
                padded = imgs.to_padded_tensor(
                            padding=0,
                            output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                        ).detach()

                self._er_cpu(kernel, padded, ret)
                d_imgs = ret

        return d_imgs

    # Adapted from nnUNetPredictor
    def iterator(self,
                predictor: nnUNetPredictor,
                imgs: list[np.ndarray],
                properties: list[dict]):

        # MAYBE: look at data_iterators.preprocess_fromnpy_save_to_queue for vstack use for ROI foreground masking

        # pp = predictor.get_data_iterator_from_raw_npy_data(
        #     imgs,
        #     properties
        # )

        # preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)

        # properties = {key: [i[key] for i in properties] for key in properties[0]}

        # data, seg = preprocessor.run_case_npy(
        #                 np.stack(imgs),
        #                 None,
        #                 properties,
        #                 predictor.plans_manager,
        #                 predictor.configuration_manager,
        #                 predictor.dataset_json
        #             )

        # pass

        preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
        for a, p in zip(imgs, properties):
            data, seg = preprocessor.run_case_npy(a,
                                                  None,
                                                  p,
                                                  predictor.plans_manager,
                                                  predictor.configuration_manager,
                                                  predictor.dataset_json)
            yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}

    def convert_logits_to_segmentation(self, prediction, properties, predictor):
        spacing_transposed = [properties['spacing'][i] for i in predictor.plans_manager.transpose_forward]
        current_spacing = predictor.configuration_manager.spacing if \
            len(predictor.configuration_manager.spacing) == \
            len(properties['shape_after_cropping_and_before_resampling']) else \
            [spacing_transposed[0], *predictor.configuration_manager.spacing]
        predicted_logits = predictor.configuration_manager.resampling_fn_probabilities(predicted_logits,
                                                properties['shape_after_cropping_and_before_resampling'],
                                                current_spacing,
                                                [properties['spacing'][i] for i in predictor.plans_manager.transpose_forward])
        # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
        # apply_inference_nonlin will convert to torch
        predicted_probabilities = predictor.label_manager.apply_inference_nonlin(predicted_logits)
        del predicted_logits
        segmentation = predictor.label_manager.convert_probabilities_to_segmentation(predicted_probabilities)

        # put segmentation in bbox (revert cropping)
        segmentation_reverted_cropping = np.zeros(properties['shape_before_cropping'],
                                                dtype=np.uint8 if len(predictor.label_manager.foreground_labels) < 255 else np.uint16)
        slicer = tuple([slice(*i) for i in properties['bbox_used_for_cropping']])
        segmentation_reverted_cropping[slicer] = segmentation
        del segmentation

        pass

    # Adapted from nnUNetPredictor
    def predict_tensor_iter(self,
                        data_iterator,
                        device: str = 'cuda') -> list[torch.tensor]:

        r = []
        for preprocessed in data_iterator:
            asm = []
            data = preprocessed['data']
            print(f'perform_everything_on_device: {device=="cuda"}')
            properties = preprocessed['data_properties']

            for predictor in self.predictors.values():
                old_threads = torch.get_num_threads()
                # HYPERPARAMETER: number of threads to use for prediction
                default_num_processes = 4
                torch.set_num_threads(default_num_processes if default_num_processes < old_threads else old_threads)
                prediction = None

                for params in predictor.list_of_parameters:

                    # messing with state dict names...
                    # if not isinstance(predictor.network, OptimizedModule):
                    #     predictor.network.load_state_dict(params)
                    # else:
                    #     predictor.network._orig_mod.load_state_dict(params)

                    if prediction is None:
                        prediction = predictor.predict_sliding_window_return_logits(data)
                    else:
                        prediction += predictor.predict_sliding_window_return_logits(data)

                if len(predictor.list_of_parameters) > 1:
                    prediction /= len(self.list_of_parameters)

                prediction = self.convert_logits_to_segmentation(prediction, properties, predictor)

            print(f'\nDone with image of shape {data.shape}:')

            # clear lru cache
            compute_gaussian.cache_clear()
            # clear device cache
            if device.type == 'cuda':
                torch.cuda.empty_cache()

            r.append()

        return [i.get()[0] for i in r]

    def assemble(self,
                task:str,
                segs:list[np.ndarray],
                bases:list[np.ndarray],
                prev:list[np.ndarray]) -> list[np.ndarray]:

        pr("ASSEMBLY STARTED")
        # Process total segmentation

        if task == 'total':
            for j in range(len(segs)):
                for i in range(len(self.tmap)):
                    if len(self.tmap[i]) > 0:  # if there are any keys for this value
                        a = self.m.where(self.m.isin(self.m.asarray(segs[j], dtype=self.m.uint8), self.m.array(self.tmap[i])), self.m.uint8(i), self.m.uint8(0))
                        prev[j] += a

        if task == 'tissue_types':
            for j in range(len(segs)):
                t = self.m.asarray(segs[j])
                prev[j][t == 1] = self.m.uint8(name2label["body"]["fat"])
                prev[j][t == 2] = self.m.uint8(name2label["body"]["fat"])

        if task == 'body':
            for j in range(len(segs)):
                t = self.m.asarray(segs[j])
                body = cusci.binary_dilation(t == 1, iterations=1).astype(self.m.uint8)
                body_inner = cusci.binary_erosion(t, iterations=3, brute_force=True).astype(self.m.uint8)
                skin = body - body_inner

                # Segment by density
                # Roughly the skin density range. Made large to make segmentation not have holes
                # (0 to 250 would have many small holes in skin)
                density_mask = (bases[j] > -200) & (bases[j] < 250)
                skin[~density_mask] = 0

                # Fill holes
                # skin = binary_closing(skin, iterations=1)  # no real difference
                # skin = binary_dilation(skin, iterations=1)  # not good

                mask, _ = cusci.label(skin)
                counts = self.m.bincount(mask.flatten())  # number of pixels in each blob

                # If only one blob (only background) abort because nothing to remove
                if len(counts) > 1:
                    remove = self.m.where((counts <= 10) | (counts > 30), True, False)
                    remove_idx = self.m.nonzero(remove)[0]
                    mask[self.m.isin(self.m.array(mask), remove_idx)] = 0
                    mask[mask > 0] = 1

                # Removing blobs
                # End of snippet from totalsegmentator

                dilation_kernel = self.m.ones(shape=(2, 2, 2))
                skin = self.m.where(cusci.binary_dilation(skin == 1, structure=dilation_kernel), self.m.uint8(1), self.m.uint8(0))

                prev[j][skin == 1] = self.m.uint8(name2label["body"]["skin"])

                tmp = prev[j].copy()
                prev[j][tmp == 0] = self.m.uint8(name2label["body"]["bg"])

        pr("ASSEMBLY COMPLETED")

        del segs, bases

        return prev

    def stacked_assemble_tname(self, task:str,
                segs:list[np.ndarray],
                stacked_bases:list[np.ndarray],
                prev: list[np.ndarray]) -> list[np.ndarray]:

        pr("ASSEMBLY STARTED")

        # Process total segmentation
        labels = prev

        if task == "total":
            stacked_totals = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)

            for i in range(len(self.tmap)):
                if len(self.tmap[i]) > 0:  # if there are any keys for this value
                    labels += torch.where(torch.isin(stacked_totals, torch.as_tensor(self.tmap[i], device=self.device)), np.uint8(i), np.uint8(0))

        elif task == "tissue_types":
            stacked_tissues = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)

            labels[stacked_tissues == 1] = np.uint8(name2label["body"]["fat"])
            labels[stacked_tissues == 2] = np.uint8(name2label["body"]["fat"])

        elif task == "body":
            stacked_outers = torch.stack([torch.as_tensor(segs[i], device=self.device, dtype=torch.uint8) for i in range(stacked_bases.shape[0])], axis=0)

            # Adapted code snippet from totalsegmentator
            body = self.bin_dilation(stacked_outers == 1, kernel_size=3, iterations=1).astype(torch.uint8)
            body_inner = self.bin_erosion(stacked_outers == 1, kernel_size=3, iterations=3).astype(torch.uint8)
            skin = body - body_inner

            # Segment by density
            # Roughly the skin density range. Made large to make segmentation not have holes
            # (0 to 250 would have many small holes in skin)
            density_mask = (stacked_bases > -200) & (stacked_bases < 250)
            skin[~density_mask] = 0

            # Fill holes
            # skin = binary_closing(skin, iterations=1)  # no real difference
            # skin = binary_dilation(skin, iterations=1)  # not good

            if torch.cuda.is_available():
                mask, _ = cusci.label(skin)
            else:
                mask, _ = scipy.ndimage.label(skin)

            counts = torch.bincount(mask.flatten())  # number of pixels in each blob

            # If only one blob (only background) abort because nothing to remove
            if len(counts) > 1:
                remove = torch.where((counts <= 10) | (counts > 30), True, False)
                remove_idx = torch.nonzero(remove)[0]
                mask[torch.isin(mask, remove_idx)] = 0
                mask[mask > 0] = 1

            # Removing blobs
            # End of snippet from totalsegmentator
            mask = torch.where(self.bin_dilation(mask == 1, kernel_size=3, iterations = 2), np.uint8(1), np.uint8(0))

            labels[mask == 1] = np.uint8(name2label["body"]["skin"])

            tmp = labels.copy()
            labels[tmp == 0] = np.uint8(name2label["body"]["bg"])

        pr("ASSEMBLY COMPLETED")
        del segs, bases

        return labels

    def stacked_assemble_tid(self, task:str,
                segs:list[np.ndarray],
                stacked_bases:list[np.ndarray],
                prev: list[np.ndarray]) -> list[np.ndarray]:

        pr("ASSEMBLY STARTED")

        # Process total segmentation
        labels = prev

        if task == "total":
            stacked_totals = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)

            for i in range(len(self.tmap)):
                if len(self.tmap[i]) > 0:  # if there are any keys for this value
                    labels += torch.where(torch.isin(stacked_totals, torch.as_tensor(self.tmap[i], device=self.device)), np.uint8(i), np.uint8(0))

        elif task == "tissue_types":
            stacked_tissues = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)

            labels[stacked_tissues == 1] = np.uint8(name2label["body"]["fat"])
            labels[stacked_tissues == 2] = np.uint8(name2label["body"]["fat"])

        elif task == "body":
            stacked_outers = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)

            # Adapted from TotalSegmentator
            body = self.bin_dilation(stacked_outers == 1, kernel_size=3, iterations=1).astype(torch.uint8)
            body_inner = self.bin_erosion(stacked_outers == 1, kernel_size=3, iterations=3).astype(torch.uint8)

            skin = body - body_inner

            # Segment by density
            # Roughly the skin density range. Made large to make segmentation not have holes
            # (0 to 250 would have many small holes in skin)
            density_mask = (stacked_bases > -200) & (stacked_bases < 250)
            skin[~density_mask] = 0

            # Fill holes
            # skin = binary_closing(skin, iterations=1)  # no real difference
            # skin = binary_dilation(skin, iterations=1)  # not good

            if torch.cuda.is_available():
                mask, _ = cusci.label(skin)
            else:
                mask, _ = scipy.ndimage.label(skin)

            counts = torch.bincount(mask.flatten())  # number of pixels in each blob

            # If only one blob (only background) abort because nothing to remove
            if len(counts) > 1:
                remove = torch.where((counts <= 10) | (counts > 30), True, False)
                remove_idx = torch.nonzero(remove)[0]
                mask[torch.isin(mask, remove_idx)] = 0
                mask[mask > 0] = 1

            # Removing blobs
            # End of snippet from totalsegmentator

            mask = torch.where(self.bin_dilation(mask == 1, kernel_size=3, iterations = 2), np.uint8(1), np.uint8(0))

            labels[mask == 1] = np.uint8(name2label["body"]["skin"])

            tmp = labels.copy()
            labels[tmp == 0] = np.uint8(name2label["body"]["bg"])

        pr("ASSEMBLY COMPLETED")
        del segs, bases

        return labels

    def to_us_sim_new(self, segs:list[np.ndarray], dest_us: list[str], step_size:int) -> list[list[np.ndarray]]:
        pr("US SIMULATION STARTED")

        hparams = {
            'debug' : False,
            'device' : 'cuda'
        }

        us_r = UltrasoundRendering(params=hparams, default_param=True).to(hparams['device'])

        transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize([380, 380], transforms.InterpolationMode.NEAREST),
                    transforms.CenterCrop((256)),
                ])

        results = []
        for i in range(len(segs)):
            us_images = []
            labelmap = segs[i].get()
            dest = pthlib(dest_us[i]).joinpath("slice_")
            os.makedirs(dest.parent, exist_ok=True)

            for slice_idx in range(0, labelmap.shape[2], step_size):
                slice_data = labelmap[:, :, slice_idx].astype('int64')
                labelmap_slice = transform(slice_data).squeeze()

                us_image = us_r(labelmap_slice)
                us_images.append(us_image.cpu().numpy())

                us_image_pil = transforms.ToPILImage()(us_image.cpu().squeeze())
                us_image_pil.save(f"{dest}_{slice_idx}.png")

            results.append(us_images)

        print("US SIMULATION COMPLETED")

        return results

    def to_us_sim_old(self, segs:list[np.ndarray], dest_us: list[str],  step_size:int) -> list[list[np.ndarray]]:
        pr("US SIMULATION STARTED")

        hparams = {
            'debug' : False,
            'device' : 'cuda'
        }

        us_r = UltrasoundRendering(params=hparams, default_param=True).to(hparams['device'])

        transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize([380, 380], transforms.InterpolationMode.NEAREST),
                    transforms.CenterCrop((256)),
                ])

        results = []
        for i in range(len(segs)):
            us_images = []
            labelmap = segs[i].get()
            dest = pthlib(dest_us[i]).joinpath("slice_")
            os.makedirs(dest.parent, exist_ok=True)

            for slice_idx in range(0, labelmap.shape[2], step_size):
                slice_data = labelmap[:, :, slice_idx].astype('int64')
                labelmap_slice = transform(slice_data).squeeze()

                us_image = us_r(labelmap_slice)
                us_images.append(us_image.cpu().numpy())

                us_image_pil = transforms.ToPILImage()(us_image.cpu().squeeze())
                us_image_pil.save(f"{dest}_{slice_idx}.png")


            results.append(us_images)

        print("US SIMULATION COMPLETED")

        return results


    def forward(self,
                    imgs: list[nifti1.Nifti1Image|np.ndarray|torch.Tensor],
                    properties:list[dict],
                    dest_label: list[str],
                    dest_us: list[str],
                    step_size: int,
                    save_labels: bool,
                ) -> list[list[np.ndarray]]:

        if not self.method == 'old':
            bases = torch.stack([
                        torch.as_tensor(
                            img.get_fdata(),
                            device=self.device
                        ) for img in imgs]
                    ).cuda(self.device)

            f_labels = torch.stack([
                            torch.zeros(
                                bases[i].shape,
                                dtype=torch.uint8,
                                device=self.device
                            ) for i in range(bases.shape[0])],
                            axis=0
                        ).cuda(self.device)
        else:
            bases = [self.m.array(img.dataobj, dtype=self.m.float32) for img in imgs]
            f_labels = [self.m.zeros(bases[idx].shape, dtype=self.m.uint8) for idx in range(len(imgs))]


        if not self.method == "old":
            tasks = list(self.predictor_keys)
        else:
            tasks = ["total", "tissue_types", "body"]

        print("SEGMENTATING:")

        tmp = []
        for idx in range(len(tasks)):
            # tmp = predictors[idx].predict_from_list_of_npy_arrays(imgs, None, properties, None, 2, save_probabilities=False, num_processes_segmentation_export=2)

            tmp.append(self.segmentator(imgs, properties, tasks[idx], 4))
            print("SEG DONE!")

            # f_labels = self.composer(tasks[idx], tmp, stacked_bases, f_labels)

        for idx in range(len(tasks)):
            f_labels = self.composer(tasks[idx], tmp[idx], bases, f_labels)

        us = self.us(f_labels.copy(), dest_us, step_size)

        if save_labels:
            for idx in range(len(f_labels)):
                SimpleITKIO().write_seg(f_labels[idx].get().transpose(2, 1, 0), dest_label[idx], properties[idx])
                print(f"SAVED TO '{dest_label[idx]}'")

        # r_us = self.us[self.method](f_labels, bases, dest_us, dest_label, step_size, save_labels)

        # return f_labels, stacked_bases, dest_us, dest_label, step_size, save_labels
        return "Success"

In [None]:
#@title Acquire samples
todo_dir = path.joinpath(this_folder, "sample")
todo_dir.mkdir(exist_ok=True)

if not any(todo_dir.iterdir()):
    !wget -O /content/CT2US/sample/sample.zip "https://www.dropbox.com/scl/fi/mvr3l7ndar1b36c441ht1/synapse_raw_test0.zip?rlkey=7bzc6r7aqs0eyyh0eaidhnn31&st=z58ym3b5&dl=1"
    !unzip '/content/CT2US/sample/sample.zip' -d '/content/CT2US/sample'
    !rm '/content/CT2US/sample/sample.zip'


In [None]:
#@title gradio interface code
import shutil
import gradio as gr
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

absolute_path = str(pthlib("../ct2us/imgs").resolve())
device = "cuda" if torch.cuda.is_available() else "cpu"

def process(fileobjs: list, step_size:int, save_labels:bool=False, progress=gr.Progress()):
    for f in glob.glob(f"{absolute_path}/*"):
        os.remove(f)

    for f in fileobjs:
        shutil.copyfile(f.name, absolute_path + "/" + pthlib(f.name).name)
        shutil.rmtree(f.name, ignore_errors=True)

    local_dataset = CTDataset(
        device=device,
        img_dir=absolute_path,
        method='old',
        resample=None
    )
    batch_size = 1

    ct_dataloader = DataLoader(local_dataset, batch_size=batch_size, collate_fn=local_dataset.collate_fn)

    ct2us = CT2US(method="old", device=device)

    for data in progress.tqdm(ct_dataloader, total=len(ct_dataloader), desc="Processing CT Scans"):
        imgs, properties, dest_labels, dest_us = data
        ct2us(imgs, properties, dest_labels, dest_us, step_size, save_labels)

    return "All done!"

In [None]:
#@title Run this for actual UI
%%blocks

with gr.Blocks() as demo:
    gr.Markdown("# CT to US Simulation")
    gr.Markdown("This generates ultrasound images from CT scans by segmentating into tissue types, which are then used to simulate the corresponding images.")
    gr.Interface(fn=process,
                inputs=[
                   "files",
                   gr.Slider(value=5, minimum=1, maximum=50, step=1, label="Step Size"),
                   "checkbox"
                ],
                outputs=["text"])