In [7]:
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt

from map_elites import MAPElites2DEvaluation

In [8]:
# assumes loading run on stories_genre_ending domain (2D)
DATA_PATH = "data/histories_opinions_stories/qdaif/stories_genre_ending/lmx_near_seeded_init/1/history.jsonl" # here, median QD score run out of 5 re-runs (for QDAIF runs, measured up to iteration 2000, out of 5000)
N_BINS = (10, 10)
OUTPUT_TITLE = "QDAIF"

In [9]:
def load_results(path: Path)-> pd.DataFrame:
    X = pd.read_json(path, lines=True)
    return X

In [None]:
X = load_results(DATA_PATH)
X = X.reset_index()

X_BINS = [0.005, 0.02, 0.05, 0.20, 0.50, 0.80, 0.95, 0.98, 0.995]
Y_BINS = [0.005, 0.02, 0.05, 0.20, 0.50, 0.80, 0.95, 0.98, 0.995]

# compute the state of the archive at each iteration
map_elites_evaluation = MAPElites2DEvaluation(history_length=len(X), x_bins=X_BINS, y_bins=Y_BINS, start=(0,0), stop=(1,1))
map_elites_evaluation.fit(phenotype_key="phenotype", data=X)

archive = map_elites_evaluation.archive
archive[archive == -np.inf] = 0 # (iters, dim0, dim1)

sns.set_theme(style="white")
f, ax = plt.subplots(figsize=(9, 6))
g = sns.heatmap(np.flip(archive[-1], 0), cmap=sns.color_palette("rocket", as_cmap=True), square=True, linewidths=.5, vmin=0, vmax=1, ax=ax)
cbar = g.collections[0].colorbar
g.set_xticks([0, 5, 10], labels=['0.0', '0.5', '1.0'], fontsize=15)
cbar.ax.tick_params(labelsize=15)

# account for swapping of labels (and their corresponding range limits) from analysis scripts ran for baseline methods
if "/baselines/" in DATA_PATH:
    g.set_yticks([0, 5, 10], labels=['1.0', '0.5', '0.0'], fontsize=15)
else:
    g.invert_yaxis() # for QDAIF
    g.set_yticks([0, 5, 10], labels=['0.0', '0.5', '1.0'], fontsize=15)

g.set_xlabel("Ending (Tragic to Happy)", fontsize=15)
g.set_ylabel("Genre (Romance to Horror)", fontsize=15)

g.set_title(OUTPUT_TITLE, size=40)
f.show()