In [None]:
!mkdir --parents /tmp/cache/yaak-datasets/metadata

In [None]:
from omegaconf import OmegaConf
import bertviz
import pytorch_grad_cam
from hydra.utils import instantiate
import torch

import matplotlib.pyplot as plt
import more_itertools as mit
from deephouse.tools.camera import Camera
from einops import rearrange, reduce, repeat
from torchvision.transforms import Normalize


class Unnormalize(Normalize):
    def __init__(self, mean, std, **kwargs):
        mean = torch.tensor(mean)
        std = torch.tensor(std)

        super().__init__(
            mean=(-mean / std).tolist(),
            std=(1.0 / std).tolist(),
            **kwargs,
        )

        
def get_figure(images, batch_no):
    layers, heads, height, width, _ = images.shape

    # determine the size of the figure based on the aspect ratio of the images
    base_size = 12
    if width > height:
        fig_width = base_size
        fig_height = base_size * height / width
    else:
        fig_height = base_size
        fig_width = base_size * width / height
        
    # create a new figure with 4x4 subplots
    fig, axs = plt.subplots(layers, heads, figsize=(fig_width, fig_height), gridspec_kw={'wspace': 0, 'hspace': 0})
    
    # add a border between subplots
    spine_kwargs = {'color': 'white', 'linewidth': 1.}
    for ax in axs.flat:
        ax.spines['top'].set(**spine_kwargs)
        ax.spines['bottom'].set(**spine_kwargs)
        ax.spines['left'].set(**spine_kwargs)
        ax.spines['right'].set(**spine_kwargs)
        ax.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
        
    fig.suptitle(f"Batch[{batch_no}]", fontsize=14)
    fig.text(0.5, 0.05, "Heads" if heads > 1 else "Avg heads", ha='center', fontsize=10)
    fig.text(0.1, 0.5, "Layers", va='center', rotation='vertical', fontsize=10)

    # loop through each subplot and display the corresponding image
    for i in range(layers):
        for j in range(heads):
            img = images[i, j]
            try:
                ax = axs[i, j]
            except IndexError:
                ax = axs[i]
            ax.imshow(img)
            if j==0:
                ax.set_ylabel(i, rotation=0)
                ax.yaxis.set_label_coords(-0.1,0.5)
            if i==(layers-1):
                ax.set_xlabel(j)

    return fig
        
    
@torch.no_grad()
def get_state_and_frames(cilpp, batch):
    clips = mit.one(batch["clips"].values())
    frames = rearrange(clips["frames"], "b 1 c h w -> b c h w")

    meta = clips["meta"]
    speed = meta["VehicleMotion_speed"].to(torch.float32)
    
    if any(camera_params := clips.get("camera_params", {}).copy()):
        camera_model = mit.one(set(camera_params.pop("model")))
        camera = Camera.from_params(model=camera_model, params=camera_params)
        camera = camera.to(frames)
    else:
        camera = None
    
    state = cilpp._embed_state(frames=frames, speed=speed, camera=camera)
    
    return state, frames


@torch.no_grad()
def compute_labels(batch):
    clips = mit.one(batch["clips"].values())
    meta = clips["meta"]

    gas = meta["VehicleMotion_gas_pedal_normalized"]
    brake = meta["VehicleMotion_brake_pedal_normalized"]
    # NOTE: assuming (gas > 0) xor (brake > 0)
    accel_lbl = gas - brake
    steering_lbl = meta["VehicleMotion_steering_angle_normalized"]

    return {
        "acceleration": accel_lbl.to(torch.float32),
        "steering_angle": steering_lbl.to(torch.float32),
    }


@torch.no_grad()
def get_attention_maps(encoder: torch.nn.modules.transformer.TransformerEncoder, x: torch.Tensor, mask=None):
    attention_maps = []
    for l in encoder.layers:
        attn_x = x.detach()
        _, attn_map = l.self_attn(attn_x, attn_x, attn_x, attn_mask=mask, need_weights=True, average_attn_weights=False)
        attention_maps.append(attn_map)
        x = l.forward(x)
    return attention_maps


