In [None]:
"""
Look at runtimes on instances where Catfish fails
i.e. instances where k_catfish > k_opt
"""

import matplotlib.pyplot as plt
import numpy as np
import collections
import table_generator

import scipy.stats as scistat



def get_acc(true_pos, true_neg, false_pos, false_neg):
    return (true_pos + true_neg) / (true_pos + true_neg + false_pos + false_neg)

def effone(true_pos, true_neg, false_pos, false_neg):
    precision = true_pos / (true_pos + false_pos)
    recall = true_pos / (true_pos + false_neg)
    effone = 2/( (1/precision) + (1/recall) )
    return precision, recall, effone



# CHANGE froot TO THE DATASET YOU WANT: zebra, human, mouse
# for froot in ['zebra', 'mouse', 'human']:
for froot in ['mouse']:
    timeoutexceed = 800
    which_alg = "toboggan"

    inputfile = "all-" + froot + ".txt"


    # [1] Get all dataset and toboggan info
    datadict, datamatrix, dict_cat_to_tob, dict_tob_to_cat = table_generator.make_tables(inputfile)
    info_dict = table_generator.get_toboggan_timing_info(datadict, datamatrix)
    nontrivials_dict = info_dict['nontrivials_dict']
    toboggan_completed = info_dict['toboggan_completed']
    toboggan_timeouts = info_dict['toboggan_timeouts']
    toboggan_num_paths_dict = info_dict['toboggan_num_paths_dict']

    catinputfile = 'catfish-' + froot + '-output.txt'
    catfish_dict, catfish_matrix = table_generator.get_catfish_tables( '../catfish-comparison/', catinputfile)

    # [2] Now look at Toboggan runtime on instances where Catfish fails
    # First, find instances where k_catfish > k_opt
    _, _, catfish_numpaths, _ = table_generator.get_catfish_timing_info( catfish_dict, catfish_matrix )

    tob_on_cat_failures = {}
    catfish_fail_times = {}
    gap_info = collections.defaultdict(int)
    for key, val in catfish_dict.items():
        key_tob = dict_cat_to_tob[key]
        num_tob_paths = toboggan_num_paths_dict[key_tob]
        if num_tob_paths is None:
            continue
        num_cat_paths = int(catfish_numpaths[key])
        num_tob_paths = int(num_tob_paths)
        if  num_tob_paths < num_cat_paths:
            tob_on_cat_failures[key_tob] = toboggan_completed[key_tob]
            catfish_fail_times[key] = val
            gap_info[ num_cat_paths - num_tob_paths ] += 1

    toboggan_times = list(tob_on_cat_failures.values())
    catfish_fail_times, _, _, _ = table_generator.get_catfish_timing_info(catfish_fail_times, catfish_matrix)

    
    print(froot)
    print(gap_info)
    print( "total instances that catfish fails: {}".format(sum(gap_info.values())) )
    print("")

    
    # [3] Now check how similar catfish runtimes are from non-optimal to optimal
    # interesect catfish keys with nontrivial keys
    nontrivial_catfish = catfish_dict.copy()
    for key in catfish_dict:
        key_tob = dict_cat_to_tob[key]
        if key_tob not in toboggan_completed:
            nontrivial_catfish.pop(key, None)
    catfish_nontrivial_times, _, _, _ = table_generator.get_catfish_timing_info(nontrivial_catfish, catfish_matrix)
    
    # https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.stats.ttest_ind.html
    ttest_stat, pvalue = scistat.ttest_ind(catfish_fail_times,
                                           catfish_nontrivial_times,
                                           axis=0,
                                           equal_var=False,
                                           nan_policy='propagate')
    print("T-test stat: {:5.8f}\n"
          "and pvalue:  {:5.8f}\n".format(ttest_stat, pvalue))
    print("Mean of catfish runtime on tough instances: {:5.5f}\n"
      "mean of catfish runtime on all nontrivials: {:5.5f}\n".format(np.mean(catfish_fail_times),
                                                                np.mean(catfish_nontrivial_times)) )
    
    # Now check how good a classifier runtime is:

    catfishlongtime = np.mean(catfish_fail_times)
    catfishshorttime = np.mean(catfish_nontrivial_times)
    catfishthreshold = np.mean( [catfishlongtime, catfishshorttime] )
    # catfishthreshold = np.sqrt(catfishlongtime * catfishshorttime)

    toboggan_num_paths_dict = info_dict['toboggan_num_paths_dict']
    _, catfish_times_dict, catfish_paths_dict, _ = table_generator.get_catfish_timing_info(nontrivial_catfish, 
                                                                                                          catfish_matrix)

    true_pos = 0
    false_pos = 0
    true_neg = 0
    false_neg = 0
    none_counter = 0

    for key, val in nontrivial_catfish.items():
        key_toboggan = dict_cat_to_tob[key]
        num_tob_paths = toboggan_num_paths_dict[key_toboggan]

        if num_tob_paths is None:
            none_counter += 1
            continue
        else:
            num_tob_paths = int(num_tob_paths)

        num_cat_paths = catfish_paths_dict[key]
        cat_time = catfish_times_dict[key]
        if num_tob_paths < num_cat_paths:
            if cat_time >= catfishthreshold:
                    true_pos += 1
            else:
                    false_neg += 1
        else:
            if cat_time >= catfishthreshold:
                    false_pos += 1
            else:
                    true_neg += 1

    print("Number of Nones is: {}".format(none_counter))
    print("Acc is: {:5.5f}".format(get_acc(true_pos, true_neg, false_pos, false_neg)))
    prec, recall, effone = effone(true_pos, true_neg, false_pos, false_neg)
    print("F1, prec, rec: {:5.5f}, {:5.5f}, {:5.5f}".format(effone, prec, recall))
    print("truepos  falsepos  trueneg  falseneg")
    print([ true_pos, false_pos, true_neg, false_neg ])