# Testing pipeline

Step 0, 1 and 2 are to develop and test the different steps of the pipeline. 

Later, you only need to (pre-)train the bbox and segmentation models directly from the terminal (see README.md). 

Once you have those two models you can start simulating/testing the Challenge inference and evaluation pipeline.
For that, go directly to step 3 and 4.

## 0. Initial setup

### 0.1 Load and define

In [None]:
# Load libraries
import matplotlib.pyplot as plt

from helper_code import *
from team_code import *

print(f"DEVICE: {DEVICE}")
print(f"WORLD_SIZE: {WORLD_SIZE}")

In [None]:
# Define data paths
data_training_path = f"{ROOT}/data/ptb-xl/Dataset{X_FREQUENCY}_Signals/imagesTr_original"
data_vali_path = f"{ROOT}/data/ptb-xl/Dataset{X_FREQUENCY}_Signals/imagesTv_original"
data_test_path = f"{ROOT}/data/ptb-xl/Dataset{X_FREQUENCY}_Signals/imagesTs_original"
data_output_path = "data/"
model_folder = "model/"
model_folder_checkpoints = "model/checkpoints/"

### 0.2 Inspect

In [None]:
# Load the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=None,
    single_signals=False,
    run_in_parallel=False,
)

In [None]:
# Check the mask. It should not have a border and only two kind of values (0 and 1)
mask = data[1][0]["mask"]
print(data[1][0]["info_dict"]["image_path"])
print(mask.shape)
print(mask[0])
print(pd.Series(mask[0].flatten()).value_counts())

In [None]:
# Show test_images with test_masks on top
inspection_plots(loader_to_use=loader[1], num_images_to_plot=5)

## 1. Total image

### 1.1 Test rotation

In [None]:
# Load the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=None,
    run_in_parallel=False,
)

In [None]:
# Get predicted rotation angles
rot_angle_list = {}
for batch_idx, batch_dicts in enumerate(loader[1]):
    for j in range(len(batch_dicts["image"])):
        image = batch_dicts["image_original"][j].permute(1, 2, 0)
        image = image.numpy().astype(np.uint8)
        lines = get_lines(image, threshold_HoughLines=1200)
        filtered_lines = filter_lines(
            lines, degree_window=30, parallelism_count=3, parallelism_window=2
        )
        if filtered_lines is None:
            rot_angle = np.nan
        else:
            rot_angle = get_median_degrees(filtered_lines)
        rot_angle_list[batch_dicts["info_dict"]["image_path"][j]] = {
            "rot_angle_predicted": rot_angle,
            "rot_angle_predicted_loader": np.float64(
                batch_dicts["info_dict"]["rot_angle_predicted"][j].numpy()
            ),
            "actual_rotation": batch_dicts["info_dict"]["rotation"][j],
            "image": image,
            "lines": lines,
            "filtered_lines": filtered_lines,
        }

# Compare if loader and step wise prediction are the same
predicted_as_loader = sum(
    [
        1
        for k, v in rot_angle_list.items()
        if (v["rot_angle_predicted"] == v["rot_angle_predicted_loader"])
        or (
            np.isnan(v["rot_angle_predicted"])
            and np.isnan(v["rot_angle_predicted_loader"])
        )
    ]
)
print(f"Same as loader: {predicted_as_loader} out of {len(rot_angle_list)}")
for k, v in rot_angle_list.items():
    if (v["rot_angle_predicted"] != v["rot_angle_predicted_loader"]) and not (
        np.isnan(v["rot_angle_predicted"]) and np.isnan(v["rot_angle_predicted_loader"])
    ):
        print(
            f"File: {k}, predicted: {v['rot_angle_predicted']}, loader predicted: {v['rot_angle_predicted_loader']}"
        )

# Check how many are correctly predicted
correctly_predicted = sum(
    [
        1
        for k, v in rot_angle_list.items()
        if v["rot_angle_predicted"] == v["actual_rotation"]
        or (np.isnan(v["rot_angle_predicted"]) and np.isnan(v["actual_rotation"]))
    ]
)
print(f"Correctly predicted: {correctly_predicted} out of {len(rot_angle_list)}")
for k, v in rot_angle_list.items():
    if (v["rot_angle_predicted"] != v["actual_rotation"]) and not (
        np.isnan(v["rot_angle_predicted"]) and np.isnan(v["actual_rotation"])
    ):
        print(
            f"File: {k}, predicted: {v['rot_angle_predicted']}, actual: {v['actual_rotation']}"
        )
        if True:  # Print image
            final_image = get_image_with_lines(v["image"], v["lines"])
            final_image.show()
            if v["filtered_lines"] is not None and len(v["filtered_lines"]) > 0:
                filtered_image = get_image_with_lines(v["image"], v["filtered_lines"])
                filtered_image.show()
            else:
                print("No filtered lines to plot")

