In [None]:
import cv2
import json
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
def is_gray(color: tuple):
    (r, g, b) = color
    diffs = map(abs, (r-g, r-b, g-b))
    return all(diff <= 10 for diff in diffs)


def is_useful_tile(tile: Image) -> bool:
    tile_copy = tile.copy()
    tile_copy.thumbnail((32, 32))
    _, most_frequent_color = max(tile_copy.getcolors(maxcolors=1024), key=lambda x: x[0])
    return not is_gray(most_frequent_color)


def prepare_data(attn_file: str, annot_file: str, drop_white: bool = True):
    with open(attn_file) as file:
        attention = json.load(file)
    with open(annot_file) as file:
        annotations = json.load(file)
    attention_maps = dict()
    for attention_image in attention:
        attention_maps[attention_image['image']] = np.array(attention_image['attention_px'], dtype=np.uint8)
    categories = [cat['name'] for cat in annotations['categories']]
    id2file = {im['id']: im['file_name'] for im in annotations['images']}
    images = {im['id']: im for im in annotations['images']}
    for annotation in annotations['annotations']:
        if 'annotations' not in images[annotation['image_id']]:
            images[annotation['image_id']]['annotations'] = []
        images[annotation['image_id']]['annotations'].append(annotation)
        images[annotation['image_id']]['attention'] = attention_maps[id2file[annotation['image_id']]]
    if drop_white:
        to_drop = []
        for image in images:
            if not is_useful_tile(Image.open(f"new_val/{images[image]['file_name']}")):
                to_drop.append(image)
        for image in to_drop:
            del images[image]
    return images, categories

In [None]:
def binary_mask_from_annotations(image_size: tuple, annotations: list):
    mask = np.full(image_size, False)
    i, j = np.ogrid[0: image_size[0], 0: image_size[1]]
    for annotation in annotations:
        j_, i_, j_size, i_size = annotation['bbox']
        scales = (i_size // 2, j_size // 2)
        center = (i_ + scales[0], j_ + scales[1])
        mask |= (i - center[0]) ** 2 / scales[0] ** 2 + (j - center[1]) ** 2 / scales[1] ** 2 <= 1
    return np.array(mask)


def binary_masks_from_image(image: dict):
    cat_annotations = [[] for _ in range(5)]
    for annotation in image['annotations']:
        cat_annotations[min(4, annotation['category_id'] - 1)].append(annotation)
    binary_masks = [binary_mask_from_annotations(
        (image['height'], image['width']),
        cat_annotations[i]
    ) for i in range(5)]
    return binary_masks


def metrics_from_attention_map(attention_map: np.array, bin_image: np.array, method: str):
    scaled_mask = cv2.resize(attention_map, dsize=bin_image.shape, interpolation=cv2.INTER_CUBIC)
    if method == 'Recall':
        return np.sum(scaled_mask * bin_image), np.sum(bin_image)
    elif method == 'Precision':
        return np.sum(scaled_mask * bin_image), np.sum(scaled_mask)
    else:
        return np.sum(scaled_mask * bin_image), np.sum(np.maximum(scaled_mask, bin_image))


def plot_heatmap(images: dict, categories: list, method: str = 'Recall'):
    numerator = np.zeros((6, len(categories) + 1))
    denominator = np.zeros_like(numerator)
    for image in tqdm(images.values(), total=len(images)):
        bin_masks = binary_masks_from_image(image)
        bin_mask_all_cells = binary_mask_from_annotations((image['height'], image['width']), image['annotations'])
        for i, attention_map in enumerate(image['attention']):
            for j, bin_mask in enumerate(bin_masks):
                num, den = metrics_from_attention_map(attention_map, bin_mask, method)
                numerator[i, j] += num
                denominator[i, j] += den
            num, den = metrics_from_attention_map(attention_map, bin_mask_all_cells, method)
            numerator[i, -1] += num
            denominator[i, -1] += den
    heatmap = numerator / denominator
    heatmap = pd.DataFrame(heatmap, columns=[*categories, 'All cells'])
    plt.tight_layout()
    fig, ax = plt.subplots(figsize=(8, 8))
    hm = sns.heatmap(heatmap, cbar=False, annot=True, linewidths=0.5, cmap='YlGnBu_r', square=True)
    plt.title(f'{method} of attention maps', fontdict={'size': 20}, pad=20)
    ax.set_ylabel('Attention maps', fontdict={'size': 16})
    ax.set_xlabel('Cell categories', fontdict={'size': 16})
    plt.xticks(rotation=45)

In [None]:
images, categories = prepare_data('oneshot_original_th0_5_chkpt10.json', 'new_val/instances_newval.json')
categories.remove('Spermatozoa')
plot_heatmap(images, categories, 'Recall')
plot_heatmap(images, categories, 'Precision')
plot_heatmap(images, categories, 'IoU')

In [None]:
i = 5
res = []
masks = binary_masks_from_image(images[i])
for j in range(5):
    im = np.array(Image.open('new_val/' + images[i]['file_name']))
    im[~masks[j]] //= 2
    res.append(im)
plt.figure(figsize=(30, 5))
ax = plt.axes([0,0,1,1], frameon=False)
ax.set_axis_off()
plt.imshow(np.hstack(res))
plt.imsave('cell_types.png', np.hstack(res), dpi=800)