In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk

from fast_tfai.utils.utils import get_all_images

In [None]:
def pad_and_resize(img):
    max_dim = img.shape[1]
    new_img = np.zeros((max_dim, max_dim, 3), dtype=np.uint8)
    pad = (max_dim - img.shape[0]) // 2
    new_img[pad : pad + img.shape[0], :, :] = img
    return cv2.resize(new_img, (380, 380))

In [None]:
def _wrap_with_heatmap(model: tf.keras.Model):
    last_conv_layer = model.layers[-2].layers[-2].get_output_at(0)
    weights = model.layers[-1].get_weights()[0]

    backbone_conv_model = tf.keras.Model(
        inputs=model.layers[-2].input, outputs=[last_conv_layer]
    )

    new_inputs = tf.keras.layers.Input(shape=model.input.shape[1:])
    conv_layer = backbone_conv_model(new_inputs)
    new_model = tf.keras.Model(inputs=new_inputs, outputs=[conv_layer], name="castrato")

    new_input = tf.keras.layers.Input(shape=model.input.shape[1:])
    pred = model(new_input)
    conv = new_model(new_input)
    pred_class = tf.math.argmax(pred, axis=1)

    reshaped_w = tf.expand_dims(weights, axis=0)
    class_weights = tf.transpose(
        tf.gather(reshaped_w, pred_class, axis=2), perm=[2, 1, 0]
    )

    output = tf.matmul(conv, class_weights)
    heatmaps = tf.image.resize(
        output, model.input.shape[1:3], method="bilinear", name="heatmap"
    )

    return tf.keras.Model(new_input, [pred, heatmaps])

In [None]:
def _wrap_with_heatmap_new(model: tf.keras.Model):
    last_conv_layer = model.layers[-2].layers[-2].get_output_at(0)
    weights = model.layers[-1].get_weights()[0]

    backbone_conv_model = tf.keras.Model(
        inputs=model.layers[-2].input, outputs=[last_conv_layer]
    )

    new_inputs = tf.keras.layers.Input(shape=model.input.shape[1:])
    conv_layer = backbone_conv_model(new_inputs)
    new_model = tf.keras.Model(inputs=new_inputs, outputs=[conv_layer], name="castrato")

    new_input = tf.keras.layers.Input(shape=model.input.shape[1:])
    pred = model(new_input)
    conv = new_model(new_input)

    mask = pred > 0.5
    pred_proba = pred * tf.cast(mask, tf.float32)
    normalized_proba = pred_proba / tf.reduce_sum(pred_proba, axis=1, keepdims=True)
    normalized_proba = tf.where(
        tf.math.is_nan(normalized_proba),
        tf.zeros_like(normalized_proba),
        normalized_proba,
    )

    weighted_weights = tf.matmul(normalized_proba, tf.transpose(weights, perm=[1, 0]))
    heatmap = tf.matmul(conv, tf.expand_dims(weighted_weights, axis=2))
    heatmap = (
        (heatmap - tf.reduce_min(heatmap))
        / (tf.reduce_max(heatmap) - tf.reduce_min(heatmap))
        * 255
    )

    heatmaps = tf.image.resize(
        heatmap, model.input.shape[1:3], method="bilinear", name="heatmap"
    )

    return tf.keras.Model(new_input, [pred, heatmaps])

In [None]:
model_path = "/home/simone/workspace/fogna/outputs/ompi/ST4/"
model = tfk.models.load_model(model_path)

In [None]:
model.summary()

In [None]:
model_with_heatmap = _wrap_with_heatmap_new(model)

In [None]:
model_with_heatmap.outputs

In [None]:
last_conv_layer = model.layers[-2].layers[-2].get_output_at(0)
weights = model.layers[-1].get_weights()[0]
backbone_conv_model = tf.keras.Model(
    inputs=model.layers[-2].input, outputs=[last_conv_layer]
)
new_inputs = tf.keras.layers.Input(shape=model.input.shape[1:])
conv_layer = backbone_conv_model(new_inputs)
new_model = tf.keras.Model(inputs=new_inputs, outputs=[conv_layer], name="castrato")

new_input = tf.keras.layers.Input(shape=model.input.shape[1:])

In [None]:
pred = model(new_input)
conv = new_model(new_input)

In [None]:
images_list = get_all_images("/home/simone/workspace/fogna/outputs/ompi/ST4/st4_ompi/")

In [None]:
k = 11
img = cv2.imread(str(images_list[k]))
img = pad_and_resize(img)
plt.imshow(img)

In [None]:
pred_simulated = model(img[np.newaxis, ...])
conv = new_model(img[np.newaxis, ...])

In [None]:
mask = pred_simulated > 0.5
pred_proba = pred_simulated * tf.cast(mask, tf.float32)
normalized_proba = pred_proba / tf.reduce_sum(pred_proba, axis=1, keepdims=True)
normalized_proba = tf.where(
    tf.math.is_nan(normalized_proba), tf.zeros_like(normalized_proba), normalized_proba
)

print(pred_proba)
print(normalized_proba)

In [None]:
weighted_weights = tf.matmul(normalized_proba, tf.transpose(weights, perm=[1, 0]))
heatmap = tf.matmul(conv, tf.expand_dims(weighted_weights, axis=2))
heatmap = (
    (heatmap - tf.reduce_min(heatmap))
    / (tf.reduce_max(heatmap) - tf.reduce_min(heatmap))
    * 255
)
print(weighted_weights.shape)

In [None]:
heatmaps = cv2.resize(heatmap[0, :].numpy(), (380, 380), interpolation=cv2.INTER_LINEAR)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 9))
ax[0].imshow(img)
ax[1].imshow(img)
ax[1].imshow(np.array(heatmaps, np.uint8), cmap="jet", alpha=0.5)

In [None]:
outputs = model_with_heatmap(img[np.newaxis, ...])
heat = outputs[1]

fig, ax = plt.subplots(1, 2, figsize=(16, 9))
ax[0].imshow(img)
ax[1].imshow(img)
ax[1].imshow(np.array(heat[0, :], np.uint8), cmap="jet", alpha=0.5)