### 1.2 Test getting scale info

In [None]:
# Load the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=None,
    run_in_parallel=False,
)

In [None]:
# Inspect # TODO: Check why this is predicting too big grids sometimes
threshold = 0.01  # In percent
count_correct = 0
count_total = 0
deviation_x = []
deviation_y = []
for batch_idx, batch_dicts in enumerate(loader[0]):
    for j in range(len(batch_dicts["image"])):
        pixels_per_grid_predicted = float(
            batch_dicts["info_dict"]["pixels_per_grid_predicted"][j]
        )
        x_grid = float(batch_dicts["info_dict"]["x_grid"][j])
        y_grid = float(batch_dicts["info_dict"]["y_grid"][j])
        sec_per_pixel_predicted = float(
            batch_dicts["info_dict"]["sec_per_pixel_predicted"][j]
        )
        mV_per_pixel_predicted = float(
            batch_dicts["info_dict"]["mV_per_pixel_predicted"][j]
        )
        sec_per_pixel = float(batch_dicts["info_dict"]["sec_per_pixel"][j])
        mV_per_pixel = float(batch_dicts["info_dict"]["mV_per_pixel"][j])
        deviation_x.append(abs(sec_per_pixel_predicted - sec_per_pixel) / sec_per_pixel)
        deviation_y.append(abs(mV_per_pixel_predicted - mV_per_pixel) / mV_per_pixel)
        if (abs(pixels_per_grid_predicted - x_grid) / x_grid > threshold) or (
            abs(pixels_per_grid_predicted - y_grid) / y_grid > threshold
        ):
            print(
                f"File: {batch_dicts['info_dict']['image_path'][j]}, x_grid: {x_grid}, y_grid: {y_grid}, pixels_per_grid_predicted: {pixels_per_grid_predicted}"
            )
            image_with_grid_lines = batch_dicts["image"][j]
            x_min = int(
                IMAGES_PARTS_FOR_GRID_PREDICTION[0] * image_with_grid_lines.shape[2]
            )
            y_min = int(
                IMAGES_PARTS_FOR_GRID_PREDICTION[1] * image_with_grid_lines.shape[1]
            )
            x_max = int(
                IMAGES_PARTS_FOR_GRID_PREDICTION[2] * image_with_grid_lines.shape[2]
            )
            y_max = int(
                IMAGES_PARTS_FOR_GRID_PREDICTION[3] * image_with_grid_lines.shape[1]
            )
            image_cropped_np = (
                image_with_grid_lines.permute(1, 2, 0)
                .numpy()
                .astype(np.uint8)[y_min:y_max, x_min:x_max]
            )
            lines = get_lines(image_cropped_np, threshold_HoughLines=430)
            lines_filtered = filter_lines(lines, degree_window=5, parallelism_count=1)
            im_to_show = get_image_with_lines(image_cropped_np, lines_filtered)
            im_to_show.show()
        else:
            count_correct += 1
        count_total += 1
print(f"Correctly predicted: {count_correct} out of {count_total}")
print(
    f"Mean deviation sec_per_pixel: {round(np.nanmean(deviation_x),4)}, mean deviation mV_per_pixel: {round(np.nanmean(deviation_y),4)}"
)

### 1.3 Bounding box model

#### Train

In [None]:
# Get the model and weights
model_pretrained, preprocess = get_bbox_model(
    "FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1"
)

In [None]:
# Get the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=preprocess,
    run_in_parallel=False,
)

In [None]:
# Initiate trainer
params = [p for p in model_pretrained.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=LR_BBOX, momentum=MOMENTUM_BBOX, weight_decay=WEIGHT_DECAY_BBOX
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=STEP_SIZE_BBOX, gamma=GAMMA_BBOX
)
trainer = Trainer(
    model=model_pretrained,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    num_epochs=NUM_EPOCHS_BBOX,
    device=DEVICE,
    target_transform=get_bbox_type_targets,
    input_transform=get_bbox_inputs,
    model_dir=model_folder_checkpoints,
    criterion=None,  # We will use the loss function from the model
    run_in_parallel=RUN_IN_PARALLEL,
)

In [None]:
# Train
trained_model = trainer.fit(training_dataloader=loader[0], vali_dataloader=loader[1])

In [None]:
# Save the model
save_torch_model(model_folder, trained_model, "lead_bbox_detection")

#### Apply

In [None]:
# Load the model
model, preprocess = get_bbox_model(box_score_thresh=0.01)
finetuned_weights = torch.load(model_folder + "lead_bbox_detection.pth")
model.load_state_dict(finetuned_weights)

In [None]:
# Get the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=preprocess,
    run_in_parallel=False,
)

