In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import trange
from time import time


from mlae.evaluate import load_model, get_n_samples_from_dataloader, all_conditions
from mlae.evaluate.surrogate_lot_det import (
    log_det_surrogate_latent_encoder_mixed,
    log_det_surrogate_data_encoder_mixed,
)

In [21]:
def collect_gradient(model):
    return {
        name: deepcopy(param.grad)
        for name, param in model.named_parameters()
        if param.grad is not None and (param != 0).any()
    }


def grad_diff(grad1, grad2):
    diff_dict = {}
    for key in set(grad1) | set(grad2):
        try:
            value1 = grad1[key]
            value2 = grad2[key]
            diff_dict[key] = value1 - value2
        except KeyError:
            continue
    return diff_dict


def grad_norm(grad):
    grad_sum = 0
    grad_len = 0
    for value in grad.values():
        grad_sum = grad_sum + (value**2).sum()
        grad_len = grad_len + value.nelement()
    return grad_sum.sqrt()


def grad_dot(grad1, grad2):
    dot = 0
    for key in set(grad1) | set(grad2):
        try:
            value1 = grad1[key]
            value2 = grad2[key]
            dot = dot + (value1 * value2).sum()
        except KeyError:
            continue
    return dot


def gradient_cosine_similarity(grad1, grad2):
    grad1_norm = grad_norm(grad1)
    grad2_norm = grad_norm(grad2)
    grad_prod = grad_dot(grad1, grad2)
    return grad_prod / (grad1_norm * grad2_norm)

In [22]:
PATH_TO_MODEL = "./MLAE/lightning_logs/model_name/version_X"  # change model_name and version_X

model = load_model(PATH_TO_MODEL)

In [None]:
experiment_reps = 2
batch_sizes = [1, 16, 64, 256, 512]
dim = model.data_dim
latent_dim = model.latent_dim
max_hutchinson_samples = 16
device = "cuda"

data = []
for attempt in trange(experiment_reps):
    for batch_size in batch_sizes:
        for trace_space in ["data", "latent"]:
            model.to(device)

            batch = get_n_samples_from_dataloader(
                model.val_dataloader(), batch_size, all_conditions(model)[0]
            )
            batch, _, _, c_batch = model.apply_conditions(batch)
            batch = batch.to(device)
            c_batch = c_batch.to(device)

            # compute reference gradients
            reference_duration = []
            reference_gradient = []
            reference_hutchinson_samples = []
            model.zero_grad()
            start = time()
            reference_result = log_det_surrogate_latent_encoder_mixed(
                batch,
                c_batch,
                model.encode,
                model.decode,
                max_hutchinson_samples,
            )[-2]
            reference_result.mean().backward()
            reference_duration.append(time() - start)
            reference_gradient.append(collect_gradient(model))
            reference_hutchinson_samples.append(max_hutchinson_samples)

            # compute surrogate gradient approximations
            for hutchinson_samples in range(1, latent_dim + 1):
                for orthogonalize_hutchinson_samples in [True, False]:
                    try:
                        model.zero_grad()
                        start = time()
                        result = eval(f"log_det_surrogate_{trace_space}_encoder_mixed")(
                            batch,
                            c_batch,
                            model.encode,
                            model.decode,
                            hutchinson_samples,
                            orthogonalize=orthogonalize_hutchinson_samples,
                        )[-2]
                        result.mean().backward()
                        duration = time() - start
                        gradient = collect_gradient(model)

                        for (
                            reference_duration_i,
                            reference_gradient_i,
                            reference_hutchinson_samples_i,
                        ) in zip(
                            reference_duration,
                            reference_gradient,
                            reference_hutchinson_samples,
                        ):
                            data.append(
                                {
                                    "trace_space": trace_space,
                                    "dim": dim,
                                    "latent_dim": latent_dim,
                                    "attempt": attempt,
                                    "batch_size": batch_size,
                                    "reference_duration": reference_duration_i,
                                    "reference_hutchinson_samples": reference_hutchinson_samples_i,
                                    "hutchinson_samples": hutchinson_samples,
                                    "orthogonalize_hutchinson_samples": orthogonalize_hutchinson_samples,
                                    "dist": (
                                        grad_norm(
                                            grad_diff(gradient, reference_gradient_i)
                                        )
                                        / grad_norm(reference_gradient_i)
                                    ).item(),
                                    "neg_dot": (
                                        1
                                        - grad_dot(gradient, reference_gradient_i)
                                        / (
                                            grad_norm(gradient)
                                            * grad_norm(reference_gradient_i)
                                        )
                                    ).item(),
                                    "relative_norm": (
                                        grad_norm(gradient)
                                        / grad_norm(reference_gradient_i)
                                    ).item(),
                                    "duration": duration,
                                }
                            )
                    except TypeError:
                        continue

