In [None]:
import sys
sys.path.append("../")

import os
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
from timm.data import create_transform
import PIL.Image as Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from src.data.transforms import transforms_imagenet_eval
from src.models.entropy_utils import select_patches_by_threshold, visualize_selected_patches_cv2

split = "val"
image_size = 512
data_dir = os.path.join("/edrive1/rchoudhu/ILSVRC2012", split)

val_transform = transforms_imagenet_eval(img_size=image_size)
data_val = datasets.ImageFolder(
    root=f"{data_dir}",
    transform=val_transform
)

unnorm = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)


img, label = data_val[28582]
# Undo imagenet norm, then permute dims. 
img = unnorm(img).permute(1,2,0)

print("Image shape: ", img.shape)
img_vis = np.array(img*255).astype(np.uint8)
img_vis = Image.fromarray(img_vis)
img_vis

In [None]:
H, W, C = img.shape

img_tensor = img.permute(2, 0, 1).unsqueeze(0)
img_down2 = F.interpolate(img_tensor, scale_factor=0.5, mode='bilinear', align_corners=False)
img_down4 = F.interpolate(img_tensor, scale_factor=0.25, mode='bilinear', align_corners=False)
#  Upsample back up 
img_up2 = F.interpolate(img_down2, scale_factor=2, mode='bilinear', align_corners=False)
img_up4 = F.interpolate(img_down4, scale_factor=4, mode='bilinear', align_corners=False)

img_vis = img_up2.squeeze(0).permute(1, 2, 0).numpy()
img_vis = np.array(img_vis*255).astype(np.uint8)
img_vis = Image.fromarray(img_vis)


patch_size = 16
mse_map = F.mse_loss(img_tensor, img_up2, reduction='none').squeeze(0).permute(1, 2, 0)
patches16 = mse_map.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)

patch_size = 32
mse_map = F.mse_loss(img_tensor, img_up2, reduction='none').squeeze(0).permute(1, 2, 0)
patches = mse_map.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)

patch_size = 64
mse_map = F.mse_loss(img_tensor, img_up4, reduction='none').squeeze(0).permute(1, 2, 0)
patches64 = mse_map.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)

print("MSE map shape: ", mse_map.shape)
print("pathces shape: ", patches.shape)
print("patch16 shape: ", patches16.shape)
print("patch64 shape: ", patches64.shape)

In [None]:
responses = patches.mean(dim=(3, 4)).mean(-1).unsqueeze(0)
responses16 = patches16.mean(dim=(3, 4)).mean(-1).unsqueeze(0)
responses64 = patches64.mean(dim=(3, 4)).mean(-1).unsqueeze(0)

response_map = {16: responses16, 32: responses, 64: responses64}

masks =select_patches_by_threshold(response_map, thresholds=[0.0001, 0.0001])
for k, mask in masks.items():
    masks[k] = mask.squeeze()

num_tokens = sum([m.sum().int().item() for _, m in masks.items()])
num_base_tokens = masks[16].numel()

print("Fraction retained: {}".format(num_tokens / num_base_tokens))

In [None]:
vis_img = visualize_selected_patches_cv2(
    img_tensor[0]*255, 
    masks=masks,
    patch_sizes=list(masks.keys()),
    color=(128,128,255)
)

In [None]:
vis_img