In [1]:
import torch
import numpy as np
from PIL import Image
from random import seed
import matplotlib.pyplot as plt
from sklearn.preprocessing import minmax_scale

from yoeo.main import get_dv2_model
from yoeo.utils import load_image, convert_image, to_numpy, closest_crop, do_2D_pca
from yoeo.feature_prep import PCAUnprojector, project, JitteredImage, DataLoader


SEED = 10672
seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f669741d210>

In [2]:
torch.cuda.empty_cache()

DEVICE = "cuda:0"

In [3]:
@torch.no_grad()
def original_featurise(img: Image.Image, dv2: torch.nn.Module, to_numpy: bool=False) -> np.ndarray:
    # _img = Image.open(path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    tensor = convert_image(img, tr)
    _, _, h, w = tensor.shape
    with torch.autocast("cuda", torch.float16):
        dino_feats = dv2.forward_features(tensor)['x_norm_patchtokens']
    n_patch_w, n_patch_h = w // 14, h // 14

    dino_feats = dino_feats.permute((0, 2, 1))
    dino_feats = dino_feats.reshape((1, -1, n_patch_h, n_patch_w,))
    
    if to_numpy:
        
        dino_feats_np = to_numpy(dino_feats)
        return dino_feats_np
    else:
        return dino_feats

def vis(tensor: torch.Tensor, k: int=3) -> np.ndarray:
    arr_np = tensor.cpu().numpy()[0, :k]
    c, h , w = arr_np.shape
    flat = arr_np.reshape((c, h * w)).T
    scaled = minmax_scale(flat)
    out = scaled.reshape((h, w, c))
    return out


def get_pca(img: torch.Tensor, n_imgs: int, model: torch.nn.Module, fit3d: bool=False) -> PCAUnprojector:
    cfg_n_images = n_imgs
    cfg_use_flips = True
    cfg_max_zoom = 1.8
    cfg_max_pad = 30
    cfg_pca_batch = n_imgs
    cfg_proj_dim = 128

    dataset = JitteredImage(
        [img], cfg_n_images, cfg_use_flips, cfg_max_zoom, cfg_max_pad
    )
    loader = DataLoader(dataset, cfg_pca_batch)
    with torch.no_grad():

        jit_features = []
        for transformed_image, tp in loader:
            jit_features.append(project(transformed_image, model, fit3d))
        jit_features = torch.cat(jit_features, dim=0)

        unprojector = PCAUnprojector(
            jit_features[:cfg_pca_batch],
            cfg_proj_dim,
            img.device,
            use_torch_pca=True,
        )
    return unprojector


In [4]:
dv2 = get_dv2_model(False, False, False, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [5]:
PATH = "fig_data/perf_landscape/fish_compare.png"
img = Image.open(PATH).convert("RGB")
_h, _w = img.height, img.width
tr = closest_crop(_h, _w)

img_tensor = convert_image(img, tr, to_half=False, device_str=DEVICE)

In [None]:
feats = original_featurise(img, dv2)
unprojector = PCAUnprojector(feats, 128, DEVICE, True)
single_img_proj = unprojector.project(feats)

In [111]:
preview_proj_feats = [single_img_proj]
N_PREVIEW_IMGS = (1, 2, 3, 4, 5, 10, 25, 50)
for n_imgs in N_PREVIEW_IMGS[1:]:
    SEED = 10672
    seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    pca = get_pca(img_tensor, n_imgs, dv2, False)
    preview_proj_feats.append(pca.project(feats))

In [112]:
%%capture
N_COLS = len(N_PREVIEW_IMGS)
fig, axs = plt.subplots(1, N_COLS, figsize=(20, 20))

print(len(preview_proj_feats))
for i in range(0, N_COLS):
    axs[i].imshow(vis(preview_proj_feats[i]))
    axs[i].set_axis_off()
    axs[i].set_title(f"N: {N_PREVIEW_IMGS[i]}")
plt.tight_layout()

In [74]:
from os import listdir
from scipy.spatial.distance import cdist

PATH = '../data/imagenet_reduced/splits/7'
img_names = listdir(PATH)
img_paths = [f"{PATH}/{name}" for name in img_names][:100]


all_diffs = []
for i, img_path in enumerate(img_paths):
    print(f"{i}")
    _img = Image.open(img_path).convert("RGB")
    _h, _w = img.height, img.width
    tr = closest_crop(_h, _w)

    _feats = original_featurise(_img, dv2)
    _unprojector = PCAUnprojector(_feats, 128, DEVICE, True)
    _single_img_proj = _unprojector.project(_feats)

    proj_feats = [single_img_proj]
    N_IMGS = list(range(1, 50, 1))
    for n_imgs in N_IMGS[1:]:
        SEED = 10672 + i
        seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)
        pca = get_pca(img_tensor, n_imgs, dv2, False)
        proj_feats.append(pca.project(feats))

    ref_feat = to_numpy(_single_img_proj)
    ref_feat_flat = ref_feat.reshape((128, -1)).T

    diffs = []
    for feats_ in proj_feats[1:]:
        feats_flat = to_numpy(feats_).reshape((128, -1)).T
        # diff = cdist(feats_flat, ref_feat_flat, 'cosine')
        diff = np.abs(feats_flat - ref_feat_flat)
        diffs.append(np.mean(diff))
    all_diffs.append(diffs)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [75]:
diff_arr = np.array(all_diffs)
mean_diffs = np.mean(diff_arr, axis=0)

In [172]:
%%capture
import matplotlib.gridspec as gridspec

plt.rcParams["font.family"] = "serif"

TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

# Create the figure
fig = plt.figure(figsize=(20, 8))

# Define the GridSpec
# Set width ratios: first column is 2x wider
gs = gridspec.GridSpec(2, 6, width_ratios=[1, 1, 1, 1, 1.75, 1.75])
ax_large = fig.add_subplot(gs[:, -2:])
ax_large.plot(N_IMGS[1:], mean_diffs[:], lw=3)
ax_large.tick_params(axis='both', labelsize=TICK_FS)
ax_large.set_xlabel('N images in PCA', fontsize=LABEL_FS)
ax_large.set_ylabel(r'$<|\boldsymbol{f} - PCA_{i}(\boldsymbol{f})|>$', fontsize=LABEL_FS)
plt.yticks(rotation=45, ha='right' )

axs = []
for i, preview in enumerate(preview_proj_feats):
    row = i // 4
    col = i % 4
    ax = fig.add_subplot(gs[row, col])
    ax.imshow(vis(preview))
    ax.set_axis_off()
    ax.set_title(f"N PCA: {N_PREVIEW_IMGS[i]}", fontsize=TITLE_FS)
    axs.append(ax)


for key, ax in zip(('a', 'b'), (axs[0], ax_large)):
    y = 1.6 if key =='a' else 1.05
    x = -0.26 if key =='c' else -0.15
    ax.text(x, y, f"{key}.", transform=ax.transAxes, 
            size=LABEL_FS + 4, weight='bold')

plt.tight_layout()
plt.savefig('fig_out/supp_dist_shift.png', bbox_inches='tight')