In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib

while pathlib.Path(".").absolute().name != "aerial-disentangled-representations":
    os.chdir("..")

In [None]:
DEVICE = "mps"

In [None]:
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
from disentangled_representations.src.data_processing.aerial_dataset_instances import aerial_datasets_mapping, \
    Hi_UCD_Dataset, A

Hi_UCD_test_dataset = aerial_datasets_mapping["Hi_UCD_Dataset_test"]
Hi_UCD_test_dataset_visuals = Hi_UCD_Dataset(split="test", read_color=True,
                                             shared_transform=A.Compose([], additional_targets={}),
                                             unique_transform=A.Compose([]))

### Model loading

Requires model's checkpoint.

In [None]:
import pathlib
from loguru import logger

ckpt_dir = pathlib.Path("tb_logs") / "disent_rep" / "Deterministic projector | dim(z) = 128" / "checkpoints"

In [None]:
from disentangled_representations.src.training_procedure import LitKapellmeister, LossWeights
from disentangled_representations.src.models.projectors import SimpleDeterministicProjector, SimpleVariationalProjector
from disentangled_representations.src.models.image_encoders import EfficientNetB0


def load_models_by_checkpoint_dir_path(ckpt_dir: pathlib.Path, out_dim: int, variational: bool):
    assert ckpt_dir.is_dir()

    _ckpt_paths = list(ckpt_dir.glob("*.ckpt"))
    if len(_ckpt_paths) != 1:
        logger.warning(f"{_ckpt_paths=}")

    ckpt_path = _ckpt_paths[0]

    encoder = EfficientNetB0(in_channels=1)
    embedding_dim = int(encoder.feature_dim)

    if variational:
        projector = SimpleVariationalProjector(
            input_dimensionality=embedding_dim,
            hidden_features=[512],
            latent_dimensionality=out_dim // 2
        )
    else:
        projector = SimpleDeterministicProjector(
            input_dimensionality=embedding_dim,
            hidden_features=[512],
            output_dimensionality=out_dim
        )

    loss_weights = LossWeights(w_NTXent=1.0, w_KL=0.5)

    model = LitKapellmeister.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        image_encoder=encoder,
        projector=projector,
        loss_weights=loss_weights,
        map_location=DEVICE,
        strict=True
    )
    return model


deterministic_lit_kapellmeister = load_models_by_checkpoint_dir_path(ckpt_dir, out_dim=128, variational=False)
encoder_det = deterministic_lit_kapellmeister.image_encoder
projector_det = deterministic_lit_kapellmeister.projector

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F


def compute_all_projected_embeddings(dataset, lit_kapellmeister: LitKapellmeister, dim: int):
    loader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    Z_A = np.empty((len(loader.dataset), dim), dtype=np.float32)
    Z_B = Z_A.copy()

    idx = 0
    device = "mps"

    project: bool = not lit_kapellmeister.kapellmeister.is_projector_variational
    logger.info(f"{project=}")

    with torch.no_grad():
        for A, B in tqdm(loader, desc="Computing embeddings", unit="batch"):
            bsz = A.size(0)

            A = A.to(device)
            B = B.to(device)

            zA = lit_kapellmeister(A)
            zB = lit_kapellmeister(B)

            if project:
                # NOTE: a variational embedding should not be projected.
                zA = F.normalize(zA, dim=1)
                zB = F.normalize(zB, dim=1)

            Z_A[idx:idx + bsz] = zA.cpu().numpy()
            Z_B[idx:idx + bsz] = zB.cpu().numpy()

            idx += bsz
    return Z_A, Z_B


Z_A_det, Z_B_det = compute_all_projected_embeddings(Hi_UCD_test_dataset, deterministic_lit_kapellmeister, dim=128)

In [None]:
# NOTE: full similarities matrix is computed once for brute force nearest neighbour search to avoid performance degradation on the metrics.
similarities_det = Z_A_det @ Z_B_det.T

#### Retrieval performance assessment

In [None]:
import matplotlib.pyplot as plt


