In [None]:
import sys
from pathlib import Path
from warnings import warn

import cv2
import h5py as h5
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

sys.path.append('../lib')
from modeling import models, registry
from modeling.utils import make_layer_hook, recur_collapse_feats
from storage import get_storage_functions
from local_paths import stim_dir, cache_dir

# Parameters

In [None]:
#============================================================================
# image to process
#============================================================================
im_md5s    = 'md5_im1,md5_im2'
sep        = ','
im_w       = 16    # size of full image; ...
im_h       = 16    # unit: dva, but only ratio (im_size/patch_size) really matters
ar_tol     = 3/4   # aspect ratio tolerance (between image file and provided size)


#============================================================================
# patch size and resolution
#============================================================================
patch_size =  2    # size of each crop patch
patch_step =  0.5  # step size of patch location


#============================================================================
# model params
#============================================================================
model_name    = 'vit_large_patch16_384'
layer_name    = 'blocks.13.attn.qkv'
spatial_averaging = True  # over W, H for conv; over S for attention


#============================================================================
# paths
#============================================================================
# unlike other scripts, this one is intentionally unaware of subfolders
# (thereby requiring image IDs, e.g., MD5s, to truly be unique)
# all images are in [stim_dir]/Stimuli), so specify it explicitly
stim_dir = stim_dir + 'Stimuli/'

output_root = cache_dir + 'feats/'


#============================================================================
# misc
#============================================================================
device = 'cuda:0'
bgc = (128,128,128)   # background color; used to pad images

# Check prereqs and params

In [None]:
print('Loading images from folder', stim_dir)
stim_dir = Path(stim_dir).expanduser()
assert stim_dir.is_dir()

output_root = Path(output_root)
output_path = output_root / model_name / layer_name / \
    f'{im_w:.1f}x{im_h:.1f}_as_{patch_size}x{patch_size}_in_{patch_step:.2f}_steps.h5'
print('Saving results to', output_path)
output_path = output_path.expanduser()
output_path.parent.mkdir(exist_ok=True, parents=True)

# Prepare parameters; save config

In [None]:
im_md5s = im_md5s.split(sep)
print('Processing', len(im_md5s), 'images')

In [None]:
im_size = (im_w, im_h)
ar_tol = min(ar_tol, 1/ar_tol)
patch_step = float(patch_step)

# interpret this as origin == lower left
patches_ledge_x = np.arange(int(np.ceil(im_size[0]/patch_step))) * patch_step  # full and right partial patches
patches_ledge_x = np.concatenate([
    np.arange(-1, -int(np.ceil(patch_size/patch_step)), -1)[::-1] * patch_step,  # left partial patches
    patches_ledge_x])
patches_ledge_x -= im_size[0] / 2  # align to image center
n_patches_x = len(patches_ledge_x)

patches_ledge_y = np.arange(int(np.ceil(im_size[1]/patch_step))) * patch_step  # full and right partial patches
patches_ledge_y = np.concatenate([
    np.arange(-1, -int(np.ceil(patch_size/patch_step)), -1)[::-1] * patch_step,  # left partial patches
    patches_ledge_y])
patches_ledge_y -= im_size[1] / 2  # align to image center
n_patches_y = len(patches_ledge_y)

print('Patches step size:', patch_step)
print(f'Number of patches: {n_patches_x} x {n_patches_y} (x-by-y)')
print('Patches (bin lower edge):')
print('(The coordinates in degrees are with origin at image center)')
print('x:')
print('\t' + str(patches_ledge_x).replace('\n', '\n\t'))
print('y:')
print('\t' + str(patches_ledge_y).replace('\n', '\n\t'))

In [None]:
model_imsize = registry.get_default_preproc(model_name)['imsize']
print('Model input image size:', model_imsize)

In [None]:
tqdm_ = lambda x: tqdm(x, mininterval=300, miniters=10)  # to avoid bloated log file

In [None]:
save_results, add_attr_to_dset, check_equals_saved, link_dsets, copy_group = \
    get_storage_functions(output_path)

In [None]:
save_results('config/stimuli/size_dva', im_size)

group = 'config/patch_grid/'
save_results(group+'size', patch_size)
save_results(group+'step', patch_step)
save_results(group+'eft_edges', patches_ledge_x)
save_results(group+'right_edges', patches_ledge_x+patch_size)
save_results(group+'lower_edges', patches_ledge_y)
save_results(group+'upper_edges', patches_ledge_y+patch_size)
save_results(group+'x_locs', patches_ledge_x+patch_size/2)
save_results(group+'y_locs', patches_ledge_y+patch_size/2)

group = 'config/modelling/'
save_results(group+'model_name', model_name)
save_results(group+'layer_name', layer_name)
save_results(group+'input_image_size', model_imsize)
save_results(group+'spatial_averaging', spatial_averaging)