In [None]:
# Predict
image_paths = []
outputs = []
image_heights = []
input_images = {}
model.eval()
test_loader = loader[1]  # Use vali_loader for now
for i, batch_dict in enumerate(test_loader):
    print(f"Predicting batch {i + 1} of {len(test_loader)}")
    with torch.no_grad():
        images = batch_dict["image"]
        batch_output = model(images)
        for i in range(len(batch_output)):
            image_path = batch_dict["info_dict"]["image_path"][i]
            image_paths.append(image_path)
            outputs.append(batch_output[i])
            image_heights.append(batch_dict["image"][i].shape[1])
            input_images[image_path] = batch_dict["image"][i]

In [None]:
# Reorder and get the bbox with the highest score
bbox_per_image = {}
for image_path, output, image_height in zip(image_paths, outputs, image_heights):
    bbox_per_image[image_path] = {}
    bbox_per_image[image_path] = select_highest_scored_box(output)
    bbox_per_image[image_path]["image_height"] = image_height

In [None]:
# Plot
num_masks_to_plot = 5
fig, ax = plt.subplots(num_masks_to_plot, 1, figsize=(10, 5 * num_masks_to_plot))
i = 0
for img_path in bbox_per_image.keys():
    img = input_images[img_path]
    img = (255.0 * (img - img.min()) / (img.max() - img.min())).to(torch.uint8)
    img = img[:3, ...]

    bboxes = {
        k: v for k, v in bbox_per_image[image_path].items() if k != "image_height"
    }

    pred_boxes = torch.tensor(list(bboxes.values())).long()
    pred_labels = [str(k) for k in bboxes.keys()]
    colors = ["red"] * len(pred_labels)

    output_image = draw_bounding_boxes(
        img,
        pred_boxes,
        pred_labels,
        colors=colors,
        width=2,
        font_size=40,
        font="arial.ttf",
    )

    ax[i].imshow(output_image.permute(1, 2, 0))
    ax[i].set_title(os.path.basename(img_path))
    i += 1
    if i >= num_masks_to_plot:
        break
plt.show()

In [None]:
# Convert to json format
json_per_image = {}
for image_path in bbox_per_image.keys():
    json_per_image[image_path] = bbox_prediction_to_json(bbox_per_image[image_path])

# Save all the jsons
if False:
    for image_path, json_content in json_per_image.items():
        json_path = image_path.replace(".png", ".json")
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                existing_json_content = json.load(f)
            existing_json_content.update(json_content)
            json_content_to_store = existing_json_content
        else:
            json_content_to_store = json_content
        json_object = json.dumps(json_content_to_store, indent=4)
        with open(json_path, "w") as f:
            f.write(json_object)
else:
    print(json_per_image)

## 2. Lead level

### 2.1 Segmentation

#### Prepare

In [None]:
# Set environment variables
os.environ["nnUNet_raw"] = NNUNET_RAW
os.environ["nnUNet_preprocessed"] = NNUNET_PREPROCESSED
os.environ["nnUNet_results"] = NNUNET_RESULTS

#### Train

In [None]:
# See README.md

#### Apply single signal

In [None]:
# Predict
! nnUNetv2_predict -d Dataset200_SingleSignals -i /data/wolf6245/data/ptb-xl/Dataset200_SingleSignals/imagesTv -o data/nnUNet_output/Dataset200_SingleSignals -f  0 -tr nnUNetTrainer -c 2d -p nnUNetPlans

In [None]:
# Postprocess
! nnUNetv2_apply_postprocessing -i data/nnUNet_output/Dataset200_SingleSignals -o data/nnUNet_output_pp/Dataset200_SingleSignals -pp_pkl_file data/nnUNet_results/Dataset200_SingleSignals/nnUNetTrainer__nnUNetPlans__2d/crossval_results_folds_0/postprocessing.pkl -np 8 -plans_json data/nnUNet_results/Dataset200_SingleSignals/nnUNetTrainer__nnUNetPlans__2d/crossval_results_folds_0/plans.json

In [None]:
# Load all masks and compare
image_folder = f"{ROOT}/data/ptb-xl/Dataset200_SingleSignals/imagesTv"
mask_folder = "data/nnUNet_output_pp/Dataset200_SingleSignals"
masks = os.listdir(mask_folder)
mask_names = [os.path.basename(mask_p) for mask_p in masks]
image_paths = [
    os.path.join(image_folder, mask_name.replace(".png", "_0000.png"))
    for mask_name in mask_names
]
mask_paths = [os.path.join(mask_folder, mask_name) for mask_name in mask_names]
image_mask_paths = list(zip(image_paths, mask_paths))
random.shuffle(image_mask_paths)
assert (
    len(image_paths) == len(mask_paths) == len(image_mask_paths)
), "Number of images and masks do not match"

