In [None]:
import os
import pathlib
import multiprocessing
from itertools import zip_longest

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

import anguilla.hypervolume as hv
from anguilla.fitness import benchmark
from anguilla.evaluation import load_logs
from anguilla.dominance import NonDominatedSet2D, NonDominatedSetKD

In [None]:
FNS_2D = [
    'ZDT1',
    'ZDT2',
    'ZDT3',
    'ZDT4',
    'ZDT6',
    'IHR1',
    'IHR2',
    'IHR3',
    'IHR4',
    'IHR6',
    'ELLI1',
    'ELLI2',
    'CIGTAB1',
    'CIGTAB2'
]

FNS_3D = [
    'DTLZ1',
    'DTLZ2',
    'DTLZ3',
    'DTLZ4',
    'DTLZ5',
    'DTLZ6',
    'DTLZ7',
    'GELLI3',
]

In [None]:
def compute_reference_point_kd_chunk(logs):
    point_set = NonDominatedSetKD()
    for log in logs:
        if log is not None:
            point_set.insert(log.data)
    return point_set.non_dominated_points

def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # Taken from the itertools recipes
    # Source: https://docs.python.org/3/library/itertools.html
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

def compute_reference_point_kd(logs, n_processes=1):
    chunksize = len(logs) // n_processes + (1 if len(logs) % n_processes != 0 else 0)
    point_set = NonDominatedSetKD()
    with multiprocessing.Pool(processes=n_processes) as pool:
        point_sets = pool.map(compute_reference_point_kd_chunk, grouper(logs, chunksize))
        for points in point_sets:
            point_set.insert(points)
    return point_set.upper_bound + 1.0

def compute_reference_point_2d(logs):
    point_set = NonDominatedSet2D()
    for log in logs:
        point_set.insert(log.data)
    return point_set.upper_bound + 1.0

def compute_reference_point(logs):
    n_logs = len(logs)
    n_objectives = logs[0].data.shape[1]
    if n_objectives == 2:
        return compute_reference_point_2d(logs)
    return compute_reference_point_kd(logs, n_processes=os.cpu_count())

In [None]:
def plot_hypervolume_history(logs, reference, title=""):
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111)
    data = {}
    key = None
    for log in sorted(logs, key=lambda log: (log.optimizer, log.n_evaluations, log.trial)):
        if key is None:
            key = (log.optimizer, log.n_evaluations)
        if key != (log.optimizer, log.n_evaluations):
            data[key[0]][key[1]] = np.median(data[key[0]][key[1]])
            key = (log.optimizer, log.n_evaluations)
        if log.optimizer not in data:
            data[log.optimizer] = {}
        if log.n_evaluations not in data[log.optimizer]:
            data[log.optimizer][log.n_evaluations] = []
        indicator = hv.calculate(log.data, reference, ignore_dominated=True)
        data[log.optimizer][log.n_evaluations] += [indicator]
    data[key[0]][key[1]] = np.median(data[key[0]][key[1]])
    ticks_set = False
    for opt, history in data.items():
        ax.plot(history.keys(), history.values(), marker='s', label=opt)
        if not ticks_set:
            ax.set_xticks(list(history.keys()))
            ticks_set = True
    reference_str = ",".join([f"{coord:.2E}" for coord in reference])
    ax.set_title(f"{title} with {logs[0].fn}\nReference: [{reference_str}]")
    ax.set_ylabel("Hypervolume (median)")
    ax.set_xlabel("Function evaluations")
    ax.legend()
    fig.tight_layout()
    return fig

In [None]:
fns = FNS_2D + FNS_3D
for i, fn in enumerate(fns):
    clear_output(wait=True)
    display(f"[{i+1}/{len(fns)} - {(i+1)/len(fns):.2%}]Processing {fn}...")
    path1 = pathlib.Path("./plots/hypervolume/anguilla")
    if not path1.exists():
        os.makedirs(path1, exist_ok=True)
    path2 = pathlib.Path("./plots/hypervolume/shark")
    if not path2.exists():
        os.makedirs(path2, exist_ok=True)

    logs1 = load_logs("./data/anguilla", fns=[fn], observations=["fitness"], search_subdirs=True)
    logs1_extra = load_logs("./data/shark", fns=[fn], opts=['NSGAII'], observations=["fitness"], search_subdirs=False)
    
    logs2 = load_logs("./data/shark", fns=[fn], observations=["fitness"], search_subdirs=False)

    reference = compute_reference_point(logs1 + logs2)
    if len(logs1) > 0:
        fig = plot_hypervolume_history(logs1 + logs1_extra, reference, "Anguilla")
        fig.savefig(path1.joinpath(f"{fn}.pdf"), bbox_inches="tight")
        fig.savefig(path1.joinpath(f"{fn}.png"), bbox_inches="tight")
        plt.close(fig)
    if len(logs2) > 0:
        fig = plot_hypervolume_history(logs2, reference)
        fig.savefig(path2.joinpath(f"{fn}.pdf"), bbox_inches="tight", "Shark")
        fig.savefig(path2.joinpath(f"{fn}.png"), bbox_inches="tight")
        plt.close(fig)