In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

from random import seed

import numpy as np
from time import time_ns

from vulture.utils import to_numpy, do_2D_pca, closest_crop, convert_image
from vulture.main import get_dv2_model, get_upsampler_and_expr, get_hr_feats
from vulture.feature_prep import get_lr_feats, project

torch.backends.cudnn.enabled = True
torch.cuda.empty_cache()

DEVICE = "cuda:1"

SEED = 2
np.random.seed(SEED)
torch.manual_seed(SEED)
seed(SEED)

In [2]:
model_path = "../trained_models/e5000_full_fit_reg.pth"
cfg_path = "../vulture/models/configs/upsampler_FU_default.json"

dv2 = get_dv2_model(True, to_half=True, add_flash=True, device=DEVICE)
dv2 = dv2.to(DEVICE)

upsampler, expr = get_upsampler_and_expr(model_path, cfg_path, device=DEVICE)

Using cache found in /home/ronan/.cache/torch/hub/ywyue_FiT3D_main


[128, 128, 128, 128]


In [3]:
PATH = "fig_data/supp_more_feat_vis"
paths = [f"{PATH}/{name}.jpg" for name in ('394', 'sns2_anode_cropped', 'biphase_steel_crop' , 'cells') ]

In [4]:
@torch.no_grad()
def get_lr_hr_feats(img: Image.Image, ) -> tuple[torch.Tensor, torch.Tensor]:
    tr = closest_crop(img.height, img.width, 14, True)
    inp_img = (
    TF.normalize(
        TF.pil_to_tensor(img).to(torch.float32),
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225],
    )
    .unsqueeze(0)
    .to(DEVICE)
    )
    inp_img_dino = convert_image(img, tr, to_half=True, device_str=DEVICE)
    lr_feats, _ = get_lr_feats(dv2, [inp_img_dino], 50, fit3d=True, n_feats_in=expr.n_ch_in)
    lr_feats = lr_feats.to(DEVICE)
    lr_feats = F.normalize(lr_feats, p=1, dim=1)
    with torch.autocast(DEVICE, torch.float16):
        hr_feats = upsampler(inp_img, lr_feats)
    return (lr_feats, hr_feats)

In [5]:
def rescale(a: np.ndarray) -> np.ndarray:
    a_min = a.min(axis=(0, 1), keepdims=True)
    a_max = a.max(axis=(0, 1), keepdims=True)
    out = (a - a_min) / (a_max - a_min)
    return out

def truncate_and_rescale(x: torch.Tensor) -> np.ndarray:
    x_np = x.cpu()[0].numpy().astype(np.float32)
    x_red = x_np.transpose((1, 2, 0))[:, :, 0:3]
    x_rescaled = rescale(x_red)
    return x_rescaled

In [6]:
imgs, lr_feats, hr_feats = [], [], []
for img_path in paths:
    img = Image.open(img_path).convert("RGB")
    lr, hr = get_lr_hr_feats(img)

    imgs.append(img)
    lr_feats.append(truncate_and_rescale(lr))
    hr_feats.append(truncate_and_rescale(hr))

In [7]:
def center_crop(img, crop_h, crop_w):
    """
    Center crop a PIL Image or numpy array to (crop_h, crop_w).
    """
    if isinstance(img, Image.Image):
        w, h = img.size
        left = (w - crop_w) // 2
        top = (h - crop_h) // 2
        right = left + crop_w
        bottom = top + crop_h
        return img.crop((left, top, right, bottom))
    elif isinstance(img, np.ndarray):
        h, w = img.shape[:2]
        top = (h - crop_h) // 2
        left = (w - crop_w) // 2
        return img[top:top+crop_h, left:left+crop_w, ...]
    else:
        raise TypeError("Input must be a PIL.Image.Image or numpy.ndarray")

In [8]:
def center_crop_aspect(img, aspect_ratio):
    """
    Center crop a PIL Image or numpy array to the given aspect ratio.

    Parameters:
        img: PIL.Image.Image or np.ndarray
        aspect_ratio: float (width / height)

    Returns:
        Cropped image (same type as input)
    """
    if isinstance(img, Image.Image):
        w, h = img.size
    elif isinstance(img, np.ndarray):
        h, w = img.shape[:2]
    else:
        raise TypeError("Input must be a PIL.Image.Image or numpy.ndarray")

    current_ar = w / h

    if current_ar > aspect_ratio:
        # Image is too wide, crop width
        new_w = int(h * aspect_ratio)
        new_h = h
    else:
        # Image is too tall, crop height
        new_w = w
        new_h = int(w / aspect_ratio)

    left = (w - new_w) // 2
    top = (h - new_h) // 2

    if isinstance(img, Image.Image):
        return img.crop((left, top, left + new_w, top + new_h))
    else:
        return img[top:top + new_h, left:left + new_w, ...]

In [21]:
%%capture
fig, axs = plt.subplots(nrows=len(imgs), ncols=3, figsize=(18, 12))

plt.rcParams["font.family"] = "serif"
TITLE_FS = 25
LABEL_FS = 23
TICK_FS = 21

titles = ('Image', 'Low-res DINOv2', 'HR ViT')
img_names = ('Iron alloy', 'SnS2 anode', 'Biphase steel', 'Plant cells')
h, w, _ = hr_feats[0].shape
for i, (img, lr, hr) in enumerate(zip(imgs, lr_feats, hr_feats)):
    axs[i, 0].imshow(center_crop_aspect(img, w / h))
    axs[i, 1].imshow(center_crop_aspect(lr, (w // 14) / (h // 14) ))
    axs[i, 2].imshow(center_crop_aspect(hr, w / h))

    axs[i, 0].set_ylabel(img_names[i], fontsize=LABEL_FS)

    for j, ax in enumerate(axs[i]):
        # ax.set_axis_off()
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_frame_on(False)
        if i == 0:
            ax.set_title(titles[j], fontsize=TITLE_FS)
plt.tight_layout()
plt.savefig('fig_out/supp_feat_vis.png')