In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from modules.DS_models_stat import cut_cat, make_histogram, simple_recall, calc_corr_b
from modules.DS_data_transformation import inter_cats
from tqdm.notebook import tqdm
%config InlineBackend.figure_format = 'retina'

In [2]:
def recall_mz(z1=False, corr=False):
    
    def make_pic(ax, true_cat, det_cats, bins, prm, add_text='', switch=False, lsts=['-', '--', '-.', ':']):
        
        for name, lst in tqdm(zip(det_cats, lsts)):
            y = []
            for st, en in zip(bins[:-1], bins[1:]):
                dc = None
                tc = None
                if switch:
                    cur_df = det_cats[name].copy()
                    cur_df = cur_df[st <= cur_df[prm]]
                    cur_df = cur_df[cur_df[prm] < en]
                    tc = true_cat
                    dc = cur_df
                else:
                    cur_df = true_cat.copy()
                    cur_df = cur_df[st <= cur_df[prm]]
                    cur_df = cur_df[cur_df[prm] < en]
                    tc = cur_df
                    dc = det_cats[name]
                if len(tc) == 0 or len(dc) == 0:
                    y.append(0)
                else:
                    if corr:
                        rec = simple_recall(dc, tc)
                        corr_c=1
                        if rec > 0:
                            corr_c = calc_corr_b(dc, tc)
                        y.append(rec * corr_c)
                    else:
                        y.append(simple_recall(dc, tc))
            x = (np.array(bins[:-1]) + np.array(bins[1:]))/2
            p, = ax.plot(x, y, ls=lst)
            p.set_label(name + add_text)
        
        
    
    true_cats = {'eROSITA' : pd.read_csv('./Data/original_cats/eROSITA.csv'),
                'PSZ2(z)+MCXC+ACT' : 
                 pd.read_csv('./Data/united_cats/PSZ2(z)_MCXC_ACT_united.csv')}
    det_cats = {'SZcat(AL)' : 
                pd.read_csv('./Data/detected_cats/SZcatAL.csv'),
               'SZcat(gen)' : pd.read_csv('./Data/detected_cats/SZcatgen.csv'),
               'PSZ2' : pd.read_csv('./Data/original_cats/PSZ2.csv')}
    det_cats['MCXC'] = inter_cats(true_cats['eROSITA'].copy(), 
                                  pd.read_csv('./Data/original_cats/MCXC.csv'))
    if z1:
        name = 'PSZ2(z)+MCXC+ACT'
        true_cats[name] = cut_cat(true_cats[name], dict_cut={'z' : [0, 1]})
    
    fig, ax = plt.subplots(1, 3, figsize=(7 * 3, 5), sharey=True)
    #z
    bins = np.arange(0, 1.4, 0.1)
    make_pic(ax[0], true_cats['PSZ2(z)+MCXC+ACT'], {key : det_cats[key] for key in det_cats if key != 'MCXC'},
            prm='z', bins=bins, add_text=' recall PSZ2(z)+MCXC+ACT')
    
    #m500
    bins = [2 ** i for i in np.arange(0, 4, 0.2)]
    make_pic(ax[1], true_cats['PSZ2(z)+MCXC+ACT'], {key : det_cats[key] for key in det_cats if key != 'MCXC'},
            prm='M500', bins=bins, add_text=' recall PSZ2(z)+MCXC+ACT')
    
    #flux
    det_cats = {'SZcat(AL)' : 
                pd.read_csv('./Data/detected_cats/SZcatAL.csv'),
               'SZcat(gen)' : pd.read_csv('./Data/detected_cats/SZcatgen.csv'),
                'MCXC' : inter_cats(true_cats['eROSITA'].copy(), 
                                  pd.read_csv('./Data/original_cats/MCXC.csv')),
                'PSZ2(z)+MCXC+ACT' : pd.concat([pd.read_csv('./Data/original_cats/MCXC.csv'),
                                               pd.read_csv('./Data/original_cats/ACT.csv'),
                                               pd.read_csv('./Data/united_cats/PSZ2(z)_MCXC_ACT_united.csv')])
               }
    print(list(det_cats))
    mine, maxe = true_cats['eROSITA']['flux'].min(), true_cats['eROSITA']['flux'].max()
    bins = [np.e ** i for i in np.arange(np.log(mine), np.log(maxe), 0.8)]
    make_pic(ax[2], true_cats['eROSITA'], {key : det_cats[key] for key in det_cats},
            prm='flux', bins=bins, add_text=' recall eROSITA')
    #make_pic(ax[2], cut_cat(true_cats['PSZ2(z)+MCXC+ACT'], dict_cut={'l' : [0, 180], 'b' : [20, np.inf]}), 
    #         {'eROSITA': true_cats['eROSITA']},
    #        prm='flux', bins=bins, add_text=' recall PSZ2(z)+MCXC+ACT eb20', switch=True, lsts=[':'])
    
    for i in range(3):
        ax[i].grid(True, axis='both', which='major', linestyle=':')
        ax[i].grid(True, axis='both', which='minor', alpha=0.2, linestyle=':')
        ax[i].legend()
    ax[0].set_xlabel('z')
    ax[1].set_xlabel('M500')
    ax[2].set_xlabel('flux')
    ax[1].set_xscale('log')
    ax[2].set_xscale('log')
    ax[2].set_xlim(10**(-14), det_cats['MCXC']['flux'].max())
    fig.tight_layout()
    


In [None]:
recall_mz()