In [8]:
from pprint import pprint

import numpy as np
from itertools import chain
from monuseg_params_and_scores_compare import plot_current_setup, get_metric, plt_with_std, COLORS
from clustertools import build_datacube
from matplotlib import pyplot as plt



In [9]:
baseline_cube = build_datacube("monuseg-unet-baseline")
baseline_dice = baseline_cube("val_dice")
bl_dice_avg, bl_dice_std = np.mean(baseline_dice, axis=0).squeeze(), np.std(baseline_dice, axis=0).squeeze()
del baseline_cube
cube = build_datacube("monuseg-unet-hard")

In [10]:
to_plot = [
    "monu_rr", "no_distillation", "weights_mode",
    "weights_consistency_fn", "weights_minimum",
    "weights_neighbourhood", "distil_target_mode"
]


out_params = ["monu_rr", "monu_nc", "sparse_start_after", "n_calibration"]

param_values = set()

for _, out_cube in cube.iter_dimensions(*out_params):
    for values, in_cube in out_cube.iter_dimensions(*to_plot):
        param_values.add(values)

param_val_idxs = {v: i for i, v in enumerate(sorted(list(param_values)))}

In [11]:
def readable_weights_mode(wm):
    return {
        "pred_entropy": "entr",
        "pred_merged": "merg",
        "constant": "csnt",
        "balance_gt": "bala",
        "pred_consistency": "csty"
    }.get(wm, "n/a")

def make_label(wmode, params):
    n, v = ["w", "d", "m"], [readable_weights_mode(wmode), params['distillation'], params["distil_target_mode"]]
    if wmode == "pred_consistency" or wmode == "pred_merged":
        n.extend(["nh", "fn"])
        v.extend([params["weights_neighbourhood"], params["weights_consistency_fn"][:4]])
    elif not (wmode == "constant" or wmode == "balance_gt" or wmode == 'pred_entropy'):
        raise ValueError("unknown wmode '{}'".format(wmode))
    if wmode != "constant":
        n.append("wmin")
        v.append(params['weights_minimum'])
    return ", ".join(["{}={}".format(n, p) for n, p in zip(n, v)])


In [12]:
for (monu_rr, monu_nc, ssa, n_calib), out_cube in cube.iter_dimensions(*out_params):
    plt.figure(figsize=[12.8, 4.8])
    for_params = {
        "monu_rr": str(monu_rr), 
        "monu_nc": str(monu_nc), 
        "sparse_start_after": str(ssa), 
        "n_calibration": n_calib, 
    }

    plt_with_std(plt.gca(), np.arange(50), bl_dice_avg, bl_dice_std, label="baseline", color=COLORS[0])

    dice_ymin, dice_ymax = np.min(bl_dice_avg), np.max(bl_dice_avg)
    
    print(monu_rr, monu_nc, ssa, n_calib)

    for values, in_cube in out_cube.iter_dimensions(*to_plot):
        rr, nd, wm, wfn, wmin, wneigh, tmode = values
        if wm == "pred_merged" or wm == "pred_consistency":
            continue
        if in_cube.diagnose()["Missing ratio"] > 0.0:
            continue

        label = make_label(wm, {
            "monu_rr": rr, "distillation": int(not eval(nd)),
            "weights_consistency_fn": wfn,
            "weights_minimum": wmin, "weights_neighbourhood": wneigh,
            "distil_target_mode": tmode
        })

        print("> ", label)
        val_dice = np.array(get_metric("val_dice", in_cube))
        dice_mean = np.mean(val_dice, axis=0)
        dice_std = np.std(val_dice, axis=0)
        x = np.arange(dice_mean.shape[0])

        plt_with_std(plt.gca(), x, dice_mean, dice_std, label, COLORS[(param_val_idxs[values] + 1) % len(COLORS)], do_std=True, alpha=0.2)

        dice_ymin = min(dice_ymin, np.min(dice_mean))
        dice_ymax = max(dice_ymax, np.max(dice_mean))

    title = "_".join(map(lambda t: "{}={}".format(t[0], t[1]), for_params.items()))
    plt.title(title)

    plt.ylim(dice_ymin * 0.95, dice_ymax * 1.05)
    plt.xlim(0, 50)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    plt.ylabel("val dice (opt)")
    plt.xlabel("epoch")
    plt.tight_layout()

    filename = "hard_" + title + ".pdf"   
    
    plt.savefig(filename)
    plt.close()
    

0.9 1 15 0
>  w=bala, d=1, m=soft, wmin=0.0
>  w=csnt, d=1, m=soft
>  w=entr, d=1, m=soft, wmin=0.5
>  w=csnt, d=0, m=soft
0.9 1 15 1
0.9 2 15 0
>  w=bala, d=1, m=soft, wmin=0.0
>  w=csnt, d=1, m=soft
>  w=entr, d=1, m=soft, wmin=0.5
>  w=csnt, d=0, m=soft
0.9 2 15 1
>  w=bala, d=1, m=hard_dice, wmin=0.0
>  w=bala, d=1, m=soft, wmin=0.0
>  w=csnt, d=1, m=hard_dice
>  w=csnt, d=1, m=soft
>  w=entr, d=1, m=hard_dice, wmin=0.5
>  w=entr, d=1, m=soft, wmin=0.5
>  w=csnt, d=0, m=soft
