In [None]:
from pathlib import Path
import sys
sys.path.append('../src')

from matplotlib import pyplot as plt

from nlkda.data.base import DatasetEnum
from nlkda.models.utils import ModelEnum
from nlkda.utils import MLFlowClient, get_skyline


data_root = Path("/tmp/data/")
exp = data_root / 'experiments'
db = MLFlowClient(root=str(exp), tracking_uri="localhost:5000/")

In [None]:
# Plot
markers = ['o', 'P', 'v', 's', 'D']
styles = ['solid', 'dotted', 'dashed', 'dashdot']

for index, ds in enumerate(list(map(lambda c: c.value, DatasetEnum))):
    # plot cop tree
    cop_info = db.unified_get_entries(['params.dataset', 'params.model.model_type'],[True, True],[ds, ModelEnum.COP.value])
    if len(cop_info):
        cop_size = cop_info['metrics.model_size'][0]
        cop_mean = cop_info['metrics.cs_mean'][0]
        cop_median = cop_info['metrics.cs_median'][0]
        plt.plot(
            cop_size, 
            cop_mean, 
            color="black", 
            marker=markers[index], 
            zorder=100
        )
    else:
        print('No CoP info in mlflow')
    
    out_dir_string = "params.out_dir"
    aggregation_type = "combined" #  choose one of 'agg_p', 'agg_k' or 'combined'
    skyline_col = f"metrics.cs_mean_mono_{aggregation_type}"
    
    # get skyline
    df = db.unified_get_entries(['params.dataset', 'params.model.model_type'],[True, False],[ds, ModelEnum.COP.value])   
    if len(df):
        df = df.dropna(subset=[skyline_col])
        skyline = get_skyline(df, [f"metrics.size_{aggregation_type}",skyline_col],[True,True])

        plt.plot(
            skyline[f"metrics.size_{aggregation_type}"].to_numpy(), 
            skyline[skyline_col].to_numpy(), 
            marker=markers[index], 
            ls=styles[index], 
            label=f'{ds.upper()}', 
            zorder=99
        )
    else:
        print(f'No models for dataset {ds}')

plt.rcParams['axes.axisbelow'] = True
plt.yscale('log')
plt.xscale('log')
plt.ylabel(f'{skyline_col} [log]')
plt.xlabel('Size [log]')
plt.grid(which='minor', alpha=0.2)
plt.grid(which='major', alpha=1)
plt.legend()
plt.tight_layout()
plt.savefig('model-size-vs-cs.pdf')