In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, '../src')

In [3]:
import logging

import pandas as pd
import numpy as np
import torch

from gulpio import GulpDirectory
from torch.utils.data import Subset
from tqdm import tqdm
from torchvideo.samplers import frame_idx_to_list

from config.jsonnet import load_jsonnet
from config.application import FeatureConfig

from attribution.online_shapley_value_attributor import OnlineShapleyAttributor
from subset_samplers import ConstructiveRandomSampler, ExhaustiveSubsetSampler

from ipython_media import display_video

In [4]:
config_dict = load_jsonnet("../configs/feature_multiscale_trn.jsonnet")
config = FeatureConfig(**config_dict)

In [5]:
device = torch.device('cuda:0')

In [6]:
train_sampler = config.frame_samplers.train.instantiate()
test_sampler = config.frame_samplers.test.instantiate()
dataset = config.dataset.instantiate()
val = dataset.validation_dataset(sampler=test_sampler)
model = config.get_model().to(device).eval()
class2str = dataset.class2str()

In [7]:
gulp_dir = GulpDirectory('../datasets/ssv2/gulp/validation/')

In [8]:
class_priors = pd.read_csv('../datasets/ssv2/class-priors.csv', index_col='class', squeeze=True).sort_index().values
class_priors.shape

(174,)

In [9]:
subset_sampler = ExhaustiveSubsetSampler(device=device)
#subset_sampler = ConstructiveRandomSampler(max_samples=128, device=device)

In [10]:
attributor = OnlineShapleyAttributor(
    list(model.single_scale_models),
    priors=class_priors,
    n_classes=config.dataset.class_count,
    device=device,
    subset_sampler=subset_sampler,
)

In [11]:
frame_features, label_dict = val[8]
uid = label_dict['uid']
gt = label_dict['action']
video = np.stack(gulp_dir[uid][0])
sampled_frame_idxs = frame_idx_to_list(test_sampler.sample(len(video)))

print(frame_features.shape)
print(label_dict)

(8, 256)
{'action': 46, 'label': 46, 'uid': '100049'}


In [12]:
display_video(video, fps=14)

                                                   

(29, 240, 426, 3)
Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4

Moviepy - Done !
Moviepy - video ready __temp__.mp4




In [13]:
display_video(video[sampled_frame_idxs], fps=2)

                                                  

(8, 240, 426, 3)
Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4

Moviepy - Done !
Moviepy - video ready __temp__.mp4


In [14]:
with torch.no_grad():
    logits = model(torch.from_numpy(frame_features).to(device))
probs = torch.softmax(logits, -1)
pred = logits.argmax().item()
pred, logits[pred].item(), probs[pred].item()

(46, 10.634759902954102, 0.990464448928833)

In [15]:
class2str[label_dict['action']], class2str[pred]

('Opening something', 'Opening something')

In [16]:
frame_features.shape

(8, 256)

In [17]:
esvs, scores = attributor.explain(torch.from_numpy(frame_features).to(device))

In [19]:
esvs.shape

torch.Size([8, 174])

In [22]:
esvs[:, label_dict['action']].cpu().numpy()

array([0.11227177, 0.14852181, 0.14085133, 0.07698814, 0.1285285 ,
       0.12786841, 0.12285381, 0.12516265], dtype=float32)