# Locate & load images

In [None]:
done_md5s = None
offset = 0
if output_path.is_file():
    with h5.File(output_path, 'r') as f:
        try:
            done_md5s = f['md5'][()].astype(str)
            offset = len(done_md5s)
        except KeyError:
            pass

if done_md5s is not None:
    done_md5s = set(done_md5s)
    print(len(done_md5s), 'images already done')
    im_md5s = [v for v in im_md5s if v not in done_md5s]
    print('Processing', len(im_md5s), 'remaining images')

In [None]:
im_paths = [next(stim_dir.glob(md5+'.*')) for md5 in im_md5s]
assert all(p.is_file() for p in im_paths)

In [None]:
ar = im_size[0] / im_size[1]
print(f'Defined image aspect ratio: {ar:.2f}')

n_ims = len(im_paths)
images = np.empty(n_ims, dtype=object)

for iim, fp in enumerate(im_paths):
    image = Image.open(fp)

    # check image aspect ratio
    ar_ = image.size[0] / image.size[1]
    if not (0.99 < ar_/ar < 1.01):
        if not (ar_tol < ar_/ar < 1/ar_tol):
            warn(
                f'image {fp.name} (size: {image.size}; AR = {ar_:.2f}) '
                f'is very far from expected aspect ratio (size: {im_size}; AR = {ar:.2f} '
                '(resizing it regardless, beceause it would have been presented at the '
                'specified size)')

        i = np.argmin(image.size[:2] / np.array(im_size))
        if i == 0:
            w = image.size[0]
            h = int(round(w / ar))
        else:
            h = image.size[1]
            w = int(round(h * ar))
        print(f'Resizing {fp.name} (size: {image.size}; AR = {ar_:.2f}) to size {(h,w)} (AR = {w/h:.2f})')
        image = np.array(image.resize((w, h)))
    else:
        image = np.array(image)

    # make ims 8-bit RGB
    assert image.dtype == np.uint8
    if image.ndim == 3:
        assert image.shape[-1] in (3,4)
        if image.shape[-1] == 4:
            image = image[:,:,:3]
    else:
        assert image.ndim == 2
        image = np.repeat(image[:,:,None], 3, axis=-1)

    images[iim] = image

print(len(images), 'images')
images.shape, images.dtype

# Prepare model

In [None]:
# when no images to process, save time by not loading model
# (unfortunately, I do not know how to early-stop an ipynb from within itself)
if len(im_md5s):

    model = models.get_model_by_name(model_name, dev=device)
    preprocessing_func = models.get_preprocessor_by_model_name(model_name, preproc_imsize=False, preproc_from='numpy')

    class Embedder:
        def __init__(
                self, model=model, preproc_fun=preprocessing_func,
                model_name=model_name, layer_names=layer_name,
                spatial_averaging=spatial_averaging,
                fwd_fun=None, device=device, pbar=tqdm_):

            self.model = model
            self.preproc_fun = preproc_fun
            self.spatial_averaging = spatial_averaging
            self.device = device
            self.pbar = pbar

            if isinstance(layer_names, str):
                layer_names = (layer_names,)
            else:
                assert all(isinstance(n, str) for n in layer_names)
            self.layer_names = layer_names

            if fwd_fun is None:
                if model_name is not None and 'CLIP' in model_name:
                    fwd_fun = model.encode_image
                else:
                    fwd_fun = model.__call__
            self.fwd_fun = fwd_fun

            hooks = {}
            hdls = {}
            for n in layer_names:
                hooks[n], hdls[n] = make_layer_hook(model, n, return_handle=True)
            self.hooks = hooks
            self.hdls = hdls

        def extract_pooled_features(self, ims):
            feats = {n: [] for n in self.layer_names}
            with torch.no_grad():
                for im in self.pbar(ims):
                    tim = self.preproc_fun(im).unsqueeze(0).to(self.device)
                    self.fwd_fun(tim)

                    for n, hook in self.hooks.items():
                        feats_ = recur_collapse_feats(hook.o, spatial_averaging=self.spatial_averaging)
                        if not isinstance(feats_, torch.Tensor):
                            raise ValueError(f'unexpected feature type {type(feats_)} at layer {n}: {feats_}')
                        feats_ = feats_.cpu().numpy()
                        feats[n].append(feats_)

            return {n: np.array(v) for n, v in feats.items()}

In [None]:
if len(im_md5s):
    embedder = Embedder()
    test_im = np.full((model_imsize, model_imsize, 3), bgc, dtype=np.uint8)

    feats = embedder.extract_pooled_features([test_im])
    feats = feats[layer_name][0]
    print('feats:', feats.shape, feats.dtype)
    sample_feats = feats
    feats_shape = sample_feats.shape
    save_results('config/modelling/pooled_feat_shape', sample_feats.shape)

    with h5.File(output_path, 'a') as f:
        if 'feats/bg' not in f:
            save_results('feats/bg', sample_feats)

