# Generate the interactive GIF pictures of the recursive splitting of the BPT/AA structure of an input image

In [1]:
import numpy as np
import json
import sys, os, importlib, math
import matplotlib.pyplot as plt
import cv2
from tqdm.auto import tqdm
from skimage.segmentation import mark_boundaries
import io, imageio

from matplotlib import rc
rc('text',usetex=True)
rc('text.latex', preamble='\\usepackage{color}')

import shap_bpt as shap_bpt
print(shap_bpt.__version__)

1.0


In [2]:
image_to_explain = cv2.imread('flamingo4.png', cv2.IMREAD_COLOR)[:,:,::-1].astype(np.uint8)
print(image_to_explain.shape)

(224, 224, 3)


In [3]:
%%time
bptree = shap_bpt.build_bpt_from_image(image_to_explain)

CPU times: user 127 ms, sys: 6.76 ms, total: 134 ms
Wall time: 135 ms


In [4]:
import matplotlib.colors as mcolors
cmap = [shap_bpt.hex_to_rgb(c[1]) for c in list(mcolors.XKCD_COLORS.items())]

def colorize(nodes, img, i):
    is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)
    pxflat_image = image_to_explain.reshape((img.shape[0] * img.shape[1], 3))
    colored = np.zeros_like(image_to_explain, dtype=np.float32)
    flat_colored = colored.reshape(pxflat_image.shape)
    for node in nodes:
        if is_aa:
            clr = np.mean(np.mean(image_to_explain[node.ymin:node.ymax, node.xmin:node.xmax, :], axis=1), axis=0)/255.0
            colored[ node.ymin:node.ymax, node.xmin:node.xmax ] = clr
        else:
            s,e = node.pixels_interval()
            clr = np.mean(pxflat_image[ node.bpt.pixels[s:e] ], axis=0)/255.0
            flat_colored[ node.bpt.pixels[s:e], :: ] = clr #np.array(cmap[s % len(cmap)])[0:3]
    return colored

def make_segments(nodes, img):
    is_aa = isinstance(nodes[0], shap_bpt.AxisAlignedSegment)
    flat_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint32)
    for i, node in enumerate(nodes):
        if is_aa:
            flat_img[ node.ymin:node.ymax, node.xmin:node.xmax ] = i
        else:
            s,e = node.pixels_interval()
            flat_img.ravel()[ node.bpt.pixels[s:e] ] = i
    return flat_img#.reshape((img.shape[0], img.shape[1]))

In [6]:
base_segment = shap_bpt.BaseSegment()

# Build visualization for AA or BPT (keep one)
# root_node, name, title = shap_bpt.AxisAlignedSegment(0, bptree.width, 0, bptree.height, base_segment), 'aa', 'AA Hierarchy'
root_node, name, title = shap_bpt.BPT_Segment(bptree, bptree.N-1, base_segment), 'bpt', 'BPT Hierarchy'

segments = [root_node]
all_nodes = [root_node]
prev_boundaries = None

K=11
frames = []
leaves = np.zeros(K, dtype=int)
fig, ax = plt.subplots(figsize=(3,3))
for ii in range(0,K):
    img = colorize(segments, image_to_explain, 0)
    img = np.clip(0.2 + img * 1.1, 0, 1)
    sgm = make_segments(segments, image_to_explain)
    cut_color = (.5, 0, .25, 1)
    boundaries = mark_boundaries(np.tile((255,255,255,0), (image_to_explain.shape[0],image_to_explain.shape[1],1)), sgm, 
                                 mode='thick', color=cut_color)
    ax.set_xticks([]) ; ax.set_yticks([])
    ax.set_title(f'{title} ({ii}/{K-1})', fontsize=16)
    ax.imshow(img)
    if ii==0:
        pass
    else:
        ax.imshow(boundaries)
        nshape = (224*224, 4)
        boundaries.reshape(nshape)[ np.where(boundaries.reshape(nshape) == cut_color)[0] ] = (0,0,0,1)
        if prev_boundaries is not None:
            ax.imshow(prev_boundaries)

    prev_boundaries = boundaries

    new_segments = []
    for s in segments:
        split = s.split(s, s)
        if split is None:
            new_segments.append(s)
            leaves[ii] += 1
        else:
            new_segments.extend(split)
            all_nodes.extend(split)

    segments = new_segments

    io_buf = io.BytesIO()
    fig.savefig(io_buf, format='raw', dpi=100)
    io_buf.seek(0)
    img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                         newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
    io_buf.close()
    frames.append(img_arr)
    ax.clear()

imageio.mimsave(f'sequence_{name}.gif', frames, duration=[1000]+([1000] * (len(frames)-2))+[3000], loop=0)
plt.close()
print('saved.')

saved.
