In [None]:
!jupyter nbextension enable --py widgetsnbextension

In [None]:
import sys
sys.path.append('../fuzzy-torch')  # or just install the module
sys.path.append('../fuzzy-tools')  # or just install the module
sys.path.append('../astro-lightcurves-handler')  # or just install the module

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.results.utils import get_model_names

rootdir = 'save/paper_v3'
set_name = 'test'
method = 'spm-mcmc-estw'
cfilename = f'survey=alerceZTFv7.1~bands=gr~mode=onlySNe~method={method}'
kf = '.'

model_names = get_model_names(rootdir, cfilename, kf, set_name)
print(f'model_names (#{len(model_names)}):')
for model_name in model_names:
    print(model_name)
print('/'*100)
bypass_prob = 0.0
ds_prob = 0.1
new_model_names = []
for model_name in model_names:
    if not 'pb=.' in model_name:
        continue
    if f'bypass_synth=0~bypass_prob={bypass_prob}~ds_prob={ds_prob}' in model_name:
        new_model_names += [model_name]
for model_name in new_model_names:
    print(model_name)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from lcclassifier.results.performances import plot_metric
import ipywidgets as widgets


def interact_f(dict_name, metric_name, target_class):
    plot_metric(rootdir, cfilename, kf, set_name, new_model_names, metric_name,
        std_prop=1 / 2,
        target_class=target_class,
        dict_name=dict_name,
        )
widgets.interact(interact_f,
    dict_name=['thdays_class_metrics', 'thdays_class_metrics_all_bands'],
    metric_name=['aucroc', 'precision', 'recall', 'f1score', 'aucpr'],
    target_class=[None, ' SNIbc', 'SNIIbn', 'SNIa', 'SLSN'],
    )

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.results.cms import plot_cm
import ipywidgets as widgets


def interact_f(dict_name, export_animation, alphabet_count_offset):
    new_model_names = []
    for model_name in model_names:
        if not 'b=202' in model_name:
            #continue
            pass
        if 'RNN' in model_name:
            continue
            pass
        if 'Attn' in model_name:
            #continue
            pass
        new_model_names += [model_name]
    plot_cm(rootdir, cfilename, kf, set_name, new_model_names,
        export_animation=export_animation,
        dict_name=dict_name,
        alphabet_count_offset=alphabet_count_offset,
        )
widgets.interact(interact_f,
    dict_name=['thdays_class_metrics', 'thdays_class_metrics_all_bands'],
    export_animation=False,
    alphabet_count_offset=[1, 0],
    )

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.results.tables import get_ps_performance_df
from fuzzytools.latex.latex_tables import LatexTable
import ipywidgets as widgets

metric_names = [
    'precision',
    'recall',
    'f1score',
    'aucroc',
    'aucpr',
    ]

def interact_f(uses_avg, dict_name, target_class):
    bypass_prob = 0.0
    ds_prob = 0.1
    new_model_names = []
    for model_name in model_names:
        if not 'pb=.' in model_name:
            continue
        if f'bypass_synth=0~bypass_prob={bypass_prob}~ds_prob={ds_prob}' in model_name:
            new_model_names += [model_name]
    info_df = get_ps_performance_df(rootdir, cfilename, kf, set_name, new_model_names, metric_names,
        uses_avg=uses_avg,
        dict_name=dict_name,
        target_class=target_class,
        #'override_model_name':False, # False True
        )
    for k in range(0, len(info_df)):
        info_df.indexs[k] = info_df.indexs[k].replace('=', '***')
        info_df.indexs[k] = info_df.indexs[k].replace('Model***', 'Model=')
    display(info_df())

    latex_table = LatexTable(info_df(),
        centered=True,
        repr_replace_dict={
            '***':'=',
            '-999.000±.000':'--',
            },
        )
    print(latex_table)

widgets.interact(interact_f,
    uses_avg=False,
    dict_name=['thdays_class_metrics', 'thdays_class_metrics_all_bands'],
    target_class=[None, ' SNIbc', 'SNIIbn', 'SNIa', 'SLSN'],
    )

In [None]:
%load_ext autoreload
%autoreload 2
from fuzzytools.datascience import statistical_tests as statistical_tests
from lcclassifier import _C

def get_dict(info_df, metric_name, str_to_ignore, ds_prob):
    df = info_df.get_df()
    df = df[[c for c in df.columns if _C.METRICS_D[metric_name]['mn'] in c]]
    values_dict = df.to_dict()
    values_dict = values_dict[list(values_dict.keys())[0]]
    new_values_dict = {}
    for k in values_dict.keys():
        if 'BRF' in k:
            new_values_dict[k] = values_dict[k]
        else:
            if str_to_ignore in k:
                continue
            if f'bypass_synth***1' in k:
                continue
            if f'ds_prob***{ds_prob}' in k:
                new_values_dict[k] = values_dict[k]
            
    return new_values_dict

metric_names = [
    'precision',
    'recall',
    'f1score',
    'aucroc',
    'aucpr',
    ]
ds_prob = 0.0
for str_to_ignore in ['Models=P', 'Models=S']:
#ds_prob = 0.1
#for str_to_ignore in ['RNN']:
    for metric_name in metric_names:
        print(f'str_to_ignore={str_to_ignore}; metric_name={metric_name}; ds_prob={ds_prob}')
        values_dict = get_dict(info_df, metric_name, str_to_ignore, ds_prob)
        test = statistical_tests.permutationtest # ttest welchtest permutationtest
        df = statistical_tests.gridtest_greater(values_dict, test,
            test_kwargs={'num_rounds':1e5}, # 1e4 1e5
            th_pvalue_txt=0,
            n_decimals=4,
            check_samples=False,
            )
        display(df)