In [1]:
import os
import sys
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms as pth_transforms

sys.path.append('../input/vision-transformer-scripts')
import vision_transformer as vits

In [2]:
def preprocess_image(image_path, tf, patch_size):
    # read image -> convert to RGB -> torch Tensor
    rgb_img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    img = tf(rgb_img)
    _, image_height, image_width = img.shape
    
    # make the image divisible by the patch size
    w, h = image_width - image_width % patch_size, image_height - image_height % patch_size
    img = img[:, :h, :w].unsqueeze(0)
    
    w_featmap = img.shape[-1] // patch_size
    h_featmap = img.shape[-2] // patch_size
    return rgb_img, img, w_featmap, h_featmap

In [3]:
def calculate_threshold_attention(attention, threshold, w_featmap, h_featmap, patch_size, mode = 'nearest'):
    # we keep only a certain percentage of the mass
    val, idx = torch.sort(attention)
    val /= torch.sum(val, dim=1, keepdim=True )
    cumval = torch.cumsum(val, dim=1)
    th_attn = cumval > (1 - threshold)
    idx2 = torch.argsort(idx)

    # filter each head
    nh = attention.shape[0]
    for head in range(nh):
        th_attn[head] = th_attn[head][idx2[head]]

    # interpolate    
    th_attn = th_attn.reshape(nh, h_featmap, w_featmap).float()
    th_attn = F.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
    return th_attn

In [4]:
def calculate_attentions(img, w_featmap, h_featmap, patch_size, calc_th_attn = True, threshold = 0.6, mode = 'nearest' ):
    attentions = model.get_last_selfattention(img.to(device))
    nh = attentions.shape[1]
    
    # we keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
    
    if(calc_th_attn == True):
        th_attn = calculate_threshold_attention(attentions, threshold, w_featmap, h_featmap, patch_size, mode = mode)
    else:
        th_attn = None

    attentions = attentions.reshape(nh, h_featmap, w_featmap)
    attentions = F.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode=mode)[0].cpu().numpy()
    return attentions, th_attn

In [5]:
def get_attention_masks(image_path, model, transform, patch_size, calc_th_attn=True, threshold = 0.6, mode = 'bilinear'):
    rgb_img, img, w_featmap, h_featmap = preprocess_image(image_path, transform, patch_size)
    attentions, th_attn = calculate_attentions(img, w_featmap, h_featmap, patch_size, 
                               calc_th_attn = calc_th_attn, threshold = threshold, mode = mode)
    return rgb_img, attentions, th_attn

In [6]:
threshold = 0.6
patch_size = 8
arch = 'vit_small'
output_dir = '.'
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"

transform = pth_transforms.Compose([
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)

for p in model.parameters():
    p.requires_grad = False
model.eval();
model.to(device);

state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)

Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall8_300ep_pretrain.pth


  0%|          | 0.00/82.7M [00:00<?, ?B/s]

<All keys matched successfully>

In [8]:
input_dir = '../input/image-segmentation/cityScapes_256_512/demoVideo/'
image_list = sorted(os.listdir(input_dir))
images_path = [os.path.join(input_dir, x) for x in image_list]

In [9]:
font = {'family' : 'normal', 'weight' : 'bold', 'size'   : 9}
plt.rc('font', **font)
plt.rcParams['text.color'] = 'white'

In [10]:
%matplotlib agg
fig, axes = plt.subplots(2,3, figsize=(20,6.71))
axes = axes.flatten()

for image_path in tqdm(images_path):
    image_name = image_path.split(os.sep)[-1].split('.')[0]
    rgb_img, attentions, th_attn = get_attention_masks(image_path, model, transform, 
                                           patch_size, calc_th_attn=False, mode = 'bilinear')
    
    for i in range(6):
        axes[i].clear()
        axes[i].imshow(rgb_img)
        axes[i].imshow(attentions[i], cmap='inferno', alpha=0.5)
        axes[i].axis('off')
        axes[i].set_title(f"ATTENTION HEAD {i+1}", x= 0.19, y=0.9, va="top");

    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(f'{image_name}.png', bbox_inches='tight')

100%|██████████| 599/599 [12:06<00:00,  1.21s/it]


In [11]:
def convert_images_to_video(images_dir, output_video_path, fps : int = 20):
    
    input_images = [os.path.join(images_dir, *[x]) for x in sorted(os.listdir(images_dir)) if x.endswith('png')]
    
    if(len(input_images) > 0):
        sample_image = cv2.imread(input_images[0])
        height, width, _ = sample_image.shape
        
        # handles for input output videos
        output_handle = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width, height))

        # create progress bar
        num_frames = int(len(input_images))
        pbar = tqdm(total = num_frames, position=0, leave=True)

        for i in tqdm(range(num_frames), position=0, leave=True):
            frame = cv2.imread(input_images[i])
            output_handle.write(frame)
            pbar.update(1)

        # release the output video handler
        output_handle.release()
                
    else:
        pass

In [12]:
def createDir(dirPath):
    if(not os.path.isdir(dirPath)):
        os.mkdir(dirPath)

In [13]:
video_output_dir = os.path.join(output_dir, *['videos'])
createDir(video_output_dir)
output_video_path = os.path.join(video_output_dir, *["Vit_last_stage_demoVideo.mp4"])
print(output_video_path)

./videos/Vit_last_stage_demoVideo.mp4


In [14]:
convert_images_to_video('./', output_video_path, fps=20)

OpenCV: FFMPEG: tag 0x58564944/'DIVX' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'
100%|██████████| 599/599 [00:24<00:00, 24.49it/s]
100%|██████████| 599/599 [00:24<00:00, 24.49it/s]
