In [None]:
import os

import tensorflow as tf
import numpy as np
import h5py
import matplotlib.pyplot as plt
import pandas as pd
from dotenv import find_dotenv, load_dotenv
from tqdm import tqdm

from src.models.models import load_model
from src.data.tf_data import TFDataCreator
from src.data.utils import get_split
from src.models.losses import dice_loss, dice_coefficient_hard
from src.models.utils import config_gpu

load_dotenv(find_dotenv())
%matplotlib inline

In [None]:
config_gpu(0, 16)

In [None]:
task = "Task04_Hippocampus"

In [None]:
model_path = "../models/S4Unet__ks_3__nf_4_8_16_32_64___split_0__20221018-172719"

In [None]:
model = load_model(model_path)

In [None]:
split_id = 0
ids_train = get_split(split_id, os.environ["SPLITPATH"])["training"]
ids_val = get_split(split_id, os.environ["SPLITPATH"])["validation"]
ids_test = get_split(split_id, os.environ["SPLITPATH"])["testing"]

In [None]:
file = h5py.File(f"../data/processed/{task}/{task}_training.hdf5", "r")
data_creator = TFDataCreator.get(task.split("_")[0])(
    file,
    #    patch_size=(128, 128, 128),
    shuffle=True,
    params_augmentation={
        "rotation": False,
        "random_center": False,
    })
ds = data_creator.get_tf_data_with_id(ids_test).batch(4)

In [None]:
x, y, image_id = next(ds.as_numpy_iterator())

In [None]:
x.shape

In [None]:
y_pred = model(x)

In [None]:
s = 32
b = 3
plt.subplot(1, 3, 1)
plt.imshow(x[b, :, :, s, 0])
plt.subplot(1, 3, 2)
plt.imshow(y[b, :, :, s, 1])
plt.subplot(1, 3, 3)
plt.imshow(y_pred[b, :, :, s, 1])



In [None]:
plt.subplot(1, 3, 1)
plt.imshow(x[b, :, :, s, 0])
plt.subplot(1, 3, 2)
plt.imshow(y[b, :, :, s, 2])
plt.subplot(1, 3, 3)
plt.imshow(y_pred[b, :, :, s, 2])



In [None]:
plt.imshow(y[b, :, :, s, 2])

In [None]:
model.summary()

In [None]:
np.unique(y_pred[b, :, :, s, 0])

In [None]:
results = pd.DataFrame()
i = 0
for x, y, image_ids in ds:
    y_pred = model(x).numpy()
    dices_1 = dice_coefficient_hard(y[..., 1], y_pred[..., 1]).numpy()
    dices_2 = dice_coefficient_hard(y[..., 2], y_pred[..., 2]).numpy()
    for b in range(y.shape[0]):
        results = pd.concat([
            results,
            pd.DataFrame(
                {
                    "dice_1": dices_1[b],
                    "dice_2": dices_2[b]
                },
                index=[image_ids[b].numpy().decode("utf-8")],
            )
        ])
        i += 1


In [None]:
results.describe()

In [None]:
results