In [16]:
# export
from pathlib import Path
from timesformer.models.vit import *
from timesformer.datasets import utils as utils
from timesformer.config.defaults import get_cfg
from einops import rearrange, repeat, reduce
import cv2
#from google.colab.patches import cv2_imshow
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import matplotlib.pyplot as plt

In [17]:
# export
DEFAULT_MEAN = [0.45, 0.45, 0.45]
DEFAULT_STD = [0.225, 0.225, 0.225]

# convert video path to input tensor for model
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(DEFAULT_MEAN,DEFAULT_STD),
    transforms.Resize(224),
    transforms.CenterCrop(224),
])

# convert the video path to input for cv2_imshow()
transform_plot = transforms.Compose([
    lambda p: cv2.imread(str(p),cv2.IMREAD_COLOR),
    transforms.ToTensor(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    lambda x: rearrange(x*255, 'c h w -> h w c').numpy()
])


def get_frames(path_to_video, num_frames=8):
  "return a list of paths to the frames of sampled from the video"
  path_to_frames = list(path_to_video.iterdir())
  path_to_frames.sort(key=lambda f: int(f.with_suffix('').name[-6:]))
  assert num_frames <= len(path_to_frames), "num_frames can't exceed the number of frames extracted from videos"
  if len(path_to_frames) == num_frames:
    return(path_to_frames)
  else:
    video_length = len(path_to_frames)
    seg_size = float(video_length - 1) / num_frames 
    seq = []
    for i in range(num_frames):
      start = int(np.round(seg_size * i))
      end = int(np.round(seg_size * (i + 1)))
      seq.append((start + end) // 2)
      path_to_frames_new = [path_to_frames[p] for p in seq]
    return(path_to_frames_new)

def create_video_input(path_to_video):
  "create the input tensor for TimeSformer model"
  path_to_frames = get_frames(path_to_video)
  frames = [transform(cv2.imread(str(p), cv2.IMREAD_COLOR)) for p in path_to_frames]
  frames = torch.stack(frames, dim=0)
  frames = rearrange(frames, 't c h w -> c t h w')
  frames = frames.unsqueeze(dim=0)
  return(frames)

def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)
    #return cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2GRAY)

def create_masks(masks_in, np_imgs):
  masks = []
  for mask, img in zip(masks_in, np_imgs):
    mask= cv2.resize(mask, (img.shape[1], img.shape[0]))
    mask = show_mask_on_image(img, mask)
    masks.append(mask)
  return(masks)

In [18]:
# export
def space_only_attention_masks(attn_s):
    attn_s = attn_s.mean(dim = 1)
    # adding residual and renormalize 
    attn_s = attn_s + torch.eye(attn_s.size(-1))[None,...]
    attn_s = attn_s / attn_s.sum(-1)[...,None]

    attn_s = rearrange(attn_s, 't1 p1 p2 -> p1 p2 t1')
    return attn_s

class SpaceAttentionRollout():
    def __init__(self, model):
        self.model = model
        self.hooks = []

    def get_attn_s(self, module, input, output):
        self.space_attentions.append(output.detach().cpu())

    def remove_hooks(self): 
        for h in self.hooks: 
            h.remove()
    
    def __call__(self, path_to_video):
        input_tensor = create_video_input(path_to_video)
        self.model.zero_grad()
        self.space_attentions = []
        self.attentions = []
        for name, m in self.model.named_modules():
            if 'attn.attn_drop' in name:
                self.hooks.append(m.register_forward_hook(self.get_attn_s))

        preds = self.model(input_tensor)
        for h in self.hooks: 
            h.remove()
        
        print("space_attentions:", len(self.space_attentions))
        for attn_s in self.space_attentions:
            print(space_only_attention_masks(attn_s).shape)
            self.attentions.append(space_only_attention_masks(attn_s))

        p,t = self.attentions[0].shape[0], self.attentions[0].shape[2]

        print("self_att:", self.attentions[0].shape)
        print("p:", p)
        print("t:", t)

        result = torch.eye(p*t)

        for attention in self.attentions:
            attention = attention.unsqueeze(1).repeat(1, 8, 1, 1)
            print(attention.shape)
            attention = rearrange(attention, 'p1 t1 p2 t2-> (p1 t1) (p2 t2)')
            result = torch.matmul(attention, result)
            
        mask = rearrange(result, '(p1 t1) (p2 t2) -> p1 t1 p2 t2', p1 = p, p2=p)
        mask = mask.mean(dim=1)
        mask = mask[0,1:,:]
        print("mask:", mask.shape)
        print(mask.size(0))
        width = int(mask.size(0)**0.5)
        mask = rearrange(mask, '(h w) t -> h w t', w = width).numpy()
        print("mask rearrange:", mask.shape)
        mask = mask / np.max(mask)
        return mask


In [19]:
model_file = 'timesformer\TimeSformer_spaceOnly_8x32_224.pyth'
Path(model_file).exists()

True

In [20]:
cfg = get_cfg()
cfg.merge_from_file('configs/SSv2/TimeSformer_spaceOnly_8x32_224.yaml')
cfg.TRAIN.ENABLE = False
cfg.TIMESFORMER.PRETRAINED_MODEL = model_file
model = MODEL_REGISTRY.get('vit_base_patch16_224')(cfg)

In [21]:
model

vit_base_patch16_224(
  (model): VisionTransformer(
    (dropout): Dropout(p=0.0, inplace=False)
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=Fals

In [22]:
import pandas as pd

read_data = pd.read_csv('example_data\kinetics\kinetics_400_labels.csv')

In [23]:
kinetics_labels = list(read_data['name'])

In [24]:
len(kinetics_labels)

400

In [25]:
path_to_video = Path('example_data/74225/')
path_to_video.exists()

True

In [26]:
with torch.set_grad_enabled(False):
  np.random.seed(cfg.RNG_SEED)
  torch.manual_seed(cfg.RNG_SEED)
  model.eval();
  pred = model(create_video_input(path_to_video)).cpu().detach()



In [27]:
topk_scores, topk_label = torch.topk(pred, k=5, dim=-1)

print(topk_scores)
print(topk_label)

for i in range(5):
  pred_name = kinetics_labels[topk_label.squeeze()[i].item()]
  print(f"Prediction index {i}: {pred_name:<25}, score: {topk_scores.squeeze()[i].item():.3f}")

tensor([[1.4680, 1.4605, 1.1831, 1.1716, 1.1499]])
tensor([[ 31, 218, 325, 235,  37]])
Prediction index 0: bowling                  , score: 1.468
Prediction index 1: playing badminton        , score: 1.460
Prediction index 2: somersaulting            , score: 1.183
Prediction index 3: playing ice hockey       , score: 1.172
Prediction index 4: brushing teeth           , score: 1.150


In [28]:
att_roll = SpaceAttentionRollout(model)
masks = att_roll(path_to_video)

space_attentions: 12
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
torch.Size([197, 197, 8])
self_att: torch.Size([197, 197, 8])
p: 197
t: 8
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
torch.Size([197, 8, 197, 8])
mask: torch.Size([196, 8])
196
mask rearrange: (14, 14, 8)


In [29]:
np_imgs = [transform_plot(p) for p in get_frames(path_to_video)]
masks_output = create_masks(list(rearrange(masks, 'h w t -> t h w')), np_imgs)
cv2.imshow('', np.hstack(masks_output))
cv2.imwrite('space_only_att.jpg', np.hstack(masks_output))
cv2.waitKey(0)
cv2.destroyAllWindows()