In [None]:
# Plot
num_masks_to_plot = 5
fig, ax = plt.subplots(num_masks_to_plot, 1, figsize=(10, 2 * num_masks_to_plot))
i = 0
for img_path, mask_path in image_mask_paths:
    img_name = os.path.basename(img_path)
    example_img = read_image(img_path)
    example_mask = read_image(mask_path)
    ax[i] = plot_image(ax[i], example_img, example_mask)
    ax[i].set_title(img_name)
    i += 1
    if i >= num_masks_to_plot:
        break
plt.show()

#### Apply whole image version

In [None]:
# Predict
! nnUNetv2_predict -d Dataset300_FullImages -i /data/wolf6245/data/ptb-xl/Dataset300_FullImages/imagesTv -o data/nnUNet_output/Dataset300_FullImages -f  0 -tr nnUNetTrainer -c 2d -p nnUNetPlans

In [None]:
# Postprocess
! nnUNetv2_apply_postprocessing -i data/nnUNet_output/Dataset300_FullImages -o data/nnUNet_output_pp/Dataset300_FullImages -pp_pkl_file data/nnUNet_results/Dataset300_FullImages/nnUNetTrainer__nnUNetPlans__2d/crossval_results_folds_0/postprocessing.pkl -np 8 -plans_json data/nnUNet_results/Dataset300_FullImages/nnUNetTrainer__nnUNetPlans__2d/crossval_results_folds_0/plans.json

In [None]:
# Load all masks and compare
image_folder = f"{ROOT}/data/ptb-xl/Dataset300_FullImages/imagesTv"
mask_folder = "data/nnUNet_output_pp/Dataset300_FullImages"
masks = os.listdir(mask_folder)
mask_names = [os.path.basename(mask_p) for mask_p in masks]
image_paths = [
    os.path.join(image_folder, mask_name.replace(".png", "_0000.png"))
    for mask_name in mask_names
]
mask_paths = [os.path.join(mask_folder, mask_name) for mask_name in mask_names]
image_mask_paths = list(zip(image_paths, mask_paths))
random.shuffle(image_mask_paths)
assert (
    len(image_paths) == len(mask_paths) == len(image_mask_paths)
), "Number of images and masks do not match"

In [None]:
# Plot
num_masks_to_plot = 50
fig, ax = plt.subplots(num_masks_to_plot, 1, figsize=(10, 5 * num_masks_to_plot))
i = 0
for img_path, mask_path in image_mask_paths:
    img_name = os.path.basename(img_path)
    example_img = read_image(img_path)
    example_mask = read_image(mask_path)
    ax[i] = plot_image(ax[i], example_img, example_mask)
    ax[i].set_title(img_name)
    i += 1
    if i >= num_masks_to_plot:
        break
plt.show()

### 2.3 Vectorisation

#### Get example data

In [None]:
# Get the data
records, data, loader = dataloader_wrapper(
    list_of_paths=[data_training_path, data_vali_path, data_test_path],
    test_settings=[False, False, True],
    shuffle_settings=[True, False, False],
    transform=None,
    single_signals=True,
    run_in_parallel=False,
)

In [None]:
# Get one example
batch_dict = next(iter(loader[1]))

# Get info for one image from batch
j = 0

# Get example
image = batch_dict["image"][j]
mask = batch_dict["mask"][j]

# Get info
image_path = batch_dict["info_dict"]["image_path"][j]
record_path = batch_dict["info_dict"]["signal_path"][j]
sec_per_pixel = batch_dict["info_dict"]["sec_per_pixel"][j]
mV_per_pixel = batch_dict["info_dict"]["mV_per_pixel"][j]
lead_name = batch_dict["info_dict"]["lead_name"][j]
original_size_image = (
    batch_dict["info_dict"]["original_size_image"][0][j].item(),
    batch_dict["info_dict"]["original_size_image"][1][j].item(),
    batch_dict["info_dict"]["original_size_image"][2][j].item(),
)
original_size_mask = (
    batch_dict["info_dict"]["original_size_mask"][0][j].item(),
    batch_dict["info_dict"]["original_size_mask"][1][j].item(),
    batch_dict["info_dict"]["original_size_mask"][2][j].item(),
)

# Re-resize the mask and image
print(f"Old image shape: {image.shape}, mask shape: {mask.shape}")
image = resize(image, (original_size_image[1], original_size_image[2]))
mask = resize(mask, (original_size_mask[1], original_size_mask[2]))
print(f"Now image shape: {image.shape}, mask shape: {mask.shape}")

In [None]:
# Crop the mask and the image to positive mask area
crop_to_mask = True
if crop_to_mask:
    image = cut_to_mask(image, mask)
    mask = cut_to_mask(mask, mask)

#### Inspect

In [None]:
# Show masked image
image_np = image.permute(1, 2, 0).numpy().astype(np.uint8)
mask_np = mask[:1, :, :].squeeze().numpy().astype(np.uint8)
masked_image = image_np
masked_image[mask_np >= 1] = [0, 255, 0]
plt.imshow(masked_image)

