In [1]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from skimage.feature import blob_dog, blob_doh, blob_log

import torch
from torch import nn
from torchvision import transforms as pth_transforms

import vision_transformer as vits

## Load model

In [2]:
arch = 'vit_small'
patch_size = 8
pretrained_weights = 'vit_small'
checkpoint_key = None
image_path = 'surf.png'

image_size = (592, 1184)
output_dir = '.'
threshold = None

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# build model
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)
if os.path.isfile(pretrained_weights):
    state_dict = torch.load(pretrained_weights, map_location="cpu")
    if checkpoint_key is not None and checkpoint_key in state_dict:
        print(f"Take key {checkpoint_key} in provided checkpoint dict")
        state_dict = state_dict[checkpoint_key]
    # remove `module.` prefix
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    # remove `backbone.` prefix induced by multicrop wrapper
    state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
    msg = model.load_state_dict(state_dict, strict=False)
    print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
else:
    print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
    url = None
    if arch == "vit_small" and patch_size == 16:
        url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
    elif arch == "vit_small" and patch_size == 8:
        url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"  # model used for visualizations in our paper
    elif arch == "vit_base" and patch_size == 16:
        url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
    elif arch == "vit_base" and patch_size == 8:
        url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
    if url is not None:
        print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
        state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
        model.load_state_dict(state_dict, strict=True)
    else:
        print("There is no reference weights available for this model => We use random weights.")

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.


In [4]:
def load_image(file, image_size):
    '''loads picture and applies torch transforms'''
    img = cv2.imread(file)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    transform = pth_transforms.Compose([
        pth_transforms.Resize(image_size), # Do we need this?
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    img = transform(img)
    
    return img

In [5]:
def patch_image(img, patch_size):
    '''Reshapes image to be divisible by the patch size'''
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0)

    w_featmap = img.shape[-2] // patch_size
    h_featmap = img.shape[-1] // patch_size
    
    return img, w_featmap, h_featmap

In [6]:
def mask_image(model, img, patch_size):
    '''Passes image through model and returns attention mask'''
    img, w_featmap, h_featmap = patch_image(img, patch_size)
    attentions = model.get_last_selfattention(img.to(device))
    nh = attentions.shape[1] # number of head

    # we keep only the output patch attention
    heatmap = attentions[0, :, 0, 1:].reshape(nh, -1)
    heatmap = heatmap.reshape(nh, w_featmap, h_featmap)
    
    # interpolate attention mask back to original image size
    heatmap = nn.functional.interpolate(heatmap.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
    
    return heatmap.sum(0).astype('double')

In [7]:
def get_num_blobs(heatimg):
    '''
    predicts the number of surfers as the number of 
    blobs found in the attention mask of an image
    '''
    blobs = blob_doh(heatimg*256, threshold=0.05, min_sigma = 10, max_sigma=50)
    
    return len(blobs)

In [8]:
def predict(row, image_size=(592, 1184), patch_size=8):
    img = load_img(row['filename'], image_size)
    global model
    heatimg = mask_image(model, img, patch_size)
    num_blobs = get_num_blobs(heatimg)
    row['pred_num_surfers'] = num_blobs
    
    return row

In [9]:
# Get predictions for each image like this?
df = df.apply(lambda row: predict(row), axis=1)

NameError: name 'df' is not defined