In [None]:
!mkdir --parents /tmp/cache/yaak-datasets/metadata

In [None]:
from omegaconf import OmegaConf
import bertviz
import pytorch_grad_cam
from hydra.utils import instantiate
import torch

import matplotlib.pyplot as plt
import more_itertools as mit
from deephouse.tools.camera import Camera
from einops import rearrange, reduce, repeat
from torchvision.transforms import Normalize


class Unnormalize(Normalize):
    def __init__(self, mean, std, **kwargs):
        mean = torch.tensor(mean)
        std = torch.tensor(std)

        super().__init__(
            mean=(-mean / std).tolist(),
            std=(1.0 / std).tolist(),
            **kwargs,
        )

def get_unnorm_frames(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):    
    clips = mit.one(batch["clips"].values())
    frames = clips["frames"][:,-1,...]

    unorm = Unnormalize(mean, std)
    imgs = unorm(frames)
    return imgs.clamp(min=0.0, max=1.0)


def get_frames(batch):
    # Return the last frame from the context
    clips = mit.one(batch["clips"].values())
    frames = clips["frames"][:,-1,...]
    return frames

In [None]:
from cargpt.models.cilpp import CILpp


# wandb_model = "yaak/cargpt/model-knmg1pgn:v0"
wandb_model = "yaak/cargpt/model-6hm5s715:v0"
cilpp = CILpp.load_from_wandb_artifact(name=wandb_model)
# cilpp = CILpp.load_from_checkpoint("artifacts/model-rc93mcrx:v7/model.ckpt")
# cilpp.eval()

# cfg = OmegaConf.load("config/experiment/cilpp.yaml")
# datamodule = instantiate(cfg.datamodule)
# data = datamodule.val_dataloader()

cfg = OmegaConf.load("config/data/train.yaml")
datamodule = instantiate(cfg.datamodule)
data = datamodule.train_dataloader()

In [None]:
batch = next(iter(data))

In [None]:
plt.close()
imgs = rearrange(get_unnorm_frames(batch), 'b c h w -> b h w c')
print(imgs.shape)
plt.imshow(imgs[2])

# Try with the simples classification

Based on https://jacobgil.github.io/pytorch-gradcam-book/introduction.html

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet34, resnet50

model1 = resnet34(weights="IMAGENET1K_V1")
target_layers = [model1.layer4[-1]]

idx = 2
input_tensor = get_frames(batch)[idx:idx+1]

