In [98]:
import torch
import numpy as np
from models.esvs import _MTRN
from datasets.pickle_dataset import MultiPickleDataset
from torch.utils.data import DataLoader

from frame_sampling import RandomSampler
from torchvideo.samplers import frame_idx_to_list
from attribution.online_shapley_value_attributor import OnlineShapleyAttributor
import pandas as pd

from subset_samplers import ConstructiveRandomSampler

In [57]:
device = torch.device('cuda:0')
dtype = torch.float

n_frames = 4

def no_collate(args):
    return args

frame_sampler = RandomSampler(frame_count=n_frames, snippet_length=1, test=True)

In [128]:
models = [_MTRN(frame_count=i) for i in range(1,9)]
for j, m in enumerate(models):
    models[j].load_state_dict(torch.load(f'../datasets/epic/models/mtrn-frames={j+1}.pt'))

model = models[n_frames]

In [46]:
dataset = MultiPickleDataset('../datasets/epic/features/p01_features.pkl')
dataloader = DataLoader(dataset, batch_size=1)



In [47]:
# out = torch.softmax(torch.rand((1,97)), dim=-1)
data = iter(dataloader)

inp, lab = data.next()

In [60]:
def subsample_frames(video):
    video_length = len(video)
    if video_length < n_frames:
        raise ValueError(f"Video too short to sample {n_frames} from")
    sample_idxs = np.array(frame_idx_to_list(frame_sampler.sample(video_length)))
    return sample_idxs, video[sample_idxs]


# input_ = torch.cat(inp).to(dtype=dtype)

# input_.shape
# subsample_frames(inp.squeeze())
sample_idx, sample_video = subsample_frames(inp.squeeze())

In [85]:
with torch.no_grad():
    out = model(sample_video.to(device))

In [86]:
verbs = torch.softmax(out.cpu(), dim=-1)

nouns = torch.softmax(torch.rand((1,300)), dim=-1)


result_scores = torch.cat((verbs, nouns), dim=-1)
result_scores.shape

torch.Size([1, 397])

In [121]:
class_priors = pd.read_csv('../datasets/epic/labels/verb_class_priors.csv', index_col='verb_class')['prior'].values

In [122]:
attributor = OnlineShapleyAttributor(
    single_scale_models=models,
    priors=class_priors,
    n_classes=len(class_priors),
    device=device,
    subset_sampler=ConstructiveRandomSampler(max_samples=128, device=device)
)

In [131]:
esvs, _ = attributor.explain(sample_video.to(device))

In [140]:
noun_esvs = torch.softmax(torch.rand((1,n_frames,300)), dim=-1)

In [142]:
verb_esvs = esvs.cpu().unsqueeze(0)

In [143]:
result_esvs = torch.cat((verb_esvs, noun_esvs), dim=-1)

In [145]:
result_esvs.shape

torch.Size([1, 4, 397])

In [148]:
scores.cpu().unsqueeze(0), verbs

(tensor([[2.2395e-07, 3.3908e-07, 5.8338e-10, 1.0000e+00, 3.4592e-07, 5.1184e-09,
          2.7350e-08, 8.3875e-19, 1.9700e-10, 9.0772e-14, 5.9277e-14, 1.1191e-09,
          5.7197e-22, 6.9220e-10, 1.5641e-10, 6.7009e-14, 4.1518e-19, 9.9002e-11,
          1.8142e-18, 1.1444e-30, 2.1583e-19, 3.3834e-13, 6.4151e-19, 4.7155e-15,
          1.5246e-11, 1.6208e-30, 9.3500e-13, 2.4585e-22, 1.0455e-12, 4.0832e-13,
          7.2530e-29, 4.7336e-27, 3.0561e-22, 1.4567e-14, 1.6909e-13, 1.7889e-18,
          1.9405e-27, 1.8237e-32, 1.4351e-29, 6.0800e-15, 1.8333e-13, 1.2956e-12,
          1.6992e-19, 1.0244e-13, 2.1771e-22, 2.0248e-26, 1.2540e-13, 5.9344e-26,
          2.7389e-12, 3.4341e-18, 8.5247e-13, 4.5339e-19, 3.4088e-21, 9.9019e-24,
          9.7781e-14, 1.2768e-13, 1.3115e-29, 1.1716e-13, 2.8644e-13, 3.1215e-29,
          5.6449e-13, 3.9466e-13, 1.8200e-13, 1.5892e-20, 3.9636e-21, 3.9318e-12,
          1.4236e-12, 4.3685e-13, 2.3514e-13, 4.7388e-13, 2.0958e-13, 1.2054e-16,
          9.7227

In [152]:
inp.shape[1]

195

In [155]:
lab['narration_id'] = lab['narration_id'][0]

In [156]:
lab['narration_id']

'P01_01_0'