In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
from os import environ
import torch

from hydra import initialize, compose
from hydra.utils import instantiate

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
with initialize(config_path=".", version_base=None):
    cfg = compose("config")

In [None]:
# evaluate on fields not trained on
test_path = "/data/scratch/aakash/test_small"
test_dataset = instantiate(cfg.cached_simulator, cached_data_path=test_path, splits="0:100/0:100/0:100")
trainer = instantiate(cfg.train.trainer, logger=None)

In [None]:
BASE_PATH = "../../output/PSF_MODELS/single_field_base/checkpoints/epoch15.ckpt"
UNAWARE_PATH = "../../output/PSF_MODELS/multi_field_psf_unaware/checkpoints/epoch19.ckpt"
PARAMS_ONLY_PATH = "../../output/PSF_MODELS/multi_field_psf_params_only/checkpoints/epoch15.ckpt"

base_model = instantiate(cfg.encoder, image_normalizer={"concat_psf_params": False})
base_model.load_state_dict(torch.load(BASE_PATH)["state_dict"])
base_model.eval();

unaware_model = instantiate(cfg.encoder, image_normalizer={"concat_psf_params": False})
unaware_model.load_state_dict(torch.load(UNAWARE_PATH)["state_dict"])
unaware_model.eval();

params_only_model = instantiate(cfg.encoder, image_normalizer={"concat_psf_params": True})
params_only_model.load_state_dict(torch.load(PARAMS_ONLY_PATH)["state_dict"])
params_only_model.eval();

### Base model

In [None]:
base_results = trainer.test(base_model, datamodule=test_dataset)

### PSF-unaware model

In [None]:
unaware_results = trainer.test(unaware_model, datamodule=test_dataset)

### Concat params only model

In [None]:
params_results = trainer.test(params_only_model, datamodule=test_dataset)

### Concatenate results into dataframe

In [None]:
models = {
    "base": (base_model, base_results),
    "unaware": (unaware_model, unaware_results),
    "params": (params_only_model, params_results),
}

# Results
keys = list(base_results[0].keys())
data = { model_name: [results[0][key] for key in keys] for model_name, (_, results) in models.items() }
data_flat = pd.DataFrame.from_dict(data, orient="index", columns=[key.split("/")[1] for key in keys]).reset_index()
data_flat= data_flat.rename(columns={"index": "model"})
data_melt = pd.melt(data_flat, id_vars="model", value_vars=[key.split("/")[1] for key in keys], var_name="metric", value_name="value")
data_melt.to_csv("psf_model_results.csv")

from IPython.display import HTML
HTML(data_flat.to_html())

### Plot Results

In [None]:
def plot_results(data, ncols=3, title=None):
    sns.set_style('ticks')
    sns.set(font_scale=0.8)

    hue = "bin" if "bin" in data.columns else None

    g = sns.catplot(
        data,
        kind="bar",
        x="model", y="value", col="metric", hue=hue,
        sharex=False, sharey=False, col_wrap=ncols,
        height=3, aspect=1.5,
        palette="dark", alpha=0.6,
        legend=True
    )
    g.set_titles(template="{col_name}")

    for ax in g.axes:
        remove_ticks = False
        heights = []
        for container in ax.containers:
            heights.extend([rect.get_height() for rect in container.patches])
        median = np.median(heights)

        for container in ax.containers:
            orig_heights = [rect.get_height() for rect in container]
            # clip outlier heights
            for rect in container.patches:
                if rect.get_height() > np.abs(5 * median):
                    rect.set_height(np.abs(5 * median))
                    remove_ticks = True

            # add labels
            labels = ax.bar_label(container, labels=[f"{height:.3f}" for height in orig_heights], fontsize=6)

        new_heights = []
        for container in ax.containers:
            new_heights.extend([rect.get_height() for rect in container.patches])
        ax.set_ylim(min(0, min(new_heights) * 1.1 + 0.1), max(new_heights) * 1.1 + 0.1)

        ax.tick_params(axis="x", labelsize=6)

        # remove y ticks and labels
        if remove_ticks:
            ax.set(yticklabels=[])
    
    if title:
        fig = g.axes[-1].get_figure()
        plt.suptitle(title)
        fig.set_tight_layout(True)

In [None]:
keep_keys = ["f1", "star_fluxes_r_mae", "gal_fluxes_r_mae", "disk_hlr_mae", "bulge_hlr_mae"]
data_to_plot = data_melt[np.isin(data_melt["metric"], keep_keys)]

data_to_plot = data_to_plot.replace({"base": "Single Field", "unaware": "PSF-unaware", "params": "PSF Encoding"})
data_to_plot = data_to_plot.replace({"f1": "F1-score", "galaxy_fluxes_r_mae": "Galaxy flux, median estimation error", "star_fluxes_r_mae": "Star flux, median estimation error", "disk_hlr_mae": "Disk half-light radius, median estimation error", "bulge_hlr_mae": "Bulge half-light radius, median estimation error"})
plot_results(data_to_plot, ncols=2, title="Results on simulated data with variable PSFs")