#### Scale

In [None]:
# Get scaling info
sec_per_box = sec_per_pixel * mask.shape[2]
mV_per_box = mV_per_pixel * mask.shape[1]
x_frequency = X_FREQUENCY
total_seconds = round(sec_per_pixel.item() * mask.shape[2], 1)
values_needed = int(total_seconds * x_frequency)

# Get mask values
non_zero_mean = torch.tensor(
    [
        torch.mean(torch.nonzero(mask[0, :, i]).type(torch.float32))
        for i in range(mask.shape[2])
    ]
)

# y-scale by shifting
zero_pixel = mask.shape[1] / 2
predicted_signal = (zero_pixel - non_zero_mean) * mV_per_pixel

# x-scale by interpolation
n = predicted_signal.shape[0]
data_reshaped = predicted_signal.view(1, 1, n)
resampled_data = F.interpolate(
    data_reshaped, size=values_needed, mode="linear", align_corners=False
)
predicted_signal_sampled = resampled_data.view(-1)
print(
    f"Predicted signal length: {predicted_signal.shape[0]}, interpolated signal length: {predicted_signal_sampled.shape[0]}"
)

#### Plot

In [None]:
# Load the original signal
print(f"Using signal from {record_path} with lead {lead_name} from image {image_path}")
label_signal, label_fields = load_signals(record_path)
mask_signal = reorder_signal(label_signal, label_fields["sig_name"], [lead_name])

# Remove nan values from the signal
original_signal = torch.tensor(mask_signal[:, 0])
original_signal = original_signal[~torch.isnan(original_signal)]
print(
    f"Original signal length: {len(mask_signal)}, after removing nan values: {len(original_signal)}"
)

# Calc difference between corrected and original signal
difference_signal = predicted_signal_sampled - original_signal

In [None]:
# Plot the signal, the original and the difference
y_min = -0.6
y_max = 0.6
fig, ax = plt.subplots(1, 5, figsize=(30, 5))
ax[0].plot(original_signal)
ax[0].set_ylim(y_min, y_max)
ax[0].set_title("Original signal")
ax[1].plot(predicted_signal)
ax[1].set_ylim(y_min, y_max)
ax[1].set_title("Predicted signal not sampled")
ax[2].plot(predicted_signal_sampled)
ax[2].set_ylim(y_min, y_max)
ax[2].set_title("Predicted signal sampled")
ax[3].plot(original_signal)
ax[3].plot(predicted_signal_sampled)
ax[3].set_ylim(y_min, y_max)
ax[3].legend(["Original", "Predicted sampled"])
ax[3].set_title("Original and predicted sampled signal")
ax[4].plot(difference_signal)
ax[4].set_ylim(y_min, y_max)
ax[4].set_title("Difference signal")
plt.show()

## 3. run_model

Here we simulate the model inference on a single record.

### 3.1 Prep from run_model.py

In [None]:
import argparse
from run_model import *
from team_code import *

In [None]:
# Type
binary_signal_masks = False

# Settings to test
use_true_rot = False
use_true_scale = True
use_true_bbox = True
use_true_mask = False

# Print settings
verbose = True

# If we want to crop to mask
crop_to_mask = False

# If nans should be interpolated
interpolate_nans = False

In [None]:
# Args
args = argparse.ArgumentParser()
args.model_folder = "model"
args.output_folder = "data/test_outputs"

if binary_signal_masks:
    args.data_folder = f"{ROOT}/data/ptb-xl/Dataset400_BinarySignals/imagesTs"
else:
    args.data_folder = f"{ROOT}/data/ptb-xl/Dataset500_Signals/imagesTs_original"

In [None]:
# Prep
digitization_model, classification_model = load_models(args.model_folder, True)
records = find_records(args.data_folder)
num_records = len(records)
os.makedirs(args.output_folder, exist_ok=True)

In [None]:
# Simulate one record from loop
i = 2  # Record: Select one record
j = 1  # Image: Select first image, but there should also only be one per signal
data_record = os.path.join(args.data_folder, records[i])
output_record = os.path.join(args.output_folder, records[i])
if verbose:
    print(f"Using record {data_record}")

### 3.2 run_models(data_record, digitization_model, classification_model, args.verbose) -> signals, labels

#### Prepare

In [None]:
# Convert
record = data_record
digitization_model = digitization_model
classification_model = classification_model
header_file = get_header_file(record)
header = load_text(header_file)
num_samples = get_num_samples(header)
num_signals = get_num_signals(header)
signal_names = get_signal_names(header)
verbose = True

In [None]:
# Load image
path = os.path.split(record)[0]
image_files = get_image_files(record)
image_file = image_files[j]
image_file_path = os.path.join(path, image_file)
image = read_image(image_file_path)
image = image[:3]
if verbose:
    print(f"Using image {image_file_path}")
    plt.imshow(image.permute(1, 2, 0).numpy())

