In [None]:
%matplotlib widget

import os
import matplotlib.pyplot as plt
import ipywidgets as widgets
import nibabel as nib
import pydicom
import numpy as np
import pandas as pd

In [None]:
def plot_image(filepath=None, img_data=None, **kwargs):
    if filepath is not None:
        nii_img = nib.load(os.path.expanduser(filepath))
        nii_data = nii_img.get_fdata()
    elif img_data is not None:
        nii_data = img_data
    else:
        raise ValueError("Either filepath or img_data must be not None")

    plt.ioff()
    fig = plt.figure()
    plt.ion()
    im = plt.imshow(nii_data[...,0], vmin=nii_data.min(), vmax=nii_data.max(), **kwargs)

    out = widgets.Output()
    @out.capture()
    def update(change):
        with out:
            if change['name'] == 'value':
                im.set_data(nii_data[...,change['new']])
                fig.canvas.draw_idle
        
    slider = widgets.IntSlider(value=0, min=0, max=nii_data.shape[-1]-1)
    slider.observe(update)
    display(widgets.VBox([slider, fig.canvas]))
    display(out)

In [None]:
def plot_multiple_images(filepaths, **kwargs):
    images_data = []
    depth = None
    for fpath in filepaths:
        nii_img = nib.load(os.path.expanduser(fpath))
        images_data.append(nii_img.get_fdata())
        if depth is None:
            depth = images_data[-1].shape
        else:
            assert depth == images_data[-1].shape, "All images must have the same shape!"

    plt.ioff()
    fig, axes = plt.subplots(1,len(filepaths))
    plt.ion()

    ims = []
    for i in range(len(filepaths)):
        ims.append(axes[i].imshow(images_data[i][...,0], vmin=images_data[i].min(), vmax=images_data[i].max(), **kwargs))
    out = widgets.Output()
    @out.capture()
    def update(change):
        with out:
            if change['name'] == 'value':
                for i in range(len(filepaths)):
                    ims[i].set_data(images_data[i][...,change['new']])
                fig.canvas.draw_idle
        
    slider = widgets.IntSlider(value=0, min=0, max=depth[-1]-1)
    slider.observe(update)
    display(widgets.VBox([slider, fig.canvas]))
    display(out)

In [None]:
dataset_path = os.path.expanduser(os.path.expandvars("~/data/medicaldecathlon/Task10_Colon"))
train_dataset_img_path = os.path.join(dataset_path,"imagesTr")
test_dataset_img_path = os.path.join(dataset_path,"imagesTs")
train_dataset_label_path = os.path.join(dataset_path,"labelsTr")
train_dataset_frailty_path = os.path.join(dataset_path,"train_clean.csv")
test_dataset_frailty_path = os.path.join(dataset_path,"test_clean.csv")

In [None]:
train_image_filenames = sorted(filter(lambda s: not s.startswith("."), os.listdir(os.path.join(train_dataset_img_path))))
test_image_filenames = sorted(filter(lambda s: not s.startswith("."), os.listdir(os.path.join(test_dataset_img_path))))
train_label_filenames = sorted(filter(lambda s: not s.startswith("."), os.listdir(os.path.join(train_dataset_label_path))))

In [None]:
train_image_filenames[20]

In [None]:
img_idx = 20

plot_multiple_images([
    os.path.join(train_dataset_img_path, train_image_filenames[img_idx],),
    os.path.join(train_dataset_label_path, train_label_filenames[img_idx],),
], cmap="gray")

In [None]:
b = nib.load(os.path.join(train_dataset_label_path, train_label_filenames[0])).get_fdata()
df = pd.DataFrame(
    dict(zip(("x","y","z"),np.where(b!=0)))
)
df[["x","y"]].hist(sharex=True, sharey=True, bins=max(b.shape), range=(0,max(b.shape)), density=True,)
df[["z"]].hist(bins=b.shape[2], range=(0,b.shape[2]), density=True,)

In [None]:
df_labels = pd.read_csv(train_dataset_frailty_path, index_col="PatientID")

df_labels.loc[df_labels["Risk Category"]=="LOW","Risk Category"] = 0
df_labels.loc[df_labels["Risk Category"]=="MEDIUM","Risk Category"] = 1
df_labels.loc[df_labels["Risk Category"]=="HIGH","Risk Category"] = 2
df_labels = df_labels.astype(int)

df_labels

In [None]:
df_labels.hist(density=False,)

In [None]:
df_labels.value_counts(["Risk Category"])

In [None]:
df_labels_test = pd.read_csv(test_dataset_frailty_path, index_col="PatientID")

df_labels_test.loc[df_labels_test["Risk Category"]=="LOW","Risk Category"] = 0
df_labels_test.loc[df_labels_test["Risk Category"]=="MEDIUM","Risk Category"] = 1
df_labels_test.loc[df_labels_test["Risk Category"]=="HIGH","Risk Category"] = 2
df_labels_test = df_labels_test.astype(int)

df_labels_test.hist()

In [None]:
df_labels.shape

In [None]:
df_labels[["Risk Category"]].hist(density=True,)

In [None]:
img_values = {
    "idx":[],
    "min":[],
    "p25":[],
    "p50":[],
    "p75":[],
    "max":[],
    "mass":[],
    "vol":[],
    "density":[],
}

for img_idx in range(len(train_image_filenames)):
    img_path = os.path.join(train_dataset_img_path, train_image_filenames[img_idx])
    seg_path = os.path.join(train_dataset_label_path, train_label_filenames[img_idx])

    img = nib.load(os.path.expanduser(img_path)).get_fdata()
    seg = nib.load(os.path.expanduser(seg_path)).get_fdata()
    assert(np.max(seg) in [0,1]), "AAA"
    seg = seg.astype(np.bool)

    seg_mass = np.sum(img[seg])
    seg_vol = np.sum(seg)
    seg_density = seg_mass/seg_vol
    print(f"{img_idx} {os.path.basename(img_path)}\tdensity {seg_density:.3f} = mass {seg_mass:.3f} / vol {seg_vol:.3f}")
    for key, value in zip(
        img_values.keys(),
        [img_idx, np.quantile(img, 0.0), np.quantile(img, 0.25), np.quantile(img, 0.50), np.quantile(img, 0.75), np.quantile(img, 1.00), seg_mass, seg_vol, seg_density]
    ):
        img_values[key].append(value)


In [None]:
print(pd.DataFrame(img_values).describe())

In [None]:
for eval_dir in map(lambda d: os.path.expanduser(os.path.expandvars(os.path.join("~/data/shade2022/validation_cleaned",d))), ["covid/images", "kidney/images"]):
    img_paths = []
    for data_handle in [f for f in os.listdir(eval_dir) if not f.startswith(".DS_Store")]:
        data_dir = os.path.join(eval_dir, data_handle)
        img_paths.append(os.path.join(data_dir, [f for f in os.listdir(data_dir) if not f.startswith(".DS_Store")][0]))
        scan = nib.load(img_paths[-1]).get_fdata()
        print(img.min(), img.max(), img_paths[-1])
        scan_mod = np.clip(scan, a_min=-120, a_max=240)
        plot_image(
            img_data=scan_mod,
            cmap="gray")
        break
    break