def evaluate_retrieval_metrics(
        similarities,
        correct_indices,
        ks_map=[1, 5, 10, 20, 50],
        max_cmc_k=50,
        font_size=24,
        save_path=None
):
    sims = np.asarray(similarities)
    correct_idx = np.asarray(correct_indices, dtype=int)
    Q, N = sims.shape
    assert correct_idx.shape[0] == Q, "correct_indices must have length `Q`"

    true_sim = sims[np.arange(Q), correct_idx]
    rank_positions = 1 + np.sum(sims > true_sim[:, None], axis=1)

    rank1_acc = np.mean(rank_positions == 1)
    print(f"Rank-1 accuracy: {rank1_acc * 100:.2f}%")

    cmc_ks = np.arange(1, max_cmc_k + 1)
    cmc = np.mean(rank_positions[:, None] <= cmc_ks[None, :], axis=0)

    ks = np.array(ks_map)
    precisions = np.where(
        rank_positions[:, None] <= ks[None, :],
        1.0 / rank_positions[:, None],
        0.0
    )
    map_at_k = np.mean(precisions, axis=0)

    # ——— PLOTTING ———
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size * 0.8,
        "ytick.labelsize": font_size * 0.8,
        "legend.fontsize": font_size * 0.8,
    })

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    ax = axes[0]
    ax.plot(cmc_ks, cmc, lw=2)
    ax.set_xlabel("k", fontsize=font_size)
    ax.set_ylabel("CMC(k) / Recall@k", fontsize=font_size)
    ax.set_title("CMC Curve (Recall@k)", fontsize=font_size)
    ax.tick_params(axis='both', which='major', labelsize=font_size * 0.8)
    ax.grid(True)

    ax = axes[1]
    ax.plot(ks, map_at_k, marker='o', lw=2)
    ax.set_xlabel("k", fontsize=font_size)
    ax.set_ylabel("mAP@k", fontsize=font_size)
    ax.set_title("Mean Average Precision @ K", fontsize=font_size)
    ax.tick_params(axis='both', which='major', labelsize=font_size * 0.8)
    ax.grid(True)

    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, format='pdf')

    plt.show()

    return {
        'rank1_acc': rank1_acc,
        'cmc_ks': cmc_ks,
        'cmc': cmc,
        'ks_map': ks,
        'map_at_k': map_at_k,
        'rank_positions': rank_positions
    }

In [None]:
_ = evaluate_retrieval_metrics(similarities_det, np.arange(similarities_det.shape[0]), ks_map=[1, 2, 3, 5, 10, 15, 20, 30, 50], save_path=None)

## Retrieval visualizations


In [None]:
def visualize_retrieval_batch(
        query_imgs,
        retrieved_imgs,
        k=5,
        figsize_per_query=(2.5, 2.5),
        font_size=14,
        save_path=None
):
    n = len(query_imgs)
    cols = k + 1
    column_titles = ["Query"] + [f"Rank {i + 1}" for i in range(k)]
    col_colors = ["#e0f7fa"] + [
        "#f5f5f5" if (j % 2) == 0 else "#ffffff"
        for j in range(1, cols)
    ]

    fig, axes = plt.subplots(
        n + 1,
        cols,
        figsize=(figsize_per_query[0] * cols, figsize_per_query[1] * (n + 1)),
        dpi=200
    )

    if n == 1:
        axes = axes[np.newaxis, :]

    for j, title in enumerate(column_titles):
        ax = axes[0, j]
        ax.axis("off")
        ax.set_facecolor(col_colors[j])
        ax.text(
            0.5, 0.5, title,
            ha="center", va="center",
            fontsize=font_size,
            weight="bold"
        )

    for i in range(n):
        ax = axes[i + 1, 0]
        ax.imshow(query_imgs[i])
        ax.axis("off")
        ax.set_facecolor(col_colors[0])

        for j in range(k):
            ax = axes[i + 1, j + 1]
            ax.imshow(retrieved_imgs[i][j])
            ax.axis("off")
            ax.set_facecolor(col_colors[j + 1])

    plt.tight_layout(pad=1.0)

    if save_path:
        fig.savefig(save_path, format="pdf", bbox_inches="tight")

    plt.show()

In [None]:
import numpy as np
import torch


def get_retrieved_images_by_indices(similarities, indices, k: int):
    """
    Returns query images and top-k retrieved images.
    """
    queries_vis = []
    for i in indices:
        vis_img = Hi_UCD_test_dataset_visuals[i][0]
        queries_vis.append(vis_img)

    similar_subset = similarities[indices, :]
    indices_matrix = np.argsort(-similar_subset, axis=1)[:, :k]
    print(f"Correct: {indices_matrix[:, 0] == np.array(indices)}")

    retrieved_images = []
    for row in indices_matrix:
        retrieved_images.append([Hi_UCD_test_dataset_visuals[j][1] for j in row])

    return queries_vis, retrieved_images


In [None]:
indices_correct = [0, 11, 20, 266, 320, 323]
indices_correct_2 = [325, 462, 463, 470, 472, 502]
indices_incorrect = [5, 6,  8, 10, 12, 52]
random_indices_all = [1557, 1336, 1324, 1590, 1286, 1868, 1907, 1832, 1190, 1613, 1439, 1773]
random_indices_1 = random_indices_all[:6]
random_indices_2 = random_indices_all[6:]

