In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict
import segmentation_models as sm
import gc
from pipeline import Pipeline
import matplotlib.pyplot as plt
from tqdm import tqdm

## Load the test set, build the pipeline

In [None]:
data_params = {
    "data_dir": "./dataset/",
    "patch_size": 512,
    "downsampling": 1.0,
    # "layers": 40,
    # "z_start": 0,
    'layers': list(range(11, 24)) + list(range(27, 36)),
    "batch_size": 4,
    'train_transform': None,  # either None or this
    'use_adapt_hist': False
}

backbone = 'resnet18'

pipeline = Pipeline(**data_params)

volume_a, mask_a, _ = pipeline.load_sample(split="test", index='a')
volume_a = sm.get_preprocessing(backbone)(volume_a)

volume_b, mask_b, _ = pipeline.load_sample(split="test", index='b')
volume_b = sm.get_preprocessing(backbone)(volume_b)

volume_3, mask_3, labels_3 = pipeline.load_sample(split="train", index='3')
volume_3 = sm.get_preprocessing(backbone)(volume_3)

# volume_1, mask_1, labels_1 = pipeline.load_sample(split="train", index='1')
# volume_1 = sm.get_preprocessing(backbone)(volume_1)


gc.collect()
print("Loading complete.")

## Load Model and Checkpoint

In [None]:
CHECKPOINT_PATH = 'logs/20230521-102231resnet18_adam_jaccard_noadapthist_transform_512_12-24_and_28-36_layers/'
CHECKPOINT_NAME = 'checkpoint'

model = sm.Unet(
    'resnet18',
    input_shape=pipeline.get_input_shape(),
    encoder_weights=None,
    classes=1
)

model.load_weights(f'{CHECKPOINT_PATH}/checkpoint')

## Predict and Assemble

In [None]:
from train_utils import predict_and_assemble

threshold = 0.95
all_pred_3, all_binary_pred_3, pred_3 = predict_and_assemble(pipeline, volume_3, mask_3, threshold, model)
# all_pred, all_binary_pred, pred = predict_and_assemble(pipeline,volume_1, mask_1, threshold, model)
all_pred_a, all_binary_pred_a, pred_a = predict_and_assemble(pipeline, volume_a, mask_a, threshold, model)
all_pred_b, all_binary_pred_b, pred_b = predict_and_assemble(pipeline, volume_b, mask_b, threshold, model)

## Plot the result

#### Volume 3

In [None]:
threshold = 0.5
fig, axs = plt.subplots(1, 3, figsize=(10, 3))
axs[0].imshow(labels_3, cmap='gray')
axs[0].set_title("Ground Truth")
axs[1].imshow(all_pred_3, cmap='gray')
axs[1].set_title("Predictions")
axs[2].imshow(all_pred_3 > threshold, cmap='gray')
axs[2].set_title(f'Binary Predictions')
# plt.savefig(f'imgs/{CHECKPOINT_NAME}-threshold-{threshold}-volume-3-compare.svg')

#### Test a

In [None]:
threshold = 0.95
fig, axs = plt.subplots(1, 3, figsize=(12, 2.5))
axs[0].imshow(mask_a, cmap='gray')
axs[0].set_title("Mask")
axs[1].imshow(all_pred_a, cmap='gray')
axs[1].set_title("Predictions")
axs[2].imshow(all_pred_a > threshold, cmap='gray')
axs[2].set_title(f'Binary Predictions')
fig.suptitle("Predictions on test set a")
# plt.savefig(f'imgs/{CHECKPOINT_NAME}-threshold-{threshold}-test-a-compare.svg')

#### Test b

In [None]:
threshold = 0.95
fig, axs = plt.subplots(1, 3, figsize=(10, 3))
axs[0].imshow(mask_b, cmap='gray')
axs[0].set_title("Mask")
axs[1].imshow(all_pred_b, cmap='gray')
axs[1].set_title("Predictions")
axs[2].imshow(all_pred_b > threshold, cmap='gray')
axs[2].set_title(f'Binary Predictions')
fig.suptitle("Predictions on test set b")
# plt.savefig(f'imgs/{CHECKPOINT_NAME}-threshold-{threshold}-test-b-compare.svg')

## Submission file

In [None]:
def rle(output):
    flat_img = np.where(output.flatten() > 0.5, 1, 0).astype(np.uint8)
    starts = np.array((flat_img[:-1] == 0) & (flat_img[1:] == 1))
    ends = np.array((flat_img[:-1] == 1) & (flat_img[1:] == 0))
    starts_ix = np.where(starts)[0] + 2
    ends_ix = np.where(ends)[0] + 2
    lengths = ends_ix - starts_ix
    return " ".join(map(str, sum(zip(starts_ix, lengths), ())))

In [None]:
submission = defaultdict(list)

submission["Id"].append("a")
submission["Predicted"].append(rle(all_pred_a))
submission["Id"].append("b")
submission["Predicted"].append(rle(all_pred_b))

pd.DataFrame.from_dict(submission).to_csv(f'submission/submission-{CHECKPOINT_NAME}.csv', index=False)

In [None]:
df = pd.DataFrame.from_dict(submission)
df