# NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering.

In this notebook we present how to upsample **Any Feature from Any VFM at Any Resolution**

## Setup

Imports

In [None]:
import random
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import torch
import PIL
import torch.nn.functional as F
import torchvision.transforms as T
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from IPython.display import clear_output

project_root = str(Path().absolute().parent)
sys.path.append(project_root)

from utils.training import get_batch, get_dataloaders, load_multiple_backbones
from utils.visualization import plot_feats

Load config

In [None]:
if not GlobalHydra.instance().is_initialized():
    initialize(config_path="../config", version_base=None)

overrides = ["val_dataloader.batch_size=1", 
             "train_dataloader.batch_size=1", 
             "model=naf", 
             "img_size=448"]

cfg = compose(config_name="base", overrides=overrides)
clear_output()

Dataloader

In [None]:
train_dataloader, val_dataloader = get_dataloaders(cfg, shuffle=True)
clear_output()

Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
model.cuda()

clear_output()

Data

In [None]:
# Either iterate over the dataloader
batch = next(iter(val_dataloader))
batch = get_batch(batch, device)
img_batch = batch["image"]
bs = img_batch.shape[0]
IMG_PATH = None

# Or load a custom image
IMG_PATH = "../asset/dinov3.png"
img_batch = PIL.Image.open(IMG_PATH).convert("RGB")
img_batch = val_dataloader.dataset.transform(img_batch).unsqueeze(0).to(device)

Backbone

In [None]:
@torch.no_grad()
def upsample_backbone(backbone, img, model, mean_std_bck, mean_std_ups, sizes=[512]):
    torch.cuda.empty_cache()
    model.eval()
    mean_bck, std_bck = mean_std_bck
    mean_ups, std_ups = mean_std_ups

    img_bck = T.functional.normalize(img, mean=mean_bck, std=std_bck)
    img_ups = T.functional.normalize(img, mean=mean_ups, std=std_ups)

    hr_feats = backbone(img_bck)
    hr_size = hr_feats.shape[-1]
    preds = []
    for size in sizes:
        pred = model(img_ups, hr_feats, (size, size))
        preds.append(pred)

    hr_feats_ = torch.nn.functional.interpolate(hr_feats, size, mode="nearest-exact")
    plot_feats(
        img_batch[0],
        hr_feats_[0],
        [p[0] for p in preds],
        legend=[f"Input", f"{hr_size}x{hr_size}"] + [f"{size}x{size}" for size in sizes],
        font_size=20,
    )
    plt.show()
    torch.cuda.empty_cache()

## Any Backbone

In [None]:
backbone_configs = [
    {"name": "vit_base_patch16_dinov3.lvd1689m"},
    {"name": "radio_v2.5-b"},
    {"name": "franca_vitb14"},
    {"name": "vit_base_patch14_reg4_dinov2"},
    {"name": "vit_base_patch14_dinov2.lvd142m"},
]
mapping = {
    "vit_base_patch16_dinov3.lvd1689m": "Dinov3-B",
    "radio_v2.5-b": "Radio-v2.5-B",
    "franca_vitb14": "Franca-B14",
    "vit_base_patch14_reg4_dinov2": "Dinov2-R-B",
    "vit_base_patch14_dinov2.lvd142m": "Dinov2-B",
}
backbones, *_ = load_multiple_backbones(cfg, backbone_configs, device)
clear_output()

In [None]:
mean_ups, std_ups = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
for backbone in backbones:
    backbone.cuda().eval()

    mean_bck, std_bck = backbone.config["mean"], backbone.config["std"]

    print(f"BACKBONE: {mapping[backbone_configs[backbones.index(backbone)]['name']].upper()}")
    upsample_backbone(backbone, img_batch, model, (mean_bck, std_bck), (mean_ups, std_ups), sizes=[448])

## Any Resolution

In [None]:
backbone_configs = [
    {"name": "vit_base_patch16_dinov3.lvd1689m"},
]
backbones, *_ = load_multiple_backbones(cfg, backbone_configs, device)
clear_output()

In [None]:
backbone = backbones[0]
backbone.cuda().eval()

mean_bck, std_bck = backbone.config["mean"], backbone.config["std"]

print(f"BACKBONE: {backbone_configs[backbones.index(backbone)]['name']}")
upsample_backbone(backbone, img_batch, model, (mean_bck, std_bck), (mean_ups, std_ups), sizes=[64, 128, 256, 512, 1024])