# Only for testing, load the json
if any([use_true_rot, use_true_scale, use_true_bbox, use_true_mask]):
    label_signal, label_fields = load_signals(record)
    json_dict, _ = load_json(record)
    json_dict = json_dict[j]
    mask_file_path = image_file_path.replace("/imagesT", "/labelsT").replace("_0000.png", ".png").replace('_original', '')
    mask = read_image(mask_file_path)
    if verbose:
        print("Loaded json and mask.")
        print("Mask distribution:")
        print(pd.Series(mask.numpy().flatten()).value_counts())

#### Rotate

In [None]:
# Rotate
if use_true_rot:
    rot_angle = json_dict["rotate"]
    if verbose:
        print(f"Using true rotation angle: {rot_angle}")
else:
    rot_angle = get_rotation_angle(image.permute(1, 2, 0).numpy().astype(np.uint8))
    if verbose:
        print(f"Using predicted rotation angle: {rot_angle}")

image_rotated = rotate(image, rot_angle)
if verbose:
    plt.imshow(image_rotated.permute(1, 2, 0).numpy())

#### OPTION 1: Bounding boxes and individual signal segmentation

In [None]:
if binary_signal_masks:
    # Get bounding box
    if use_true_bbox:
        lead_bounding_box = json_dict["lead_bounding_box"]
        lead_bounding_box_filtered = filter_for_full_lead(
            lead_bounding_box,
            json_dict["full_mode_lead"]["val"],
            box_type="lead_bounding_box",
            image_path=image_file,
        )
        if verbose:
            print("Using true bboxes")
            print(
                f"Found boxes for {[b['lead_name'] for b in lead_bounding_box_filtered]}"
            )
    else:
        # Predict
        model = digitization_model["bbox_model"]["model"]
        transform = digitization_model["bbox_model"]["preprocess"]
        model.eval()
        image_rotated_transformed = transform(image_rotated)
        with torch.no_grad():
            batch_output = model([image_rotated_transformed])
            output = batch_output[0]
            image_height = image_rotated.shape[1]

        # Reorder and get the bbox with the highest score
        bbox = select_highest_scored_box(output)
        bbox = convert_box_to_integer(bbox)
        bbox["image_height"] = image_height
        lead_bounding_box_filtered = bbox_prediction_to_json(bbox)
        lead_bounding_box_filtered = lead_bounding_box_filtered[
            "lead_bounding_box_predicted"
        ]

        if verbose:
            print(
                f"Using predicted bboxes. Found boxes for {[b['lead_name'] for b in lead_bounding_box_filtered]}"
            )

    if verbose:
        dict_test = {
            "image": image_rotated,
            "info_dict": {
                "lead_bounding_box": lead_bounding_box_filtered,
                "full_mode_lead": json_dict["full_mode_lead"]["val"],
                "image_path": image_file_path,
            },
        }
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax = plot_image_with_torch(ax, dict_test)

In [None]:
if binary_signal_masks:
    # Segment
    # Split images
    signal_images = {}
    signal_positions = {}
    for box in lead_bounding_box_filtered:
        signal_img, y1 = cut_image_to_bbox(image_rotated, box, True)
        signal_positions[box["lead_name"]] = y1
        signal_images[box["lead_name"]] = signal_img
    if verbose:
        print(f"Split images for {signal_images.keys()}")

    # Run segmentation
    signal_masks = {}
    if use_true_mask:
        for box in lead_bounding_box_filtered:
            signal_mask = cut_image_to_bbox(mask, box)
            signal_masks[box["lead_name"]] = signal_mask
        if verbose:
            print("Using true masks.")
    else:
        for lead, img in signal_images.items():  # TODO: Parallelise this
            image_rotated_resized = resize(img, IMG_SIZE)
            signal_mask_predicted_resized = predict_mask_nnunet(
                image_rotated_resized, "Dataset200_SingleSignals"
            )
            signal_mask_predicted = resize(
                signal_mask_predicted_resized, (img.shape[1], img.shape[2])
            )
            signal_masks[lead] = signal_mask_predicted
        if verbose:
            print("Using predicted masks.")

#### OPTION 2: Direct signal segmentation

In [None]:
# Segment
if not binary_signal_masks:
    if use_true_mask:
        if verbose:
            print("Using true, multiclass masks.")
        mask_to_use = mask
    else:
        if verbose:
            print("Using predicted, multiclass masks.")
        mask_to_use = predict_mask_nnunet(
            image_rotated, f"Dataset{X_FREQUENCY}_Signals"
        )

    # Use mask to cut into single, binary masks
    signal_masks, signal_positions, signal_images = cut_binary(mask_to_use, image_rotated, signal_names)
    if verbose:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax = plot_image(ax, image_rotated, mask_to_use)
        ax.set_title(image_file_path)
        plt.show()

