In [None]:
import cv2
import torch
import numpy as np
import math

from tqdm.notebook import tqdm

import helper_pytorch as H

from dataset import Dataset
from post_processing import post_processing

In [None]:
run_name = "uacanet"
run_id = 0
ckpt_path = f"./ckpts/{run_name}{run_id}"
model_valid_path = f"{ckpt_path}/model_valid.pt"
history_path = f"{ckpt_path}/history.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Make model

In [None]:
# model = H.UNet(
#     capacity=64,
#     n_classes=4,
#     n_channels=1
# )
# model = H.resnet18(
#     capacity=64,
#     n_classes=4,
#     in_channels=1
# )
# model = H.CE_Net_(
#     num_channels=1,
#     num_classes=4
# )
model = H.UACANet(
    n_channels=1,
    n_classes=4,
    pretrained=False,
)
# model = H.segmenter(
#     img_height=352,
#     img_width=math.ceil(1250*np.pi/16)*16,
# )

model.to(device)
model.load_state_dict(torch.load(model_valid_path)['model'])

### Plot history

In [None]:
H.plot_history(history_path, 3)

### Make dataloaders

In [None]:
test_dataset = Dataset(
    split='test',
    do_transform=False,
)

In [None]:
img, gt = test_dataset[2]
img = torch.unsqueeze(img, 0)
print(img.shape)

### Predict

In [None]:
results = []

for i, img_gt in enumerate(tqdm(test_dataset)):
    img = img_gt[0]
    img = torch.unsqueeze(img, 0)
    img = img.to(device)

    pred = model(img)

    pred = pred[0].detach().cpu().numpy().argmax(axis=0)

    pred = post_processing(pred)

    gt = cv2.imread(f"./data/test/gts/{test_dataset.files[i]}", cv2.IMREAD_GRAYSCALE)
    gt[gt==11] = 3
    gt[gt==9] = 2

    dscs = H.dice_np(
        gts=gt,
        preds=pred,
        n_classes=4
    )

    results.append(dscs)

    del img, img_gt, pred

results = np.array(results).transpose(1, 0)

In [None]:
print(results.mean(axis=1))

In [None]:
indexes = {
    "closest_to_mean_index": np.abs(results[3]-results[3].mean()).argmin(axis=0),
    "max_index": results[3].argmax(),
    "min_index": results[3].argmin(),
}

In [None]:
for key, item in indexes.items():
    img, gt = test_dataset[item]
    img = torch.unsqueeze(img, 0)
    img = img.to(device)

    pred = model(img)

    pred = pred[0].detach().cpu().numpy().argmax(axis=0)
    
    pred = post_processing(pred)

    gt = cv2.imread(f"./data/test/gts/{test_dataset.files[item]}", cv2.IMREAD_GRAYSCALE)
    img = cv2.imread(f"./data/test/images/{test_dataset.files[item]}")/255

    gt[gt==11] = 3
    gt[gt==9] = 2

    dscs = H.dice_np(
        gts=gt,
        preds=pred,
        n_classes=4
    )

    H.plot_single_data(
        img,
        gt,
        pred,
        monitor_class=3,
        figsize=(30, 10),
        opacity=0.2,
        suptitle=f"dsc: {dscs[3]}"
    )


    

In [None]:
torch.cuda.empty_cache()