# Initialize result storage

In [None]:
def create_ignoring_existing(f, *args, attrs=None, **kwargs):
    assert isinstance(args[0], str)
    try:
        dset = f.create_dataset(*args, **kwargs)
        if attrs is not None:
            assert isinstance(attrs, dict)
            for k, v in attrs.items():
                dset.attrs[k] = v
    except ValueError as e:
        if 'name already exists' not in str(e):
            raise
        dset = f[args[0]]
        if attrs is not None:
            assert isinstance(attrs, dict)
            for k, v in attrs.items():
                check_equals_saved(v, dset.attrs[k], k)

In [None]:
cache_opts = dict(compression='gzip', compression_opts=9)
if len(im_md5s):
    with h5.File(output_path, 'a') as f:
        create_ignoring_existing(
            f, 'md5',
            shape=(0,),
            maxshape=(None,),
            chunks=(1,),
            dtype='S32',
            **cache_opts)

        dims = np.array(['image', 'feat_chan'], dtype=bytes)
        coords = np.array(['md5', 'feat_chans'], dtype=bytes)
        create_ignoring_existing(
            f, 'feats/full_image',
            shape=(0,)+feats_shape,
            maxshape=(None,)+feats_shape,
            attrs=dict(dims=dims, coords=coords),
            chunks=(1,)+feats_shape,
            dtype=sample_feats.dtype,
            **cache_opts)

        dims = np.array(['image', 'rf_x', 'rf_y', 'feat_chan'], dtype=bytes)
        coords = np.array(['md5', 'patch_locs', 'patch_locs', 'feat_chans'], dtype=bytes)
        shape_ = (n_patches_x, n_patches_y,) + feats_shape
        create_ignoring_existing(
            f, 'feats/patch_grid',
            shape=(0,)+shape_,
            maxshape=(None,)+shape_,
            attrs=dict(dims=dims, coords=coords),
            chunks=(1,)+shape_,
            dtype=sample_feats.dtype,
            **cache_opts)

# Main loop

In [None]:
def get_image_patch(im, im_size_dva, patch_min_x_dva, patch_min_y_dva, wsize_dva, wsize_px, bgc=bgc):
    assert isinstance(im, np.ndarray)# and im.shape[0] == im.shape[1]
    assert isinstance(wsize_px, int)
    map1 = np.arange(wsize_px)
    map2 = map1.copy()
    ppd = im.shape[0] / im_size_dva[0]
    map1 = (
        ppd * (map1+0.5) / wsize_px * wsize_dva
        + ppd * (patch_min_x_dva + im_size_dva[0] / 2)
    ).astype(np.float32)
    ppd = im.shape[1] / im_size_dva[1]
    map2 = (
        ppd * (map2+0.5) / wsize_px * wsize_dva
        + ppd * (-patch_min_y_dva -wsize_dva + im_size_dva[1] / 2)
    ).astype(np.float32)
    map1 = np.repeat(map1[None,:], wsize_px, 0)
    map2 = np.repeat(map2[:,None], wsize_px, 1)
    wim = cv2.remap(
        im.astype(np.float32), map1, map2, interpolation=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT, borderValue=bgc)
    wim = np.round(wim).astype(np.uint8)
    return wim

In [None]:
embedder.pbar = iter  # to avoid bilayer tqdm
for iim, (im, md5) in enumerate(zip(tqdm_(images), im_md5s)):
    i_ = offset + iim

    # full image feats
    im_ = np.array(Image.fromarray(im).resize((model_imsize, model_imsize)))
    feats_ = embedder.extract_pooled_features([im_])
    feats = feats_[layer_name][0]

    with h5.File(output_path, 'a') as f:
        dset = f['feats/full_image']
        if dset.shape[0] < i_ + 1:
            dset.resize(i_+1, axis=0)
        dset[i_] = feats

    # patch grid feats
    featss = []

    for ix, x0 in enumerate(patches_ledge_x):
        featss.append([])

        for iy, y0 in enumerate(patches_ledge_y):
            wim = get_image_patch(im, im_size, x0, y0, patch_size, model_imsize)
            feats_ = embedder.extract_pooled_features([wim])
            feats = feats_[layer_name][0]
            featss[-1].append(feats)

    featss = np.array(featss)

    with h5.File(output_path, 'a') as f:
        dset = f['feats/patch_grid']
        if dset.shape[0] < i_ + 1:
            dset.resize(i_+1, axis=0)
        dset[i_] = featss

        dset = f['md5']
        if dset.shape[0] < i_ + 1:
            dset.resize(i_+1, axis=0)
        dset[i_] = md5

# Wrap up

In [None]:
%load_ext watermark
%watermark
%watermark -vm --iversions -rbg