In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, '../src')

In [5]:
import time
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from functools import lru_cache
from tqdm import tqdm
from devtools import debug
from ipywidgets import interact, widgets
from gulpio import GulpDirectory

In [6]:
from config.application import FeatureConfig
from config.jsonnet import load_jsonnet
from subset_samplers import ConstructiveRandomSampler
from attribution.online_shapley_value_attributor import OnlineShapleyAttributor
from torchvideo.samplers import FullVideoSampler

In [7]:
sns.set()

In [8]:
config_dict = load_jsonnet("../configs/feature_multiscale_trn.jsonnet")
config = FeatureConfig(**config_dict)

In [9]:
gulp_dir = GulpDirectory('../datasets/ssv2/gulp/validation/')

In [10]:
device = torch.device('cuda:0')

In [11]:
dataset_builder = config.dataset.instantiate()
dataset = dataset_builder.validation_dataset(sampler=FullVideoSampler())

OSError: Unable to open file (unable to open file: name = '../datasets/ssv2/features/trn.hdf', errno = 2, error message = 'No such file or directory', flags = 40, o_flags = 0)

In [44]:
class2str = dataset_builder.class2str()

In [25]:
model = config.get_model().eval().to(device)

In [26]:
class_priors = pd.read_csv('../datasets/ssv2/class-priors.csv', index_col='class', squeeze=True).values

In [27]:
@lru_cache(maxsize=1000)
def compute_approximate_esvs(uid: str, n_runs: int, max_samples: int):
    subset_sampler = ConstructiveRandomSampler(max_samples=max_samples, device=device)
    attributor = OnlineShapleyAttributor(
        single_scale_models=model.single_scale_models,
        priors=class_priors,
        n_classes=config.dataset.class_count,
        device=device,
        subset_sampler=subset_sampler,
    )
    
    run_results = []
    for _ in range(n_runs):
        start = time.time()
        esvs, scores = attributor.explain(torch.from_numpy(features).to(device))
        duration_ms = (time.time() - start) * 1000
        run_results.append({
            'esvs': esvs.cpu().numpy(),
            'scores': scores.cpu().numpy(),
            'duration_ms': duration_ms
        })
    return run_results

In [28]:
uids = list(dataset.label_sets['uid'])
labels = dataset.label_sets['action']

In [76]:
str2class = {v:k for k,v in class2str.items()}

In [86]:
cls = 'Pouring something into something'
class_example_idxs = (labels == str2class[cls]).nonzero()[0]
uid = uids[class_example_idxs[0]]
example_idx = uids.index(uid)
label = labels[example_idx]
uid

'100765'

In [87]:
features, label_dict = dataset[example_idx]
assert label_dict['uid'] == uid

In [90]:
def plot_esvs(n_runs: int, max_samples: int, n_iters: int, _cls: int, highlighed_frame: int = 1):
    run_results = compute_approximate_esvs(uid, n_runs, max_samples)
    entries = []
    for run_idx, result in enumerate(run_results):
        for frame_idx, esv in enumerate(result['esvs'][..., _cls]):
            entry = {
                'run_idx': run_idx,
                'frame': frame_idx + 1,
                'esv': esv
            }
            entries.append(entry)
    df = pd.DataFrame(entries)
    fig, axs = plt.subplots(figsize=(17, 5), ncols=3, constrained_layout=True)
    sns.lineplot(x='frame', y='esv', data=df, zorder=2, ax=axs[2])
    axs[0].imshow(gulp_dir[uid][0][highlighed_frame - 1])
    axs[0].grid(None)
    axs[0].axis('off')
    scores = np.stack([r['scores'] for r in run_results]).mean(axis=0)
    top_10_preds = scores.argsort()[::-1][:10]
    axs[1].barh([class2str[pred] for pred in top_10_preds][::-1], scores[top_10_preds][::-1])
    axs[1].set_title("Top-10 predications")
    axs[2].set_xlabel("Frame")
    axs[2].set_ylabel("ESV")
    axs[2].axhline(0, color='grey', zorder=1)
    axs[2].axvline(highlighed_frame, color='red')
    axs[2].set_title("Element Shapley Values")
    

style = {'description_width': 'initial'}
interact(
    plot_esvs,
    n_runs=widgets.Dropdown(
        options=[1, 2, 3, 5, 10],
        value=3,
        description="Number of times to run approximation",
        style=style
    ),
    max_samples=widgets.Dropdown(
        options=[128, 256, 512, 1024],
        value=512,
        description="Max # samples/scale in approximation",
        style=style
    ),
    n_iters=widgets.Dropdown(
        options=[1, 2, 4, 8], 
        value=1,
        description="Number of iterations in approximation",
        style=style
    ),
    _cls=widgets.Dropdown(
        options=[
            (name, cls)
            for cls, name in class2str.items()
        ],
        value=label_dict['action'],
        description="Class",
        style=style
    ),
    highlighed_frame=widgets.IntSlider(
        value=1,
        min=1,
        max=len(gulp_dir.merged_meta_dict[uid]['frame_info']),
        description="Frame",
        style=style
    )
)

interactive(children=(Dropdown(description='Number of times to run approximation', index=2, options=(1, 2, 3, …

<function __main__.plot_esvs(n_runs: int, max_samples: int, n_iters: int, _cls: int, highlighed_frame: int = 1)>