In [None]:
import os

import tensorflow as tf
import numpy as np
import h5py
import matplotlib.pyplot as plt
import pandas as pd

from src.models.models import unet_model
from src.features.extract_features import reshape_image_unet
from src.models.utils import predict_volume
from src.data.tf_data_hdf5 import preprocess_image, get_bb_mask_voxel
from src.data.utils import get_split
from src.models.evaluation import evaluate_pred_volume

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
model = unet_model(3, input_shape=(None, None, 3), upsampling_kind="upsampling")

In [None]:
model_trained = tf.keras.models.load_model(
    "/home/valentin/python_wkspce/plc_segmentation/models/unet__prtrnd_True__a_0.9__wt_1.0__wl_0.0upsmpl_upsampling__split_0__ovrsmpl_True__con_nothing20211215-131841/model_weight",
    compile=False)


In [None]:
model.set_weights(model_trained.get_weights())

In [None]:
model.summary()

In [None]:
h5_file = h5py.File("../data/processed/hdf5_2d/data.hdf5", "r")
patient_list = list(h5_file)
patient_list.remove("PatientLC_63")  # Just one lung
patient_list.remove("PatientLC_72")  # the same as 70

clinical_df = pd.read_csv("../data/clinical_info.csv").set_index("patient_id")

In [None]:
ids_train,ids_val,ids_test = get_split(0)

In [None]:
evaluate_pred_volume(model, ids_test, h5_file, clinical_df)

In [None]:
evaluate_pred_volume(model, ids_val, h5_file, clinical_df)

In [None]:
patient = "PatientLC_36"
image = h5_file[patient]["image"][()]
mask = h5_file[patient]["mask"][()]

In [None]:
plt.subplot(121)
plt.imshow(image[:,:,15,0])
plt.subplot(122)
plt.imshow(mask[:,:,15,1])

In [None]:
image.shape

In [None]:
image = reshape_image_unet(image, mask[..., 2] + mask[..., 3])
image = preprocess_image(image)
prediction = predict_volume(image, model)
 

In [None]:
prediction.shape

In [None]:
np.sum(prediction[..., 1] > 0.5)

In [None]:
get_bb_mask_voxel(mask[:, :, :, 1])

In [None]:
plt.imshow(mask[:, :, 30, 1])


In [None]:
plt.imshow(prediction[:, :, 40, 1]> 0.5)


In [None]:
plt.imshow(image[:, :, 15, 0])
plt.colorbar()


In [None]:
np.std(image[:, :, :, 1])