#### Crop to mask

In [None]:
# Crop the mask and the image to positive mask area
signal_images_cropped = {}
signal_masks_cropped = {}
signal_positions_cropped = {}
if (
    crop_to_mask
):  # TODO: What to do if mask misses values at the beginning or at the end?
    for lead in signal_images.keys():
        signal_images_cropped[lead], y1 = cut_to_mask(
            signal_images[lead], signal_masks[lead], True
        )
        signal_masks_cropped[lead] = cut_to_mask(signal_masks[lead], signal_masks[lead])
        signal_positions_cropped[lead] = signal_positions[lead] + y1
else:
    signal_images_cropped = signal_images
    signal_masks_cropped = signal_masks
    signal_positions_cropped = signal_positions

#### Vectorisation

In [None]:
# Get scaling info
if use_true_scale:
    x_grid = json_dict["x_grid"]
    y_grid = json_dict["y_grid"]
    mm_per_pixel_x = get_mm_per_pixel(x_grid)
    mm_per_pixel_y = get_mm_per_pixel(y_grid)
    sec_per_pixel = get_sec_per_pixel(mm_per_pixel_x)
    mV_per_pixel = get_mV_per_pixel(mm_per_pixel_y)
    if verbose:
        print(
            f"Using true scaling info: x_grid: {x_grid}, y_grid: {y_grid}, sec_per_pixel: {round(sec_per_pixel,4)}, mV_per_pixel: {round(mV_per_pixel,4)}"
        )
else:
    pixels_per_grid, sec_per_pixel, mV_per_pixel = get_grid_info(image_rotated)
    x_grid = pixels_per_grid
    y_grid = pixels_per_grid
    if verbose:
        print(
            f"Predicted scaling info: {x_grid}, y_grid: {y_grid}, sec_per_pixel: {round(sec_per_pixel,4)}, mV_per_pixel: {round(mV_per_pixel,4)}"
        )

In [None]:
# Print
if verbose:
    print(f"Plotting image and masks.")
    example_img_masked_dict = {}
    fig, ax = plt.subplots(
        len(signal_images_cropped), 1, figsize=(10, 2 * len(signal_images_cropped))
    )
    for i, l in enumerate(signal_images_cropped.keys()):
        example_img = signal_images_cropped[l]
        example_mask = signal_masks_cropped[l]
        if isinstance(example_img, torch.Tensor):
            example_img = example_img.permute(1, 2, 0).numpy().astype(np.uint8)
        if isinstance(example_mask, torch.Tensor):
            example_mask = example_mask[:1, :, :].squeeze().numpy().astype(np.uint8)
        example_img_masked = example_img
        example_img_masked[example_mask >= 1] = [0, 255, 0]
        example_img_masked_dict[l] = example_img_masked
        if verbose:
            ax[i].imshow(example_img_masked)
            ax[i].set_title(l)
    print("Mask distribution of example mask for lead II:")
    print(pd.Series(signal_masks_cropped["II"].numpy().flatten()).value_counts())

In [None]:
# Vectorise
signals_predicted = {}
for lead, mask in signal_masks_cropped.items():
    signals_predicted[lead] = vectorise(
        image_rotated, 
        mask, 
        signal_positions_cropped[lead], 
        sec_per_pixel, 
        mV_per_pixel, 
        Y_VALUES_PER_LEAD["full"], 
        Y_VALUES_PER_LEAD[lead], 
        num_samples,
        interpolate_nans=interpolate_nans
    )

In [None]:
# Plot
if verbose and any([use_true_rot, use_true_scale, use_true_bbox, use_true_mask]):
    fig, ax = plt.subplots(
        len(signals_predicted), 3, figsize=(18, 4 * len(signals_predicted))
    )
    snr = []
    for i, (lead_name, predicted_signal_sampled) in enumerate(
        signals_predicted.items()
    ):
        # Get data
        mask_signal = reorder_signal(
            label_signal, label_fields["sig_name"], [lead_name]
        )
        original_signal = torch.tensor(mask_signal[:, 0])
        original_signal = original_signal[~torch.isnan(original_signal)]

        # Calc difference
        if len(original_signal) == len(predicted_signal_sampled):
            difference = predicted_signal_sampled - original_signal
            signal_snr = compute_snr(
                original_signal, predicted_signal_sampled
            )
        else:
            raise ValueError("Lengths of original and predicted signal do not match.")

        # Plot
        masked_img = example_img_masked_dict[lead_name]
        snr.append(signal_snr)
        ax[i, 0].plot(original_signal)
        ax[i, 0].plot(predicted_signal_sampled)
        ax[i, 0].legend(["Original", "Predicted"])
        ax[i, 0].set_title(f"{lead_name}: Original and predicted signal")
        ax[i, 1].imshow(masked_img)
        ax[i, 1].set_title(f"{lead_name}: Masked image")
        ax[i, 2].plot(difference)
        ax[i, 2].set_title(
            f"{lead_name}: Difference signal (SNR: {round(signal_snr,2)})"
        )
    fig.suptitle(f"Average SNR: {round(np.mean(snr),2)}")
    fig.tight_layout(pad=4.3)
    plt.show()