df = pd.DataFrame(data)

In [None]:
_, ax = plt.subplots(1, 2, figsize=(12, 3))

for j, (trace_space, axis) in enumerate(zip(["data", "latent"], [ax[0], ax[1]])):
    for i, batch_size in enumerate(batch_sizes):
        df_relevant = df.where(
            (df["orthogonalize_hutchinson_samples"] == True)
            & (df["trace_space"] == trace_space)
            & (df["batch_size"] == batch_size)
        ).dropna()
        dist_list = [
            df_relevant.where(df_relevant["attempt"] == i).dropna()["dist"].to_numpy()
            for i in range(experiment_reps)
        ]
        dist_list = [dist for dist in dist_list if len(dist) != 0]
        dist_mean = np.mean(dist_list, axis=0)
        dist_std = np.std(dist_list, axis=0)
        hutchinson_samples_numeration = np.arange(1, max_hutchinson_samples + 1)
        axis.plot(
            hutchinson_samples_numeration, dist_mean, label=f"{batch_size}"
        )
        axis.fill_between(
            hutchinson_samples_numeration,
            dist_mean + dist_std,
            dist_mean - dist_std,
            alpha=0.1,
        )

# non-orthogonalized hutchinson samples
df_relevant = df.where(
    (df["orthogonalize_hutchinson_samples"] == False)
    & (df["trace_space"] == "latent")
    & (df["batch_size"] == 1)
).dropna()
dist_list = [
    df_relevant.where(df_relevant["attempt"] == i).dropna()["dist"].to_numpy()
    for i in range(experiment_reps)
]
dist_list = [dist for dist in dist_list if len(dist) != 0]
dist_mean = np.mean(dist_list, axis=0)
dist_std = np.std(dist_list, axis=0)
hutchinson_samples_numeration = np.arange(1, max_hutchinson_samples + 1)
ax[1].plot(hutchinson_samples_numeration, dist_mean, label=f"{1}", linestyle="--")
ax[1].fill_between(
    hutchinson_samples_numeration,
    dist_mean + dist_std,
    dist_mean - dist_std,
    alpha=0.1,
)


ax[0].set_xlabel("number of Hutchinson samples")
ax[1].set_xlabel("number of Hutchinson samples")
ax[0].set_ylabel("relative gradient distance")
ax[1].set_ylabel("relative gradient distance")
ax[0].set_yscale("symlog")
ax[1].set_yscale("symlog")
tick_locations = [
    *np.linspace(0, 1, 11),
    *np.linspace(1, 10, 10),
    *np.linspace(10, 100, 10),
]
tick_labels_defined = {0: "0", 1: r"$10^0$", 10: r"$10^1$", 100: r"$10^2$"}
tick_labels = [
    tick_labels_defined[loc] if loc in [0, 1, 10, 100] else "" for loc in tick_locations
]
ax[0].set_yticks(tick_locations, tick_labels)
ax[1].set_yticks(tick_locations, tick_labels)
ax[0].legend(
    title="batch size", handletextpad=0.2, ncols=5, columnspacing=0.6, loc="lower left"
)
ax[1].legend(
    title="batch size", handletextpad=0.2, ncols=6, columnspacing=0.6, loc="upper right"
)

plt.tight_layout()