def merge_frames_with_maps(frames, maps, map_weight=0.5, boost_channel=0, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), out_height=40, 
                           norm_maps=False, global_norm=True, no_attn_itself=True, avg_heads=False):
    batch_size, layers, heads, tokens, _= maps.shape
    _, channels, height, width = frames.shape
    
    out_scale_factor = out_height / height
    
    if no_attn_itself:
        tmp = torch.eye(tokens, tokens) * (-20000) # same effect as -inf, but avoiding 0*(-inf) = nan
        tmp = repeat(tmp, 't1 t2 -> b layers heads t1 t2', b=batch_size, layers=layers, heads=heads)
        maps = maps * tmp
    maps = torch.nn.functional.softmax(maps, dim=-1)
    maps = torch.mean(maps, dim=-2)  # batch_size, layers, heads, tokens
    
    if avg_heads:
        maps = maps.mean(dim=2, keepdim=True)
        heads = 1
    
    if norm_maps:
        if global_norm:
            tokens_min = reduce(maps, 'b layers heads tokens -> b', "min")
            tokens_max = reduce(maps, 'b layers heads tokens -> b', "max")
            tokens_min = repeat(tokens_min, 'b -> b layers heads tokens', layers=layers, heads=heads, tokens=tokens)
            tokens_max = repeat(tokens_max, 'b -> b layers heads tokens', layers=layers, heads=heads, tokens=tokens)
        else:
            tokens_min, _ = maps.min(dim=-1, keepdim=True)
            tokens_max, _ = maps.max(dim=-1, keepdim=True)
        maps = (maps - tokens_min) / (tokens_max - tokens_min)
    
    R = width / height
    H = (tokens / R) ** 0.5
    maps = rearrange(maps, 'b layers (heads c) (H W) -> (b layers heads) c H W', H=int(H), c=1)
        
    maps = torch.nn.functional.interpolate(maps, scale_factor=height/H*out_scale_factor, mode='nearest-exact')
    maps = rearrange(maps, '(b layers heads) c H W -> b layers (heads c) H W', layers=layers, heads=heads)
    
    unorm = Unnormalize(mean, std)
    imgs = unorm(frames)
    imgs = torch.nn.functional.interpolate(imgs, scale_factor=out_scale_factor, mode='bicubic')
    imgs = imgs.clamp(min=0., max=1.)
    
    imgs = repeat(imgs, 'B C H W -> B layers heads C H W', layers=layers, heads=heads)
    imgs = rearrange(imgs, 'B layers heads C H W -> C B layers heads H W').clone()
    if boost_channel < 3:
        imgs[boost_channel] = (1. - map_weight) * imgs[boost_channel] + map_weight * maps
    else:
        if norm_maps:
            maps = (1. - map_weight) + map_weight * maps
        imgs = torch.concatenate([imgs, rearrange(maps, '(C b) layers heads H W -> C b layers heads H W', C=1)], dim=0)
    imgs = rearrange(imgs, 'C B layers heads H W -> B layers heads H W C')    
    return imgs

In [None]:
from cargpt.models.cilpp import CILpp

wandb_model = "yaak/cargpt/model-uv8xv088:v0"
cilpp = CILpp.load_from_wandb_artifact(name=wandb_model)
# cilpp = CILpp.load_from_checkpoint("artifacts/model-rc93mcrx:v7/model.ckpt")
cilpp.eval()

cfg = OmegaConf.load("config/experiment/cilpp.yaml")
datamodule = instantiate(cfg.datamodule)

# data = datamodule.train_dataloader()
data = datamodule.val_dataloader()

In [None]:
batch = next(iter(data))

In [None]:
states, frames = get_state_and_frames(cilpp, batch)
labels = compute_labels(batch)
maps = get_attention_maps(cilpp.transformer_encoder, states)
maps = torch.stack(maps)
maps = rearrange(maps, 'layers b heads t1 t2 -> b layers heads t1 t2')

In [None]:
out = merge_frames_with_maps(frames, maps, boost_channel=3, out_height=120, map_weight=1, norm_maps=True, global_norm=True, no_attn_itself=True, avg_heads=False)
for i, r in enumerate(out):
    print({(k, labels[k][i].item()) for k in sorted(labels.keys())})
    fig = get_figure(r, i)
    plt.show(fig)
    print("-"*60)

           
out_avg = merge_frames_with_maps(frames, maps, boost_channel=3, out_height=120, map_weight=1, norm_maps=True, global_norm=True, no_attn_itself=True, avg_heads=True)
for i, r in enumerate(out_avg):
    print({(k, labels[k][i].item()) for k in sorted(labels.keys())})
    fig = get_figure(r, i)
    plt.show(fig)
    print("-"*60)

plt.close()