In [None]:
import segmentation_models as sm
from pipeline import Pipeline
import gc
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
data_dir = "./dataset/"
patch_size = 128  # e.g. 128x128
downsampling = 0.75  # setting this to e.g. 0.5 means images will be loaded as 2x smaller. 1 does nothing.
z_dim = 40   # number of slices in the z direction. max value is 65 - z_start
z_start = 0  # offset of slices in the z direction
batch_size = 16
epochs = 100
steps_per_epoch = 50
val_step = 50

pipeline = Pipeline(data_dir, patch_size, downsampling, z_dim, z_start, batch_size)
volume_3, mask_3, labels_3 = pipeline.load_sample(split="train", index=3)

gc.collect()
print("Loading complete.")

In [None]:
# for now let's only
# train_ds, val_ds = pipeline.make_datasets_for_fold(dev_folds['dev_1'])

In [None]:
model = sm.Unet(
    'resnet50',
    input_shape=pipeline.get_input_shape(),
    encoder_weights=None,
    classes=1
)
model.load_weights('chkpt/checkpoint')

In [None]:
val_ds = pipeline.make_iterated_data_generator(volume_3, mask_3, labels_3)
threshold = 0.5

fig, ax = plt.subplots(1, 2)
for i, val in enumerate(val_ds()):
    patch, label = val
    pred = model.predict(np.expand_dims(patch, 0))[0]
    ax[0].imshow((pred > threshold).astype(float), cmap='gray')
    ax[0].set_title("Prediction")
    ax[1].imshow(label, cmap='gray')
    ax[1].set_title("Label")
    fig.savefig(f'predictions/{i}.png')
