In [None]:
#%%
from org3dresnet.main import (get_inference_utils,
                              generate_model, resume_model)
from org3dresnet.model import generate_model, make_data_parallel
import org3dresnet

import torch
import numpy as np
import cv2
from torch.backends import cudnn
import torchvision
from IPython.display import HTML
from torchvision.transforms.transforms import Normalize, ToPILImage
from torchvision.transforms import transforms
from PIL import Image

import os
import sys
import json
import glob
from argparse import Namespace
from pathlib import Path
from time import time
from tqdm import tqdm

sys.path.append("/workspace/src/")

from utils.visualization import visualize
from utils.utils import video_to_html, numericalSort
from utils.sensitivity_analysis import OcclusionSensitivityMap3D as OSM




In [None]:
#%%
opt_path = "/workspace/data/r3d_models/finetuning/ucf101/r3d50_K_fc/opts.json"
with open(opt_path, "r") as f:
    model_opt = json.load(f)
model_opt = Namespace(**model_opt)

model_opt.device = torch.device('cpu' if model_opt.no_cuda else 'cuda')
if not model_opt.no_cuda:
    cudnn.benchmark = True
if model_opt.accimage:
    torchvision.set_image_backend('accimage')

model_opt.ngpus_per_node = torch.cuda.device_count()

model = generate_model(model_opt)
model = resume_model(model_opt.resume_path, model_opt.arch, model)
model = make_data_parallel(model, model_opt.distributed, model_opt.device)
model.eval()

model_opt.inference_batch_size = 1
for attribute in dir(model_opt):
    if "path" in str(attribute) and getattr(model_opt, str(attribute)) != None:
        setattr(model_opt, str(attribute), Path(
            getattr(model_opt, str(attribute))))
inference_loader, inference_class_names = get_inference_utils(model_opt)

class_labels_map = {v.lower(): k for k, v in inference_class_names.items()}



In [None]:
#%%
inputs, targets = iter(inference_loader).__next__()
video_size = inputs[[0]].shape
transform = inference_loader.dataset.spatial_transform

_transforms = transform.transforms
idx = [type(i) for i in _transforms].index(
    org3dresnet.spatial_transforms.Normalize)
normalize = _transforms[idx]
mean = torch.tensor(normalize.mean)
std = torch.tensor(normalize.std)

unnormalize = transforms.Compose(
    [
        Normalize((-mean / std).tolist(), (1 / std).tolist()),
        ToPILImage(),
    ]
)


In [None]:
#%%
def load_jpg(l, g, c, n):
    name = inference_class_names[l]
    dir = os.path.join("/workspace/data/ucf101/jpg", name,
                       "v_{}_g{}_c{}".format(
                           name, str(g).zfill(2), str(c).zfill(2))
                       )
    path = sorted(glob.glob(dir + "/*"), key=numericalSort)

    target_path = path[n*16:(n+1)*16]
    if len(target_path) < 16:
        print("not exist")
        return False

    video = []
    for _p in target_path:
        video.append(transform(Image.open(_p)))

    return torch.stack(video)
    


In [None]:
#%%
spatial_crop_sizes = [16]
temporal_crop_sizes = [16]
spatial_stride = 8
temporal_stride = 2

aosa_single = OSM(
    net=model,
    video_size=video_size,
    device=model_opt.device,
    spatial_crop_sizes=spatial_crop_sizes,
    temporal_crop_sizes=temporal_crop_sizes,
    spatial_stride=spatial_stride,
    temporal_stride=temporal_stride,
    transform=transform,
    batchsize=400,
    N_stack_mask=1,
)

aosa = OSM(
    net=model,
    video_size=video_size,
    device=model_opt.device,
    spatial_crop_sizes=spatial_crop_sizes,
    temporal_crop_sizes=temporal_crop_sizes,
    spatial_stride=spatial_stride,
    temporal_stride=temporal_stride,
    transform=transform,
    batchsize=400,
    N_stack_mask=3,
)




In [None]:
#%%
l = 21
g = 1  # > 0
c = 1  # > 0
n = 1

video = load_jpg(l, g, c, n).transpose(0, 1)
target = l
with torch.inference_mode():
    pred = model(video.unsqueeze(0)).cpu().numpy().argmax()
video_orgimg = []

for i in range(video_size[2]):
    img = video.squeeze().transpose(0, 1)[i]
    video_orgimg.append(np.array(unnormalize(img)))
video_orgimg = np.array(video_orgimg)

start = time()
aosa_single_map = aosa_single.run(video, target)
print(time() - start)

start = time()
aosa_map = aosa.run(video, target)
print(time() - start)


print("{0} temporal crop sizes, {1} spatial crop sizes, {2} dummy list".format(
    len(aosa_single_map), len(aosa_single_map[0]), len(aosa_single_map[0][0])))
print("heatmap size: ", aosa_single_map[0][0][0].shape)

title = "{} (pred: {})".format(
    inference_class_names[l], inference_class_names[pred])


s = visualize([aosa_single_map[0][0][0], aosa_map[0][0][0]], video_orgimg,
    title=title,
    )

HTML(s)



In [None]:
#%%