In [None]:
from oxford_pets_train_script import *
import torchvision.transforms.functional as F
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.io import read_image
from pathlib import Path
import pickle
from PIL import ImageFont
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision.ops import box_iou
from torchvision.transforms.functional import convert_image_dtype

# https://pytorch.org/vision/0.10/auto_examples/plot_visualization_utils.html
# boxes and masks

In [None]:
with open("aug_test_losses.pkl", "rb") as fd:
    a = pickle.load(fd)
with open("aug_train_losses.pkl", "rb") as fd:
    b = pickle.load(fd)
with open("base_test_losses.pkl", "rb") as fd:
    c = pickle.load(fd)
with open("base_train_losses.pkl", "rb") as fd:
    d = pickle.load(fd)

In [None]:
print(np.min(c) - np.min(a))

In [None]:
plt.plot(range(len(a)), a, label="Aug test losses")
plt.plot(range(len(b)), b, label="Aug train losses")
plt.plot(range(len(c)), c, label="Base test losses")
plt.plot(range(len(d)), d, label="Base train losses")
plt.legend()
plt.show()

In [None]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(15, 10))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


mean = -np.array([0.485, 0.456, 0.406])
std = 1 / np.array([0.229, 0.224, 0.225])
inverse_norm = transforms.Compose(
    [
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
        transforms.Normalize(mean=mean, std=[1, 1, 1]),
    ]
)


def draw_bb(normed_img_data, boxes):
    reg_img = inverse_norm(normed_img_data)
    reg_img = convert_image_dtype(reg_image, dtype=torch.uint8).to("cpu")
    boxed_img = draw_bounding_boxes(reg_image, boxes)
    return boxed_img

In [None]:
fs, ds, ts = gen_dataset()
id_to_breed = {v: k for k, v in fs.breed_assoc.items()}
dl, tl = gen_loaders(ds, ts, batch_size=1)
test_dataset = ts
base_model = torch.load("best_base_model").to(device)
aug_model = torch.load("best_aug_model").to(device)
pitbull_idx = 1000
havanese_idx = 2000
tdata, tlabel = fs.get_unchanged_pic(havanese_idx)

In [None]:
def gen_mask_and_box_for_model_and_data(model, img_tensor, label):
    font_path = "/usr/share/fonts/truetype/liberation/LiberationMono-Italic.ttf"
    font_size = 14
    # font = ImageFont.truetype(font_path, size=6)
    model.eval()
    with torch.no_grad():
        preds = model([img_tensor])

    reg_image = convert_image_dtype(
        inverse_norm(img_tensor.clone()), dtype=torch.uint8
    ).to("cpu")

    if preds[0]["labels"].numel() > 0:
        animal_name = id_to_breed[label["labels"].item()]
        top = preds[0]["labels"][0].item()
        pred_name = id_to_breed[preds[0]["labels"][0].item()]
        if top == label["labels"][0].item():
            # correctly predicted
            reg_pred_image_w_boxes = draw_bounding_boxes(
                reg_image,
                preds[0]["boxes"][:1],
                colors=["green"],
                labels=[pred_name],
                font=font_path,
                font_size=font_size,
            )
        else:
            reg_pred_image_w_boxes = draw_bounding_boxes(
                reg_image,
                preds[0]["boxes"][:1],
                colors=["red"],
                labels=[pred_name],
                font=font_path,
                font_size=font_size,
            )

        reg_label_image_w_boxes = draw_bounding_boxes(
            reg_image,
            label["boxes"][:1],
            colors=["white"],
            labels=[animal_name],
            font=font_path,
            font_size=font_size,
        )
        reg_image_w_mask = draw_segmentation_masks(
            reg_image, masks=preds[0]["masks"][0] > 0.5, alpha=0.7, colors=["green"]
        )
        reg_image_w_mask = draw_segmentation_masks(
            reg_image_w_mask, masks=label["masks"][0] > 0.5, alpha=0.4, colors=["red"]
        )
        # reg_image_w_box_w_mask = draw_bounding_boxes(draw_segmentation_masks(reg_image, masks=preds[0]["masks"][0]>0.5, alpha=0.7), preds[0]["boxes"][:1])

    else:
        raise Exception("No category in labels")
    return reg_pred_image_w_boxes, reg_label_image_w_boxes, reg_image_w_mask

In [None]:
pred_boximg, label_boximg, maskimg = gen_mask_and_box_for_model_and_data(
    base_model, tdata, tlabel
)

In [None]:
show([pred_boximg, label_boximg, maskimg])

In [None]:
pred_boximg, label_boximg, maskimg = gen_mask_and_box_for_model_and_data(
    aug_model, tdata, tlabel
)

In [None]:
show([pred_boximg, label_boximg, maskimg])

In [None]:
def eval_model(model):
    preds = []
    with torch.no_grad():
        model.eval()
        for x, y in tqdm.tqdm(ts, total=len(tl), ncols=0):
            output = model([x])
            preds.append((output, y))

    return preds

In [None]:
# base_eval_preds = eval_model(base_model)
# aug_eval_preds = eval_model(aug_model)

In [None]:
# with open("aug_eval_preds.pkl", "wb") as fd:
#     pickle.dump(aug_eval_preds, fd)

# with open("base_eval_preds.pkl", "wb") as fd:
#     pickle.dump(base_eval_preds, fd)

In [None]:
with open("base_eval_preds.pkl", "rb") as fd:
    base_eval_preds = pickle.load(fd)

with open("aug_eval_preds.pkl", "rb") as fd:
    aug_eval_preds = pickle.load(fd)

In [None]:
base_preds, base_labels = base_eval_preds[0]

In [None]:
base_preds[0]["labels"]

In [None]:
def gen_ious(preds):
    def calculate_mask_iou(pred_mask, label_mask):
        pred_mask = (pred_mask > 0.5).float()
        label_mask = (label_mask > 0.5).float()
        intersection = torch.sum(pred_mask * label_mask)
        union = torch.sum(pred_mask) + torch.sum(label_mask) - intersection
        if union == 0:
            return 0
        else:
            return (intersection / union).item()

    box_ious = []
    mask_ious = []
    for pred, label in preds:
        pred = pred[0]
        if label["labels"].numel() == 0:
            continue
        breed_id = label["labels"].item()
        label_mask = label["masks"][0]
        label_box = label["boxes"][0].unsqueeze(0)
        pred_idxs = torch.where(pred["labels"] == breed_id)[0]
        if pred_idxs.numel() == 0:
            iou = 0
        else:
            pred_idx = pred_idxs[0]
            pred_mask = pred["masks"][pred_idx]
            pred_box = pred["boxes"][pred_idx].unsqueeze(0)
            biou = box_iou(pred_box, label_box)
            siou = calculate_mask_iou(pred_mask, label_mask)
            box_ious.append(biou.item())
            mask_ious.append(siou)

    return box_ious, mask_ious

In [None]:
base_box_ious, base_mask_ious = gen_ious(base_eval_preds)
aug_box_ious, aug_mask_ious = gen_ious(aug_eval_preds)

In [None]:
print(
    f"box_iou for base model:{np.mean(base_box_ious):0.4f}, box_iou for aug model:{np.mean(aug_box_ious):0.4f}"
)
print(
    f"mask_iou for base model:{np.mean(base_mask_ious):0.4f}, mask_iou for aug model:{np.mean(aug_mask_ious):0.4f}"
)