In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from time import time_ns
from PIL import Image
from functools import partial
from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

from yoeo.main import get_dv2_model, get_upsampler_and_expr, get_hr_feats
from yoeo.utils import load_image, convert_image, to_numpy, closest_crop, Experiment, do_2D_pca, add_flash_attention

from yoeo.comparisons.lift import ViTLiFTExtractor
from yoeo.comparisons.strided import StridedDv2

from typing import Callable

SEED = 10672
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f308c8ed030>

In [2]:
torch.cuda.empty_cache()

DEVICE = "cuda:0"

In [3]:
def measure_mem_time(
    device: torch.device | str,
    featurise_fn: Callable
) -> tuple[float, float]:
    torch.cuda.reset_peak_memory_stats(device)  # s.t memory is accurate
    torch.cuda.synchronize(device)  # s.t time is accurate

    def _to_MB(x: int) -> float:
        return x / (1024**2)

    def _to_s(t: int) -> float:
        return t / 1e9

    start_m = torch.cuda.max_memory_allocated(device)
    start_t = time_ns()

    featurise_fn()

    end_m = torch.cuda.max_memory_allocated(device)
    torch.cuda.synchronize(device)
    end_t = time_ns()

    return _to_MB(end_m - start_m), _to_s(end_t - start_t)

def rescale(arr: np.ndarray, swap_channels: bool=True) -> np.ndarray:
    if swap_channels:
        arr = np.transpose(arr, (1, 2, 0))
    h, w, c = arr.shape
    flat = arr.reshape((h * w, c))
    rescaled_flat = MinMaxScaler(clip=True).fit_transform(flat)
    return rescaled_flat.reshape((h, w, c))

In [4]:
normal_dv2 = get_dv2_model(fit_3d=False, device=DEVICE)
dv2 = get_dv2_model(True, device=DEVICE)

model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../yoeo/models/configs/combined_no_shift.json"

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)
# upsampler = upsampler.half()

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


In [5]:
featup_jbu = torch.hub.load("mhamilton723/FeatUp", "dinov2", use_norm=True).to(DEVICE).eval()

Using cache found in /home/ronan/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [6]:
strided = StridedDv2('dinov2', 'vits14_reg', 1).to(DEVICE).eval()
strided.model = add_flash_attention(strided.model)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [7]:
lift_path = "../trained_models/lift/lift_dino_vits8.pth"
lift = ViTLiFTExtractor('dino_vits8', lift_path=lift_path, channel=384, facet='key')
# lift.extractor.model = add_flash_attention(lift.extractor.model)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dino_main


Loaded Backbone: dino_vits8
Loaded LiFT module from: ../trained_models/lift/lift_dino_vits8.pth


In [None]:
PATH = "fig_data/perf_landscape/church_compare.png"
img = Image.open(PATH).convert("RGB")

In [9]:
@torch.no_grad()
def _our_featurise(img: Image.Image, dv2: torch.nn.Module, upsampler: torch.nn.Module, expr: Experiment) -> np.ndarray:
    hr_feats = get_hr_feats(img, dv2, upsampler, DEVICE, n_ch_in=expr.n_ch_in)
    hr_feats_np = to_numpy(hr_feats)
    reduced_hr = hr_feats_np[:3]
    return reduced_hr

@torch.no_grad()
def _original_featurise(img: Image.Image, dv2: torch.nn.Module) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr)
    _, _, h, w = tensor.shape
    with torch.autocast("cuda", torch.float16):
        dino_feats = dv2.forward_features(tensor)['x_norm_patchtokens']
    n_patch_w, n_patch_h = w // 14, h // 14
    dino_feats = dino_feats.permute((0, 2, 1))
    dino_feats = dino_feats.reshape((1, -1, n_patch_h, n_patch_w,))
    dino_feats_np = to_numpy(dino_feats)
    return dino_feats_np

@torch.no_grad()
def _jbu_featurise(img: Image.Image, jbu: torch.nn.Module) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr)
    with torch.autocast("cuda", torch.float16):
        jbu_feats = jbu(tensor.to(torch.float32))
    jbu_feats = F.interpolate(jbu_feats, (_h, _w))
    jbu_feats_np = to_numpy(jbu_feats)
    return jbu_feats_np[:3]