k = 3

In [None]:
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_det, indices_correct, k=k), k=k, save_path=None)

In [None]:
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_det, indices_correct_2, k=k), k=k, save_path=None)

In [None]:
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_det, indices_incorrect, k=k), k=k, save_path=None)

In [None]:
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_det, random_indices_1, k=k), k=k, save_path=None)

In [None]:
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_det, random_indices_2, k=k), k=k, save_path=None)

# Variational

In [None]:
ckpt_dir_path = pathlib.Path("tb_logs") / "disent_rep" / "Variational projector | dim(z) = 128" / "checkpoints"

# NOTE: the MLP's output dimension is 256 because it returns 128 parameters for the Gaussian mean and 128 parameters for its (log-)variance.
variational_lit_kapellmeister = load_models_by_checkpoint_dir_path(ckpt_dir_path, out_dim=256, variational=True)
encoder_var = variational_lit_kapellmeister.image_encoder
projector_var = variational_lit_kapellmeister.projector

In [None]:
Z_A_var, Z_B_var = compute_all_projected_embeddings(Hi_UCD_test_dataset, variational_lit_kapellmeister, dim=256)

In [None]:
from disentangled_representations.src.models.abstract_models import VariationalProjector

mu_A, log_variance_A = VariationalProjector.multivariate_params_from_vector(torch.from_numpy(Z_A_var))
mu_B, log_variance_B = VariationalProjector.multivariate_params_from_vector(torch.from_numpy(Z_B_var))

In [None]:
# NOTE: Log-root generalized variance.
log_det_np = (1 / 128) * log_variance_A.sum(dim=1).detach().cpu().numpy()

fig, ax = plt.subplots(figsize=(9, 6))
ax.hist(log_det_np, bins=30, edgecolor='black')

ax.set_xlabel("Avg. log det Σ", fontsize=20)
ax.set_ylabel("Count", fontsize=20)
ax.tick_params(labelsize=18)

ax.grid(True)
fig.tight_layout()
fig.savefig("generalized_variance", format='pdf')
plt.show()

In [None]:
similarities_var = (mu_A / mu_A.norm(dim=1, keepdim=True)) @ (mu_B / mu_B.norm(dim=1, keepdim=True)).T

In [None]:
_ = evaluate_retrieval_metrics(similarities_var, np.arange(similarities_var.shape[0]), ks_map=[1, 2, 3, 5, 10, 15, 20, 30, 50], save_path=None)

In [None]:
def most_certain_indices(values, quantile=0.2):
    thresh = np.quantile(values, quantile)
    idx = np.where(values <= thresh)[0]
    return idx

Metrics on 50% of most certain queries:

In [None]:
certain_indices = most_certain_indices(log_det_np, 0.50)
_ = evaluate_retrieval_metrics(similarities_var[certain_indices], np.arange(similarities_var.shape[0])[certain_indices],
                               ks_map=[1, 2, 3, 5, 10, 15, 20, 30, 50], save_path="visuals/retrieval_metrics_V_50p.pdf")

Metrics on 20% of the most certain queries:

In [None]:
certain_indices = most_certain_indices(log_det_np, 0.20)
_ = evaluate_retrieval_metrics(similarities_var[certain_indices], np.arange(similarities_var.shape[0])[certain_indices],
                               ks_map=[1, 2, 3, 5, 10, 15, 20, 30, 50], save_path="visuals/retrieval_metrics_V_20p.pdf")

Checking the performance of most certain queries predicted by variational projector and applying them on the deterministic approach to show that predicted generalized variance is meaningful.


In [None]:
certain_indices = most_certain_indices(log_det_np, 0.50)
_ = evaluate_retrieval_metrics(similarities_det[certain_indices], np.arange(similarities_det.shape[0])[certain_indices],
                               ks_map=[1, 2, 3, 5, 10, 15, 20, 30, 50], save_path=None)

### Variational retrieval examples


Some of the image examples used for the deterministic approach:

In [None]:
indices = [0, 11, 20, 266, 320, 323, 325, 462]
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_var, indices, k=k), k=k, save_path=None)

Checking the performance on more uncertain images. One can see that most images include less detail and are indeed hard to accurately find.

In [None]:
uncertain_indices = np.argsort(log_det_np)[-1005:-995]
visualize_retrieval_batch(*get_retrieved_images_by_indices(similarities_var, uncertain_indices, k=k), k=k, save_path=None)