<a href="https://colab.research.google.com/github/yiyixuxu/TimeSformer/blob/main/visualizing_space_time_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup

In [9]:
# export
from pathlib import Path
from timesformer.models.vit import *
from timesformer.datasets import utils as utils
from timesformer.config.defaults import get_cfg
from einops import rearrange, repeat, reduce
import cv2
#from google.colab.patches import cv2_imshow
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import matplotlib.pyplot as plt

# Utilities

In [10]:
# export
DEFAULT_MEAN = [0.45, 0.45, 0.45]
DEFAULT_STD = [0.225, 0.225, 0.225]

# convert video path to input tensor for model
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(DEFAULT_MEAN,DEFAULT_STD),
    transforms.Resize(224),
    transforms.CenterCrop(224),
])

# convert the video path to input for cv2_imshow()
transform_plot = transforms.Compose([
    lambda p: cv2.imread(str(p),cv2.IMREAD_COLOR),
    transforms.ToTensor(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    lambda x: rearrange(x*255, 'c h w -> h w c').numpy()
])


def get_frames(path_to_video, num_frames=8):
  "return a list of paths to the frames of sampled from the video"
  path_to_frames = list(path_to_video.iterdir())
  path_to_frames.sort(key=lambda f: int(f.with_suffix('').name[-6:]))
  assert num_frames <= len(path_to_frames), "num_frames can't exceed the number of frames extracted from videos"
  if len(path_to_frames) == num_frames:
    return(path_to_frames)
  else:
    video_length = len(path_to_frames)
    seg_size = float(video_length - 1) / num_frames 
    seq = []
    for i in range(num_frames):
      start = int(np.round(seg_size * i))
      end = int(np.round(seg_size * (i + 1)))
      seq.append((start + end) // 2)
      path_to_frames_new = [path_to_frames[p] for p in seq]
    return(path_to_frames_new)

def create_video_input(path_to_video):
  "create the input tensor for TimeSformer model"
  path_to_frames = get_frames(path_to_video)
  frames = [transform(cv2.imread(str(p), cv2.IMREAD_COLOR)) for p in path_to_frames]
  frames = torch.stack(frames, dim=0)
  frames = rearrange(frames, 't c h w -> c t h w')
  frames = frames.unsqueeze(dim=0)
  return(frames)

def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)
    #return cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2GRAY)

def create_masks(masks_in, np_imgs):
  # 8 (224, 224, 3)
  # (8, 14, 14) 
  masks = []
  for mask, img in zip(masks_in, np_imgs):
    # (14, 14)
    mask= cv2.resize(mask, (img.shape[1], img.shape[0]))
    mask = show_mask_on_image(img, mask)
    masks.append(mask)
  return(masks)

# T+S

In [11]:
# export
def combine_divided_attention(attn_t, attn_s):
  ## time attention
    # average time attention weights across heads
  print("attn_t before:", attn_t.shape) # attn_t before: torch.Size([196, 12, 8, 8])
  attn_t = attn_t.mean(dim = 1) 
  print("attn_t after mean:", attn_t.shape) #attn_t after mean: torch.Size([196, 8, 8])
 
    # add cls_token to attn_t as an identity matrix since it only attends to itself 
  I = torch.eye(attn_t.size(-1)).unsqueeze(0) 
  attn_t = torch.cat([I,attn_t], 0) #attn_t after mean: torch.Size([197, 8, 8])
    # adding identity matrix to account for skipped connection 
  attn_t = attn_t +  torch.eye(attn_t.size(-1))[None,...]
    # renormalize
  attn_t = attn_t / attn_t.sum(-1)[...,None]
  print("att_T", attn_t.shape) # att_T torch.Size([197, 8, 8])


  ## space attention
   # average across heads
  print("attn_s before:", attn_s.shape) # attn_s before: torch.Size([8, 12, 197, 197])
  attn_s = attn_s.mean(dim = 1)
  print("attn_s after mean:", attn_s.shape) #attn_s after mean: torch.Size([8, 197, 197])
   # adding residual and renormalize 
  attn_s = attn_s +  torch.eye(attn_s.size(-1))[None,...]
  attn_s = attn_s / attn_s.sum(-1)[...,None]
  print("att_S", attn_s.shape) # att_S torch.Size([8, 197, 197])

  ## combine the space and time attention
  attn_ts = einsum('tpk, ktq -> ptkq', attn_s, attn_t)
  print("att_TS 1: ", attn_ts.shape)  # att_TS 1:  torch.Size([197, 8, 197, 8])
  
  ## average the cls_token attention across the frames
   # splice out the attention for cls_token
  attn_cls = attn_ts[0,:,:,:] 
  print("attn_cls:", attn_cls.shape) # attn_cls: torch.Size([8, 197, 8])
   # average the cls_token attention and repeat across the frames
  attn_cls_a = attn_cls.mean(dim=0)  # attn_cls_a after mean: torch.Size([197, 8])
  attn_cls_a = repeat(attn_cls_a, 'p t -> j p t', j = 8)
  print("attn_cls_a:", attn_cls_a.shape) # attn_cls_a: torch.Size([8, 197, 8])

   # add it back
  attn_ts = torch.cat([attn_cls_a.unsqueeze(0),attn_ts[1:,:,:,:]],0)
  print("att_TS 2: ", attn_ts.shape) #att_TS 2:  torch.Size([197, 8, 197, 8])
  return(attn_ts)

class DividedAttentionRollout():
  def __init__(self, model, **kwargs):
    self.model = model
    self.hooks = []

  def get_attn_t(self, module, input, output):
    self.time_attentions.append(output.detach().cpu())
  def get_attn_s(self, module, input, output):
    self.space_attentions.append(output.detach().cpu())

  def remove_hooks(self): 
    for h in self.hooks: h.remove()
    
  def __call__(self, path_to_video):
    input_tensor = create_video_input(path_to_video)
    self.model.zero_grad()
    self.time_attentions = []
    self.space_attentions = []
    self.attentions = []
    for name, m in self.model.named_modules():
      if 'temporal_attn.attn_drop' in name:
        self.hooks.append(m.register_forward_hook(self.get_attn_t))
      elif 'attn.attn_drop' in name:
        self.hooks.append(m.register_forward_hook(self.get_attn_s))

    preds = self.model(input_tensor)
    for h in self.hooks: h.remove()
    
    print("self_Att", len(self.space_attentions)) # self_Att 12
  
    for attn_t,attn_s in zip(self.time_attentions, self.space_attentions):
                                    # attn_ts = einsum('tpk, ktq -> ptkq', attn_s, attn_t) 
      print("attn_s", attn_s.shape) # attn_s torch.Size([8, 12, 197, 197])
      print("attn_t", attn_t.shape) # attn_t torch.Size([196, 12, 8, 8])
      
      print("combined attention", combine_divided_attention(attn_t,attn_s).shape) # combined attention torch.Size([197, 8, 197, 8])
      self.attentions.append(combine_divided_attention(attn_t,attn_s))

    print("list attentions after combine:", len(self.attentions), self.attentions[0].shape) 
    # list attentions after combine: 12 torch.Size([197, 8, 197, 8])

    p,t = self.attentions[0].shape[0], self.attentions[0].shape[1]
    print(p) #197
    print(t) #8
    result = torch.eye(p*t)

    print("Result", result.shape) #torch.Size([1576, 1576])

    for attention in self.attentions:
      attention = rearrange(attention, 'p1 t1 p2 t2 -> (p1 t1) (p2 t2)')
      print("att after rearrange", attention.shape) # att after rearrange torch.Size([1576, 1576])
      result = torch.matmul(attention, result) #torch.Size([1576, 1576])

    mask = rearrange(result, '(p1 t1) (p2 t2) -> p1 t1 p2 t2', p1 = p, p2=p)
    print("mask 1:", mask.shape) #mask 1: torch.Size([197, 8, 197, 8])
    mask = mask.mean(dim=1)
    print("mask 2:", mask.shape) #mask 2: torch.Size([197, 197, 8])
    mask = mask[0,1:,:]
    print("mask 3:", mask.shape) #mask 3: torch.Size([196, 8])
    width = int(mask.size(0)**0.5)
    print("width:", width) # width: 14
    mask = rearrange(mask, '(h w) t -> h w t', w = width).numpy()
    print("mask 4:", mask.shape) # mask 4: (14, 14, 8)
    mask = mask / np.max(mask)
    print("mask 5", mask.shape) # mask 5 (14, 14, 8)
    return(mask)



# load the pretrained model

download the pre-trainde model

In [None]:
#! wget https://dl.dropboxusercontent.com/s/tybhuml57y24wpm/TimeSformer_divST_8_224_SSv2.pyth

load the model

In [12]:
model_file = 'timesformer/TimeSformer_divST_8_224_SSv2.pyth'
Path(model_file).exists()

True

In [13]:
cfg = get_cfg()
cfg.merge_from_file('configs/SSv2/TimeSformer_divST_8_224.yaml')
cfg.TRAIN.ENABLE = False
cfg.TIMESFORMER.PRETRAINED_MODEL = model_file
model = MODEL_REGISTRY.get('vit_base_patch16_224')(cfg)

read the labels

In [14]:
with open('example_data/labels.json') as f:
  ssv2_labels = json.load(f)
ssv2_labels = list(ssv2_labels.keys())

inference

In [17]:
path_to_video = Path('example_data/74225/')
path_to_video.exists()

True

In [18]:
with torch.set_grad_enabled(False):
  np.random.seed(cfg.RNG_SEED)
  torch.manual_seed(cfg.RNG_SEED)
  model.eval();
  pred = model(create_video_input(path_to_video)).cpu().detach()



In [None]:
topk_scores, topk_label = torch.topk(pred, k=5, dim=-1)
for i in range(5):
  pred_name = ssv2_labels[topk_label.squeeze()[i].item()]
  print(f"Prediction index {i}: {pred_name:<25}, score: {topk_scores.squeeze()[i].item():.3f}")

# visualizing the learned space-time attention

Create a `DividedAttentionRollout` object (`att_roll`) and call it to get a mask for a given video

In [None]:
att_roll = DividedAttentionRollout(model)
masks = att_roll(path_to_video)

plot

In [None]:
np_imgs = [transform_plot(p) for p in get_frames(path_to_video)]
print(len(np_imgs), np_imgs[0].shape)
print(masks.shape)
masks_output = create_masks(list(rearrange(masks, 'h w t -> t h w')), np_imgs)
cv2.imshow('', np.hstack(masks_output))
#cv2.imwrite('original_imgs.jpg', np.hstack(np_imgs))
cv2.imwrite('divided_space_time_att.jpg', np.hstack(masks_output))
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
import matplotlib.pyplot as plt

color_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

def visualize_colormap(mask):
  min_val, max_val = np.min(mask), np.max(mask)  # Get min and max values of the mask
  print(min_val, max_val)
  gradient = np.linspace(min_val, max_val, 101)[:, None]  # Create a gradient from min to max

  _, ax = plt.subplots(figsize=(1, 10))  # Create a new figure with a custom size

  ax.imshow(gradient, aspect='auto', cmap=plt.get_cmap('jet'), origin='lower')  # Display with the 'jet' colormap
  ax.set_xticks([])  # Hide x-axis ticks
  ax.set_yticks(np.arange(0, 101, 20))  # Set y-axis ticks at regular intervals
  ax.tick_params(axis='y', labelsize=40)  # Increase font size of y-axis labels
  plt.show()  # Show the plot
  plt.savefig('color_map.png')

mask_values = masks.flatten()
visualize_colormap(mask_values)