@torch.no_grad()
def _lift_featurise(path: str, img: Image.Image, lift: ViTLiFTExtractor, n: int=3, patch_size: int=8) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width

    image_batch, _ = lift.preprocess(path, (_h, _w))
    image_batch = image_batch.to(DEVICE)

    with torch.autocast("cuda", torch.float16):
        lift_feats = lift.extract_descriptors_iterative_lift(image_batch, lift_iter=n)
    _, _, c = lift_feats.shape
    sf = int(patch_size / (2**n))
    reshaped = lift_feats.squeeze(0).T.reshape((c, _h // sf, _w //  sf))
    return to_numpy(reshaped)

@torch.no_grad()
def _strided_featurise(img: Image.Image, strided: StridedDv2) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr)
    with torch.autocast("cuda", torch.float16):
        strided_feats = strided(tensor)
    strided_feats = F.interpolate(strided_feats, (_h, _w))
    strided_feats_feats_np = to_numpy(strided_feats)
    return strided_feats_feats_np
    

In [10]:
our_featurise = partial(_our_featurise, img=img, dv2=dv2, upsampler=upsampler, expr=expr)
original_featurise = partial(_original_featurise, img=img, dv2=normal_dv2)
jbu_featurise = partial(_jbu_featurise, img=img, jbu=featup_jbu)
lift_featurise = partial(_lift_featurise, path=PATH, img=img, lift=lift)
strided_featurise = partial(_strided_featurise, img=img, strided=strided)

fns = (our_featurise, jbu_featurise, lift_featurise, strided_featurise)
featuriser_names = ("Ours", "FeatUp (JBU)", "LiFT", "Strided")

In [11]:
# dry run
for fn in fns:
    mem, time = measure_mem_time(DEVICE, fn)

In [12]:
mem_time_results: dict[str, dict] = {}
N_REPEATS =5
for name, fn in zip(featuriser_names, fns):
    mems, times = [], []
    for i in range(N_REPEATS):
        mem, time = measure_mem_time(DEVICE, fn)
        if name in ('Strided', 'LiFT'):
            # LiFT & strided upsample all 384 DINO features, JBU & Ours upsample to 128
            # for fair comparison we scale memory of those approaches with that ratio
            mem *= 128 / 384
        mems.append(mem)
        times.append(time)
    mem_time_results[name] = {'time': np.mean(times), 'mem': np.mean(mems)}
print(mem_time_results)

{'Ours': {'time': 0.0975382752, 'mem': 293.193359375}, 'FeatUp (JBU)': {'time': 0.1310169948, 'mem': 1038.5126953125}, 'LiFT': {'time': 0.061607448200000005, 'mem': 115.27278645833333}, 'Strided': {'time': 2.1706915068, 'mem': 371.07373046875}}


In [13]:
lengths = list(range(224, 896 - 64, 64))

our_featurise_l = partial(_our_featurise, dv2=dv2, upsampler=upsampler, expr=expr)
jbu_featurise_l = partial(_jbu_featurise, jbu=featup_jbu)
lift_featurise_l = partial(_lift_featurise, path=PATH, lift=lift)
strided_featurise_l = partial(_strided_featurise, strided=strided)

fns_vs_l = (our_featurise_l, jbu_featurise_l, lift_featurise_l, strided_featurise_l)

mem_vs_l: dict[str, list[float]] = {}
time_vs_l: dict[str, list[float]] = {}

for name, fn in zip(featuriser_names, fns_vs_l):
    mems, times = [], []
    for l in lengths:
        resized = img.resize((l, l))
        bound = partial(fn, img=resized)
        mem, time = measure_mem_time(DEVICE, bound)
        if name in ('Strided', 'LiFT'):

            mem *= 128 / 384
        mems.append(mem)
        times.append(time)
    mem_vs_l[name] = mems
    time_vs_l[name] = times

print(mem_vs_l)
print(time_vs_l)

{'Ours': [130.697265625, 210.31103515625, 314.50341796875, 430.13525390625, 580.919921875, 727.86669921875, 926.1728515625, 1153.587890625, 1353.32470703125, 1626.40283203125], 'FeatUp (JBU)': [483.38427734375, 731.4267578125, 1122.95947265625, 1489.2509765625, 2036.0791015625, 2532.7001953125, 3222.9599609375, 4001.646484375, 4691.2041015625, 5627.0888671875], 'LiFT': [58.58740234375, 89.16471354166666, 125.30289713541666, 168.88297526041666, 218.64127604166666, 276.2392578125, 405.64225260416663, 595.2775065104166, 847.32568359375, 1174.4586588541665], 'Strided': [166.4130859375, 257.95914713541663, 402.5550130208333, 542.3623046875, 747.2275390625, 935.1884765625, 1200.6023763020833, 1499.5504557291665, 1762.6240234375, 2122.0133463541665]}
{'Ours': [0.083556023, 0.099365946, 0.129589985, 0.162557875, 0.214227614, 0.265708625, 0.332228743, 0.402966352, 0.490823776, 0.600757946], 'FeatUp (JBU)': [0.072197202, 0.103229142, 0.150156703, 0.20095606, 0.268559712, 0.387262323, 0.426323897

In [14]:
original_pcaed = do_2D_pca(original_featurise(), 3)
im0 = rescale(original_pcaed, swap_channels=False)
print(im0.shape)

(24, 24, 3)


In [15]:
im1 = rescale(our_featurise())
im2 = rescale(jbu_featurise())
lift_feats = lift_featurise()
pcaed = do_2D_pca(lift_feats, 3)
im3 = rescale(pcaed, swap_channels=False).astype(np.float32)
stride_feats = strided_featurise()
strided_pcaed = do_2D_pca(stride_feats, 3)
im4 = rescale(strided_pcaed, swap_channels=False)

In [16]:
%%capture
imgs = (im1, im2, im3, im4)
fig, axs = plt.subplots(ncols=len(imgs))
fig.set_size_inches((16, 10))
for i in range(len(imgs)):
    axs[i].imshow(imgs[i])
    axs[i].set_axis_off()

In [30]:
%%capture
gs = gridspec.GridSpec(2, 3, width_ratios=[1, 2, 1], height_ratios=[1, 1])  

fig = plt.figure(figsize=(18, 10))  
# Create subplots in specific grid locations
ax1 = fig.add_subplot(gs[0, 0])  # Top-left
ax2 = fig.add_subplot(gs[1, 0])  # Bottom-left
ax3 = fig.add_subplot(gs[:, 1])  # Middle column (spanning both rows)
ax4 = fig.add_subplot(gs[0, 2])  # Top-right
ax5 = fig.add_subplot(gs[1, 2])  # Bottom-right

plt.rcParams["font.family"] = "serif"
inset_locs = ((0.12, 0.44, 0.3, 0.3), (0.27, 0.82, 0.3, 0.3), (0.25, 0.05, 0.3, 0.3), (0.80, 0.5, 0.3, 0.3))

TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

ax1.imshow(img)
ax1.set_title('Image: (336x336)', fontsize=TITLE_FS)
ax1.set_axis_off()

ax2.imshow(im0)
ax2.set_title('DINOv2: (24x24)', fontsize=TITLE_FS)
ax2.set_axis_off()

ot, om = 0.2, 150
max_t = mem_time_results['Strided']['time'] + ot
max_m = mem_time_results['FeatUp (JBU)']['mem'] + om

ax3.grid(True, linestyle="--", alpha=0.6)
ax3.set_xlim(-ot, max_t)
ax3.set_ylim(-om,  max_m)

ax3.set_xlabel('Time (s)', fontsize=LABEL_FS)
ax3.set_ylabel('Memory (MB)', fontsize=LABEL_FS)
ax3.tick_params(axis='both', labelsize=TICK_FS)

for i, res in enumerate(mem_time_results.items()):
    im = imgs[i]
    name, value = res
    x = value['time']# / max_t
    y = value['mem'] #/ max_m
    ax3.scatter(x, y, lw=10)

    o_img = OffsetImage(im, zoom=0.48)
    img_size = int(0.4 * 336)
    loc = inset_locs[i]
    ix, iy = loc[0] * max_t, loc[1] * max_m
    ab = AnnotationBbox(o_img, (ix, iy), frameon=True, bboxprops=dict(edgecolor=f"C{i}", linewidth=8), pad=0)
    ax3.add_artist(ab)
ax3.legend(mem_time_results.keys(), fontsize=TICK_FS)


ax4.set_xlabel('Image length (px)', fontsize=LABEL_FS)
ax4.set_ylabel('Time (s)', fontsize=LABEL_FS)
ax4.tick_params(axis='both', labelsize=TICK_FS)

for name, times in time_vs_l.items():
    ax4.plot(lengths, times, label=name, marker='.', lw=3, ms=15)
ax4.semilogy()
ax4.grid(True, linestyle="--", alpha=0.6)

ax5.set_xlabel('Image length (px)', fontsize=LABEL_FS)
ax5.set_ylabel('Memory (MB)', fontsize=LABEL_FS)
ax5.tick_params(axis='both', labelsize=TICK_FS)
ax5.grid(True, linestyle="--", alpha=0.6)

for name, mems in mem_vs_l.items():
    ax5.plot(lengths, mems, label=name, marker='.', lw=3, ms=15)
ax5.semilogy()

for key, ax in zip(('a', 'b', 'c'), (ax1, ax3, ax4)):
    y = 1.016 if key =='b' else 1.05
    x = -0.26 if key =='c' else -0.15
    ax.text(x, y, f"{key}.", transform=ax.transAxes, 
            size=LABEL_FS + 4, weight='bold')

plt.tight_layout(pad=2.5)
plt.savefig('fig_out/perf_landscape.png')

In [18]:
plt.close()