In [233]:
#!/usr/bin/env python

import os
from PIL import Image
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display
from torch.utils.data.dataloader import default_collate
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from core.dataset import Rescale, TransferTensorDict, EpicClasses
from core.tools import initialize

torch.multiprocessing.set_sharing_strategy("file_system")

In [234]:
cfg, model, dataset, device = initialize("config/config_vis.yaml")
epic_classes = EpicClasses(os.path.join(cfg.data_dir, "annotations"))

Initializing model...
Model initialized with imagenet weights
Freezing the batchnorms of Base Model RGB except first or new layers.
Model initialized with imagenet weights
Freezing the batchnorms of Base Model Audio except first or new layers.
Model initialized.
----------------------------------------------------------
Loading pre-trained weights /media/data/tridiv/epic/tbn_weights/attention/seen/epic_tbn_bninception_RGB_Audio_best.pth...
Done.
----------------------------------------------------------
Reading list of test videos...
Done.
----------------------------------------------------------
Creating the dataset using annotations/EPIC_train_action_labels.csv...
Done.
----------------------------------------------------------


In [236]:
layout = go.Layout(yaxis=dict(range=[0, 1]))

def visualize(model, dataset, index, device):
    dict_to_device = TransferTensorDict(device)
    data, target, _ = default_collate([dataset[index-1]])
    rgb_indices = data["indices"]["RGB"].numpy().squeeze()
    data = dict_to_device(data)
    model.eval()
    with torch.no_grad():
        out = model(data)
#     gt_weights = target["weights"].numpy().squeeze(0)
    weights = out["weights"].cpu().numpy()
    wts = []
    fig = make_subplots(rows=2, cols=3)
    for idx in range(weights.shape[0]):
        x = np.arange(weights.shape[2])
        wts = go.Scatter(x=x, y=weights[idx].squeeze(0))
        img = Image.open(os.path.join(cfg.data_dir, cfg.data.rgb.dir_prefix, data["vid_id"][0], "img_{:010d}.jpg".format(rgb_indices[idx])))
        img = img.resize((128,128))
        h, w = img.size
        fig.add_trace(go.Scatter(x=[0, w], y=[0, h], mode='markers'), row=1, col=idx+1)
        fig.add_layout_image(
            dict(
                source=img,
                x=40,
                y=128,
                sizex=w,
                sizey=h,
                opacity=1,
                layer="above"),
            row=1, col=idx+1
        )
        fig.add_trace(wts, row=2, col=idx+1)
        fig.update_yaxes(range=[0, 1], row=2, col=idx+1)
    fig.show()

In [237]:
interact(visualize, model=fixed(model), dataset=fixed(dataset), index=widgets.IntSlider(min=1, max=len(dataset), step=1, value=0), device=fixed(device))

interactive(children=(IntSlider(value=1, description='index', max=2398, min=1), Output()), _dom_classes=('widg…

<function __main__.visualize(model, dataset, index, device)>

In [45]:
img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]],
                    [[0, 255, 0], [0, 0, 255], [255, 0, 0]]
                   ], dtype=np.uint8)
img_rgb.shape

(2, 3, 3)