def draw_only_resnet_vis(cam_cls, tgt_cls_int, title):
    with cam_cls(model=model1, target_layers=target_layers, use_cuda=True) as cam:
        targets = [ClassifierOutputTarget(tgt_cls_int)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(imgs[idx].numpy(), grayscale_cam, use_rgb=True)
    plt.imshow(visualization);
    plt.title(title);
    plt.show()

plt.close()
for cam_cls in [GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad]:
    print(cam_cls.__name__)
    # draw_only_resnet_vis(cam_cls, tgt_cls_int=479, title=f"{cam_cls.__name__}: car wheel")
    # draw_only_resnet_vis(cam_cls, tgt_cls_int=475, title=f"{cam_cls.__name__}: car mirror")
    draw_only_resnet_vis(cam_cls, tgt_cls_int=920, title=f"{cam_cls.__name__}: traffic light, traffic signal, stoplight")

torch.cuda.empty_cache()

## Try with resnet from our architecture

The results are not the same due to **running mean and average of Batch Normalization** layers! All the weights are the same between ResNet34 instances.

In [None]:
model2 = next(cilpp.state_embedding.modules())['frame']['backbone'].resnet
model2.requires_grad_(True)
target_layers = [model2.layer4[-1]]

idx = 2
input_tensor = get_frames(batch)[idx:idx+1]

def draw_only_resnet_vis(cam_cls, tgt_cls_int, title):
    with cam_cls(model=model2, target_layers=target_layers, use_cuda=True) as cam:
        targets = [ClassifierOutputTarget(tgt_cls_int)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(imgs[idx].numpy(), grayscale_cam, use_rgb=True)
    plt.imshow(visualization);
    plt.title(title);
    plt.show()

plt.close()
for cam_cls in [GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad]:
    # draw_only_resnet_vis(cam_cls, tgt_cls_int=479, title=f"{cam_cls.__name__}: car wheel")
    # draw_only_resnet_vis(cam_cls, tgt_cls_int=475, title=f"{cam_cls.__name__}: car mirror")
    draw_only_resnet_vis(cam_cls, tgt_cls_int=920, title=f"{cam_cls.__name__}: traffic light, traffic signal, stoplight")

torch.cuda.empty_cache()

# Naive try with the whole CIL++

In [None]:
import pytorch_lightning as pl
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Shaped
from torch import Tensor

import numpy as np
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, RawScoresOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet34, resnet50

class WrapperModel(pl.LightningModule):
    def __init__(self, cilpp):
        super().__init__()
        self.cilpp = cilpp
        self._camera = None

        cilpp.requires_grad_(True)
        cilpp.eval()

        # self.cuda()
        self.requires_grad_(True)
        self.eval()

    def prepare_input_tensor(self, batch):
        clips = mit.one(batch["clips"].values())
        frames = clips["frames"]
        speed = clips["meta"]["VehicleMotion_speed"].to(frames)

        if isinstance(_camera_params := clips.get("camera_params", {}).copy(), dict):
            camera_params = _camera_params.copy()
            camera_model = mit.one(set(camera_params.pop("model")))
            _, t, *_ = frames.shape
            # need one camera per frame
            camera_params = {
                k: repeat(v, "b -> (b t)", t=t) for k, v in camera_params.items()
            }
            camera = Camera.from_params(model=camera_model, params=camera_params)
            camera = camera.to(frames)
        else:
            camera = None

        self._camera = camera

        b, t, c, h, w = frames.shape
        speed = repeat(speed, 'b sc -> b t c h w sc', t=t, c=c, h=h, w=w)
        frames = rearrange(frames, "b t c h (w fc) -> b t c h w fc", fc=1)
        input_tensor = torch.concat([frames, speed], dim=-1)
        input_tensor = rearrange(input_tensor, "b t c h w A -> b A t c h w")
        return input_tensor

    def forward(self, input_tensor):
        input_tensor = rearrange(input_tensor, "b A t c h w -> b t c h w A")
        frames = input_tensor[..., 0]
        speed = input_tensor[:, 0, 0, 0, 0, 1:]
        camera = self._camera
        pred = self.cilpp(frames=frames, speed=speed, camera=camera)
        pred = rearrange(pred, "b 1 c -> b c")
        return pred

    
class NegPosGradCAM(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super().__init__(
            model,
            target_layers,
            use_cuda,
            reshape_transform)

    def get_cam_weights(self,
                        input_tensor,
                        target_layers,
                        targets,
                        activations,
                        grads):
        with torch.no_grad():
            outputs = self.model(input_tensor.detach())  # [B, 2]
            if outputs.shape[0] > 1: raise NotImplementedError("Only batch size == 1 for now")
            scalars_outputs = [target(output)
                       for target, output in zip(targets, outputs)]
        grads = grads.copy()
        out = scalars_outputs[0].item()
        print("O", outputs, "SO",scalars_outputs, "out", out)
        if out < 0:
            elems = np.sum(grads < 0, axis=(2,3))
            selected_grads = -1.0 * np.where(grads < 0, grads, 0)
        else:
            elems = np.sum(grads >= 0, axis=(2,3))
            selected_grads = np.where(grads >= 0, grads, 0)
        # print(np.sum(selected_grads, axis=(2, 3)), elems)
        return np.sum(selected_grads, axis=(2, 3)) / (elems + 1e-8)

    
class RawAccelerationTarget:
    def __init__(self):
        pass

    def __call__(self, model_output):
        # print(model_output)
        return model_output[0]


class RawSteeringTarget:
    def __init__(self):
        pass

    def __call__(self, model_output):
        # print(model_output)
        return model_output[1]


def plot_images(images_dict, size=4):
    outer_keys = list(images_dict.keys())
    n_outer = len(outer_keys)
    n_inner = max(len(inner_dict) for inner_dict in images_dict.values())

    fig, axes = plt.subplots(nrows=n_inner, ncols=n_outer, figsize=(int(size*1.5 * n_outer), size * n_inner))

    for i, outer_key in enumerate(outer_keys):
        inner_dict = images_dict[outer_key]

        for j, inner_key in enumerate(inner_dict):
            img = inner_dict[inner_key]
            ax = axes[j, i] if n_inner > 1 else axes[i]
            ax.imshow(img)
            ax.set_title(f'{outer_key} - {inner_key}')
            ax.axis('off')

    plt.tight_layout()
    plt.show()
    

In [None]:
# idx = 1 has positive outcomes, idx=2 negative
# idx = 2
def vis_batch_idx(batch, idx, imgs):
    wrapper = WrapperModel(cilpp)
    target_layers = [next(cilpp.state_embedding.modules())['frame']['backbone'].resnet.layer4[-1]]

    input_tensor = wrapper.prepare_input_tensor(batch)[idx: idx+1]

    def get_vis(cam_cls, title, targets):
        with cam_cls(model=wrapper, target_layers=target_layers, use_cuda=True) as cam:
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(imgs[idx].numpy(), grayscale_cam, use_rgb=True)
        return visualization

    plt.close()
    classes = [
               # GradCAMPlusPlus, 
               # GradCAM, 
               NegPosGradCAM,
               # HiResCAM, 
               #ScoreCAM, 
               #AblationCAM, 
               # XGradCAM, 
               # EigenCAM, 
               #FullGrad,
              ]
    to_plot = {}
    for target_cls in [RawAccelerationTarget, RawSteeringTarget]:
        target_to_plot = to_plot.setdefault(target_cls.__name__, {})
        for cam_cls in classes:
            try:
                target_to_plot[cam_cls.__name__] = get_vis(cam_cls, targets=[target_cls()], title=f"{cam_cls.__name__}: {target_cls.__name__}")
            except (RuntimeError, AttributeError) as exc:
                print(f"Can't do {cam_cls.__name__}: {target_cls.__name__}")
                print(exc)
            finally:
                print("-"*60)
    
    plot_images(to_plot, 4)
    print((to_plot[RawAccelerationTarget.__name__][NegPosGradCAM.__name__] == to_plot[RawSteeringTarget.__name__][NegPosGradCAM.__name__]).sum(), to_plot[RawAccelerationTarget.__name__][NegPosGradCAM.__name__].shape)
    torch.cuda.empty_cache()

print(imgs.shape)
for idx in range(imgs.shape[0]):
    print(">>>", idx)
    vis_batch_idx(batch, idx, imgs)
    print("="*80)