#### Prepare saving

In [None]:
# Save correct window
signal_list = []
for signal_name in signal_names:
    signal = signals_predicted[signal_name].numpy()
    if len(signal) < num_samples:
        nan_signal = np.empty(num_samples)
        nan_signal[:] = np.nan
        signal_start = SIGNAL_START[signal_name] * num_samples/10
        nan_signal[int(signal_start):int(signal_start + len(signal))] = signal
        signal_list.append(nan_signal)
    else:
        signal_list.append(signal)
        
# Transpose
signal = np.array(signal_list).T

if verbose:
    print(f"Signal shape: {signal.shape} (should be (5000, 12))")

In [None]:
# Prep labels
model = classification_model["model"]
classes = classification_model["classes"]
features = extract_features(record)
features = features.reshape(1, -1)
probabilities = model.predict_proba(features)
probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]
max_probability = np.nanmax(probabilities)
labels = [
    classes[i]
    for i, probability in enumerate(probabilities)
    if probability == max_probability
]

### 3.3 Save

In [None]:
# Save
signals = signal
output_path = os.path.split(output_record)[0]
os.makedirs(output_path, exist_ok=True)
data_header = load_header(data_record)
save_header(output_record, data_header)
comments = [l for l in data_header.split("\n") if l.startswith("#")]
save_signals(output_record, signals, comments)
save_labels(output_record, labels)

### 3.4 Check if correctly saved for evaluation

In [None]:
# Test evaluation
input_record = data_record
print(f"Comparing {input_record} and {output_record}")

# Load the signals
input_signal, input_fields = load_signals(input_record)
input_channels = input_fields["sig_name"]
input_num_samples = input_fields["sig_len"]
output_signal, output_fields = load_signals(output_record)
output_channels = output_fields["sig_name"]
channels = input_channels

# Reorder
output_signal = reorder_signal(output_signal, output_channels, input_channels)

# Trim official
output_signal_trimmed = trim_signal(output_signal, input_num_samples)

# Replace nan with 0
output_signal_trimmed[np.isnan(output_signal_trimmed)] = 0

In [None]:
# Calculate snr
snr = []
for j, channel in enumerate(channels):

    value = compute_snr(input_signal[:, j], output_signal_trimmed[:, j])
    snr.append(value)

mean_snr = np.nanmean(snr)
print(f"Mean SNR: {mean_snr}")

In [None]:
# Plot
fig, ax = plt.subplots(len(channels), 2, figsize=(12, 4 * len(channels)))
for j, channel in enumerate(channels):
    ax[j, 0].plot(input_signal[:, j])
    ax[j, 0].plot(output_signal_trimmed[:, j])
    ax[j, 0].set_xlim(0, 5000)
    ax[j, 0].legend(["Original", "Predicted"])
    ax[j, 0].set_title(f"{channel}: Original and predicted signal")

    ax[j, 1].plot(input_signal[:, j] - output_signal_trimmed[:, j])
    ax[j, 1].set_xlim(0, 5000)
    ax[j, 1].set_title(f"{channel}: Difference signal (SNR: {round(snr[j],2)})")

fig.suptitle(
    f"Record: {os.path.split(data_record)[1]}, Average SNR: {round(mean_snr,2)}"
)
fig.tight_layout(pad=4.3)
plt.show()

## 4. evaluate_model

In [None]:
from evaluate_model import *

In [None]:
# Prep
args = argparse.ArgumentParser()
args.input_folder = f"data/test_inputs"
args.output_folder = "data/test_outputs"
args.score_file = "data/evaluation/scores.csv"
args.extra_scores = False

In [None]:
# Compute scores and unpack them
scores = evaluate_model(args.input_folder, args.output_folder, args.extra_scores)
(
    snr,
    snr_median,
    ks_metric,
    asci_metric,
    mean_weighted_absolute_difference_metric,
    f_measure,
) = scores

In [None]:
# Convert
(
    snr,
    snr_median,
    ks_metric,
    asci_metric,
    mean_weighted_absolute_difference_metric,
    f_measure,
) = scores
output_string = f"SNR: {snr:.3f}\n" + f"F-measure: {f_measure:.3f}\n"
print(output_string)

In [None]:
# Save
if args.score_file:
    score_folder = os.path.split(args.score_file)[0]
    os.makedirs(score_folder, exist_ok=True)
    save_text(args.score_file, output_string)
else:
    print(output_string)