# <center>Graphs for article (Try 3) (COPY 2)</center>

(1.02.21) Отличие от оригинала: в выборке тестовых квазаров DR16q взяты только объекты с IS_QSO_FINAL == 1, согласно рекомендациям составителей каталога

(5.02.21) Отличие от COPY 1: метрики сравниваем только на объектах у которых есть вся фотометрия

(13.02.21) Графики для DR16q строим только по `SOURCE_Z == 'VI' and Z_CONF == 3`

(23.03.21) Наконец-то!!! Предфинальная версия графиков! Удалены предыдущие дубли

In [None]:
import glob
import itertools
import json
import os
import pickle
import re
import sys
import multiprocessing
import functools
import warnings
from pprint import pprint
from collections import defaultdict
from time import sleep

import IPython.core.display as ipd
import matplotlib as mpl
import matplotlib.collections as mplc
import matplotlib.patches as patches
import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import astropy.table
from astropy.table import Table
from scipy.stats import gaussian_kde
from tqdm.notebook import tqdm, trange
from sklearn.model_selection import train_test_split

sys.path.append('/home/victor/scripts/')
import srgpz

np.set_printoptions(precision=3)
tqdm.pandas()
sns.set('talk', 'whitegrid', 'deep', font_scale=1.0,  # font='Ricty',
        rc={"lines.linewidth": 2, 'grid.linestyle': '--'})
pd.set_option('display.max_rows', 400, 'display.max_columns', None, 'display.max_colwidth', 100)
plt.rcParams.update({'font.size': 24})
warnings.filterwarnings('ignore')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Functions

In [None]:
def add_colorbar(fig, vmin, vmax, label='', cmap='magma'):
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([1.01, 0.05, 0.025, 0.9])
    cbar_ax.set_title(label)
    norm = mpl.colors.Normalize(vmin=vmin,vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax)
    

def scatter_photo_z(data, col_true, col_pred, col_conf=None,
                    xlim=None, ylim=None, figsize_factor=1,
                    cmap='rainbow', title='', marker='o', alpha=1, ax=None, s=1.0):
    
    if xlim and ylim:
        figsize = ((xlim[1]-xlim[0])*figsize_factor, (ylim[1]-ylim[0])*figsize_factor)
    else:
        figsize = (10,10)
        
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    ax.set(
        title=title,
        xlabel='spec-z',
        ylabel='photo-z',
        xticks=np.linspace(0, 7, 15),
        yticks=np.linspace(0, 7, 15),
        xlim=xlim,
        ylim=ylim,
    )
    ax.tick_params(axis='x', rotation=90)
    
    if col_conf is not None:
        ax.scatter(
            data[col_true], data[col_pred],
            c=data[col_conf].astype(float), s=50*data[col_conf].astype(float)**(3/2)*figsize_factor,
            cmap=cmap, marker=marker, vmin=0, vmax=1
        )
    else:
        ax.scatter(
            data[col_true], data[col_pred],
            cmap=cmap, marker=marker, vmin=0, vmax=1, alpha=alpha, s=s
        )
    
    ax.fill_between([0,99], [-0.15, 84], [0.15,114], color='black', alpha=0.05)
    ax.plot([0,7],[0,7], c='black', linestyle='--')
    
    ax.grid(linestyle='--')
    

def pred_cols_format(mid):
    return f'zoo_{mid}_z_max', f'zoo_{mid}_z_maxConf'


# def metrics_by_bins(data, pred_col, true_col, bins_col, metric, bins, ax=None,
#                     diff=None, print_bins=False, **kwargs):
        
#     data = data.sort_values(by=bins_col)
#     x = list()
#     y = list()
    
#     if isinstance(bins, int):
#         bins_borders = list()
#         for bin_num in range(bins):
#             idx_start = bin_num * len(data) // bins
#             idx_end = (bin_num + 1) * len(data) // bins
#             bin_data = data.iloc[idx_start: idx_end]
#             bins_borders.append(bin_data[bins_col].min())
        
#         bins_borders.append(bin_data[bins_col].max())
    
#     else:
#         bins_borders = bins
        
#     if print_bins:
#         for v in bins_borders:
#             print(f'{v:.3f},')
        
#     for start, end in zip(bins_borders[:-1], bins_borders[1:]):
#         bin_data = data.loc[(start <= data[bins_col]) & (data[bins_col] < end)]
# #         print(len(bin_data))
#         x.append(bin_data[bins_col].mean())
#         y.append(metric(bin_data[pred_col], bin_data[true_col]))

#     if diff is not None and len(diff) == bins:
#         y = np.array(y) - diff
        
#     if ax is None:
#         l = None
#     else:
#         l, = ax.plot(x, y, **kwargs)
        
#     return l, x, y


def metrics_by_bins(data, bins_col, metric, bins, ax=None,
                    diff=None, print_bins=False, override_first_point=None, **kwargs):
        
    data = data.sort_values(by=bins_col)
    x = list()
    y = list()
    
    if isinstance(bins, int):
        bins_borders = list()
        for bin_num in range(bins):
            idx_start = bin_num * len(data) // bins
            idx_end = (bin_num + 1) * len(data) // bins
            bin_data = data.iloc[idx_start: idx_end]
            bins_borders.append(bin_data[bins_col].min())
        
        bins_borders.append(bin_data[bins_col].max())
    
    else:
        bins_borders = bins
        
    if print_bins:
        for v in bins_borders:
            print(f'{v:.3f},')
        
    for start, end in zip(bins_borders[:-1], bins_borders[1:]):
        bin_data = data.loc[(start <= data[bins_col]) & (data[bins_col] < end)]
#         print(len(bin_data))
        x.append(bin_data[bins_col].mean())
        y.append(metric(bin_data))

    if diff is not None and len(diff) == bins:
        y = np.array(y) - diff
        
    if override_first_point is not None:
        x[0] = override_first_point
        
    if ax is None:
        l = None
    else:
        l, = ax.plot(x, y, **kwargs)
        
    return l, x, y, bins_borders


def sklearn2singlearg(metric, *cols):
    def f(df):
        args = (df[col] for col in cols)
        return metric(*args)
    
    return f


def make_table(data, true_col, col, bins, ids, names, dataset_name, col_names):
    table = defaultdict(list)
    nobjs = defaultdict(list)
    table['Model'] = [names[mid] for mid in ids]
    table['ModelID'] = [mid for mid in ids]
    table['Dataset'] = [dataset_name for mid in ids]
    nobjs['Model'].append('nobjs')

    for bin_start, bin_end in bins:
        bin_mask = (bin_start < data[col]) & (data[col] <= bin_end)
        bin_data = data.loc[bin_mask]

        if bin_start == -99:
            table_col_name = '{} < {}'.format(col_names[col], bin_end)
            if bin_end == 99:
                table_col_name = 'All Objects'

        elif bin_end == 99:
            table_col_name = '{} < {}'.format(bin_start, col_names[col])

        else:
            table_col_name = '{} < {} < {}'.format(bin_start, col_names[col], bin_end)
        
        
        pred_cols = [f'zoo_{mid}_z_max' for mid in ids]
        conf_cols = [col + 'Conf' for col in pred_cols]

        for pred_col, conf_col in zip(pred_cols, conf_cols):
            if pred_col in data.columns:
                model_data = bin_data.loc[bin_data[pred_col].notna()]
                margs = (model_data[pred_col], model_data[true_col])
                nmad = srgpz.metrics.nmad_z(*margs)
                catout = srgpz.metrics.catastrophic_outliers_z(*margs)
                cell_data = f'{nmad:.3f} / {catout:.3f}'
            else:
                cell_data = '- / -'

            table[table_col_name].append(cell_data)

        nobjs[table_col_name].append(len(bin_data))

    nobjs['Dataset'].append(dataset_name)
    table = pd.concat([pd.DataFrame.from_dict(table), pd.DataFrame.from_dict(nobjs)], sort=False)
    return table


def cleanup_make_table(table: pd.DataFrame):
    table = table.sort_values(by='ModelID').reset_index(drop=True)
    table['_litr'] = table['Model'].isin(['ANN (Brescia, 2019)', 'Templates (Ananna, 2017)'])
    table = table.sort_values(by='_litr')
    table['_litr'] = table['Model'].isin(['nobjs'])
    table = table.sort_values(by='_litr')
    table['_litr'] = table[table.columns[-2]].isin(['- / -'])
    table = table[~table['_litr']]
    table = table.drop(columns=['_litr', 'ModelID'])
    return table


def choose_best(table, columns):
    for col in columns:
        col_data = table[col].values.copy()
        values = [re.findall(r'(.*) / (.*)', str(row)) for row in col_data]
        mask = [True if val and not (val[0][0] == '-') else False for val in values]
        values = np.array([v[0] for v, m in zip(values, mask) if m]).astype(float)
        
        best_values = values.min(axis=0)
        values = values.astype(str)
        for i, row in enumerate(values):
            for j, e in enumerate(row):
                e = float(e)
                if e == best_values[j]:
                    values[i, j] = f'\\textbf{{{e:.3f}}}'
                else:
                    values[i, j] = f'{e:.3f}'
        
        values = [f'{a} / {b}' for a, b in values]
        table.loc[mask, col] = values
    
    return table


def latex_from_table(df: pd.DataFrame, title='', star=False, index=False, newline=9,
                     hline_every=None, hline_every_shift=4):
    total_latex = ''
    for start in range(0, df.shape[1], newline):
        end = start + newline
        if end > df.shape[1]:
            end = df.shape[1]
        
        table = df.iloc[:, start:end]
        latex_code = table.to_latex(index=index, escape=False)
        latex_code = latex_code.replace('\\toprule', '\hline')
        latex_code = latex_code.replace('\\midrule', '\hline')
        latex_code = latex_code.replace('\\bottomrule', '\hline')
        latex_code = latex_code.replace('%', '\\%')
        latex_code = latex_code.replace('\n', '\n' + ' '*12)
        prefix = f'\\begin{{table*}}\n\t' if star else f'\\begin{{table}}\n\t'
        suffix = f'\\caption{{{title}}}\n\\end{{table*}}' if star else f'\\caption{{{title}}}\n\\end{{table}}'
        total_latex += prefix + latex_code + suffix + '\n\n'
    
    latex_list = total_latex.splitlines()
    if hline_every is not None:
        start = len(latex_list) - hline_every_shift
        for i in range(start, 0, -hline_every):
            if i == start or i < hline_every:
                continue
                
            latex_list.insert(i, "\\hline")
    
    latex_new = '\n'.join(latex_list)
    return latex_new


def print_header(s, lvl=3):
    ipd.display(ipd.HTML(f'<h{lvl}>{s}</h{lvl}>'))
    
    
def dz_norm(z_pred, z_true):
    return (z_pred - z_true) / (1 + z_true)
    
    
def additional_metric(z_pred, z_conf, z_true):
    dz_criterion = np.abs(dz_norm(z_pred, z_true)) > 0.15
    z_conf_criterion = z_conf < 0.4
    total = dz_criterion | z_conf_criterion
    return total.sum() / len(z_true)


def additional_metric2(z_pred, z_conf, z_true):
    dz_criterion = dz_norm(z_pred, z_true) > 0.15
    z_conf_criterion = z_conf < 0.4
    wo_phot_criterion = z_pred.isna()
    total = dz_criterion | z_conf_criterion | wo_phot_criterion
    return total.sum() / len(z_true)


def wo_phot(z_pred):
    return z_pred.isna().sum() / len(z_pred)


def kstest(z_pred_a, z_pred_b, z_true):
    return ((z_pred_a <= z_true) & (z_true < z_pred_b)).sum() / len(z_true)

In [None]:
def ab_mag(flux: pd.Series) -> pd.Series:
    return 22.5-2.5 * np.log10(flux)

# DR16q\Train sample

In [None]:
train = pd.read_pickle(
    '/data/victor/srgz_prod/data/01_train_QSO_XbalancedGALAXY-sdss_unwise-wo_3XMM_XXLN_S82X_LH-with_VHzQs-peaks_cut-ebv_new.gz_pkl',
    compression='gzip')

train = train.iloc[:, :30]
train['crosscorr_ix'] = np.arange(len(train))
train.to_pickle('/data/victor/graphs4article/data/01_train_QSO_XbalancedGALAXY-sdss_unwise-wo_3XMM_XXLN_S82X_LH-with_VHzQs-peaks_cut-ebv_new.gz_pkl',
    compression='gzip', protocol=4)

train = train[['ra', 'dec', 'crosscorr_ix']]
train = Table.from_pandas(train)

train.write('/data/victor/graphs4article/data/01_coords.fits', overwrite=True)
train

In [None]:
dr16q = Table.read('/data/victor/raw_data/DR16Q_v4.fits')
dr16q['crosscorr_ix'] = np.arange(len(dr16q))
dr16q.write('/data/victor/graphs4article/data/DR16Q_v4.fits', overwrite=True)

dr16q = dr16q[['RA', 'DEC', 'crosscorr_ix']]
dr16q.write('/data/victor/graphs4article/data/DR16Q_v4_coords.fits', overwrite=True)
dr16q

In [None]:
# commands I used for cross-match

!/home/horungev/Catalogs/SRG/crossmatch/getaroundr.py \
    -i /data/victor/graphs4article/data/01_coords.fits \
    -o /data/victor/graphs4article/data/01_sdss.fits \
    -r 1.0 -cat sdss_second \
    -asfx _input -iRA ra -iDEC dec -iSEPNAME sep:sep1 -full
    
!/home/horungev/Catalogs/SRG/crossmatch/getaroundr.py \
    -i /data/victor/graphs4article/data/DR16Q_v4_coords.fits \
    -o /data/victor/graphs4article/data/DR16Q_v4_sdss.fits \
    -r 1.0 -cat sdss_second \
    -asfx _input -iRA RA -iDEC DEC -iSEPNAME sep:sep1 -full

In [None]:
pd.set_option('display.precision', 10)

In [None]:
train_sdss = Table.read('/data/victor/graphs4article/data/01_sdss.fits')
train_sdss = train_sdss[['crosscorr_ix_input', 'OBJID']].to_pandas()
ipd.display(train_sdss.head())
train_sdss.shape, train_sdss['OBJID'].isna().sum()

In [None]:
dr16q_sdss = Table.read('/data/victor/graphs4article/data/DR16Q_v4_sdss.fits')
dr16q_sdss = dr16q_sdss[['crosscorr_ix_input', 'OBJID']].to_pandas()
ipd.display(dr16q_sdss.head())
dr16q_sdss.shape, dr16q_sdss['OBJID'].isna().sum()

In [None]:
mask = ~dr16q_sdss['OBJID'].isin(train_sdss['OBJID'])
mask.sum()

In [None]:
dr16q_not_in_train = dr16q_sdss.loc[mask]
dr16q_not_in_train = dr16q_not_in_train.drop_duplicates(subset='crosscorr_ix_input')
dr16q_not_in_train = dr16q_not_in_train.rename(columns={'crosscorr_ix_input': 'crosscorr_ix'})
print(dr16q_not_in_train.shape)
dr16q_not_in_train = Table.from_pandas(dr16q_not_in_train)[['crosscorr_ix']]
ipd.display(dr16q_not_in_train[:5])

In [None]:
dr16q = Table.read('/data/victor/graphs4article/data/DR16Q_v4.fits')
print(len(dr16q))
dr16q[:5]

In [None]:
dr16q = astropy.table.join(left=dr16q, right=dr16q_not_in_train)

In [None]:
print(len(dr16q))
dr16q[:5]

In [None]:
dr16q.write('/data/victor/graphs4article/data/DR16Q_v4-wo_01_train.fits', overwrite=True)

# validate

In [None]:
train = pd.read_pickle(
    '/data/victor/srgz_prod/data/01_train_QSO_XbalancedGALAXY-sdss_unwise-wo_3XMM_XXLN_S82X_LH-with_VHzQs-peaks_cut-ebv_new.gz_pkl',
    compression='gzip')

In [None]:
dr16q = Table.read('/data/victor/graphs4article/data/DR16Q_v4-wo_01_train.fits')
len(dr16q)

In [None]:
dr16q = dr16q[['RA', 'DEC']].to_pandas()

In [None]:
mask_dr16q = (dr16q['RA'] > 0) & (dr16q['RA'] < 0.1) & (dr16q['DEC'] > 31) & (dr16q['DEC'] < 36)
sns.scatterplot(dr16q.loc[mask_dr16q, 'RA'], dr16q.loc[mask_dr16q, 'DEC'])

mask_train = (train['ra'] > 0) & (train['ra'] < 0.1) & (train['dec'] > 31) & (train['dec'] < 36)
sns.scatterplot(train.loc[mask_train, 'ra'], train.loc[mask_train, 'dec'], marker='x')

# Looks OK!

# Prepare other samples

In [None]:
os.makedirs('/data/victor/graphs4article/data/', exist_ok=True)

print('Stripe82X-A17')
data = pd.read_pickle('/data/victor/srgz_models/data/test-sdss_unwise-all30sec_S82X-A17-asinhmags.gz_pkl',
                      compression='gzip')
data = data.iloc[:, :56]
ipd.display(data.head())
print(data.shape)
data.to_pickle('/data/victor/graphs4article/data/test-sdss_unwise-all30sec_S82X-A17.gz_pkl',
               compression='gzip', protocol=4)

df = data
fl = df['cross-match']&df['secureO']&(df['QF']<2.1)&(df['Fx']>0)&(df['zspec']>0)&(df['zph']>0)&(df['zphML']>0)
data = data.loc[fl]
print(data.shape)
data.to_pickle('/data/victor/graphs4article/data/test-sdss_unwise-all30sec_S82X-A17-masked.gz_pkl',
               compression='gzip', protocol=4)

In [None]:
print('VHzQs test')
data = pd.read_pickle('/data/victor/srgz_models/very_high_z_qso/train_VHzQs_lis8_r1s-wo_trQXbG_XXLN_S82X_LH.gz_pkl', compression='gzip')
data = data['nrow	ra_	dec_	#na_	desig_	ra_hms_	dec_dms_	zspec'.split('\t')]
ipd.display(data.head())
print(data.shape)
data.to_pickle('/data/victor/graphs4article/data/train_VHzQs.gz_pkl', compression='gzip', protocol=4)

In [None]:
print('VHzQs test')
data = pd.read_pickle('/data/victor/srgz_models/very_high_z_qso/test_VHzQs_lis8_r1s.gz_pkl', compression='gzip')
data = data['nrow	ra_	dec_	#na_	desig_	ra_hms_	dec_dms_	zspec'.split('\t')]
ipd.display(data.head())
print(data.shape)
data.to_pickle('/data/victor/graphs4article/data/test_VHzQs.gz_pkl', compression='gzip', protocol=4)

# Run pzph1

In [None]:
# Stripe82X
!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/stripe82x-a17/ \
        --xrayCatalog /data/victor/graphs4article/data/test-sdss_unwise-all30sec_S82X-A17-masked.gz_pkl \
        --baseCatalog "ls" \
        --primaryRadius 1 \
        --xrayRaCol ra --xrayDecCol dec --njobs 18 \
        --modelsSeries x1a --modelsIds 19 21 22 35


# VHzQs test
!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/vhzqs_test/ \
        --xrayCatalog /data/victor/graphs4article/data/test_VHzQs.gz_pkl \
        --baseCatalog "ls" \
        --primaryRadius 1 \
        --xrayRaCol ra_ --xrayDecCol dec_ --njobs 18 \
        --modelsSeries x1exp0 --modelsIds 19 21 22 35 \
        --customModels /home/victor/pzphlib/custom_models.json


# DR16q wo train
!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/DR16Q_v4-wo_01_train \
        --xrayCatalog /data/victor/graphs4article/data/DR16Q_v4-wo_01_train.fits \
        --baseCatalog "ls" \
        --primaryRadius 1 --chunkSize 10000 \
        --xrayRaCol RA --xrayDecCol DEC --njobs 24 \
        --modelsSeries x1a --modelsIds 19 21 22 35

In [None]:
!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/stripe82x-a17-table13_ls-base/ \
        --xrayCatalog /data/victor/graphs4article/data/stripe82x_J-ApJ-850-66-table13.gz_pkl \
        --baseCatalog "ls" \
        --primaryRadius 1 \
        --xrayRaCol RAcdeg --xrayDecCol DEcdeg --njobs 18 \
        --modelsSeries x1a --modelsIds 19 21 22 35


!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/stripe82x-a17-table13_ps-base/ \
        --xrayCatalog /data/victor/graphs4article/data/stripe82x_J-ApJ-850-66-table13.gz_pkl \
        --baseCatalog "ps" \
        --primaryRadius 1 \
        --xrayRaCol RAcdeg --xrayDecCol DEcdeg --njobs 18 \
        --modelsSeries x1a --modelsIds 19 21 22 35

In [None]:
# models with rodion photometry
!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/stripe82x-a17-table13_ls-base/ \
        --xrayCatalog /data/victor/graphs4article/data/stripe82x_J-ApJ-850-66-table13.gz_pkl \
        --baseCatalog "ls" \
        --primaryRadius 1 \
        --xrayRaCol RAcdeg --xrayDecCol DEcdeg --njobs 18 \
        --modelsSeries x1apswf --modelsIds 19 34 \
        --customModels /home/victor/pzphlib/custom_models.json


!python3 /home/victor/pzphlib/pzph1dot1.py \
        --output /data/victor/graphs4article/DR16Q_v4-wo_01_train \
        --xrayCatalog /data/victor/graphs4article/data/DR16Q_v4-wo_01_train.fits \
        --baseCatalog "ls" \
        --primaryRadius 1 --chunkSize 10000 \
        --xrayRaCol RA --xrayDecCol DEC --njobs 24 \
        --modelsSeries x1apswf --modelsIds 19 34 \
        --customModels /home/victor/pzphlib/custom_models.json

# Дубль 3

In [None]:
for sample in ['stripe82x-a17-table13_ls-base']:
    sample_path = f'/data/victor/graphs4article/{sample}/buf/'
    data = list()
    for features_file in tqdm(glob.glob(
                os.path.join(sample_path, 'part-*.features.gz_pkl')
            ), desc=f'Reading files for {sample}'):
        
        chunk_number = re.findall('^part-(\d*).features.gz_pkl$', os.path.basename(features_file))[0]
        
        features = pd.read_pickle(features_file, compression='gzip')
        preds = [
            pd.read_pickle(file, compression='gzip')
            for file in glob.glob(os.path.join(sample_path, f'part-{chunk_number}.preds.*.gz_pkl'))
        ]
        data_chunk = [features] + preds
        data_chunk = [df.loc[~df.index.duplicated(keep='last')] for df in data_chunk]
        data_chunk = pd.concat(data_chunk, axis=1)
        data.append(data_chunk)
    
    data = pd.concat(data, axis=0)
    data = data.reset_index(drop=True)
    if sample == 'stripe82x-a17-table13_ls-base':
        data = data.rename(columns={'zsp': 'zspec'})
        df = data
        fl = (df['QF']<2.1)&(df['zspec']>0)&(df['zph']>0)&(df['zphML']>0)
        data = data.loc[fl]
        
    if sample == 'DR16Q_v4-wo_01_train':
        data['zspec'] = data['Z']
        mask = data['SOURCE_Z'] == b'VI'
        mask &= data['Z_CONF'] == 3
        data = data.loc[mask]
        
        bins = np.linspace(0, 7, 71)
        col = 'zspec'
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

        thres = 1800
        dst = []
        for start, end in zip(bins[:-1], bins[1:]):
            chunk = data.loc[(start < data[col]) & (data[col] <= end)]
            if len(chunk) > thres:
                _, chunk = train_test_split(chunk, test_size=thres/len(chunk), random_state=42)

            dst.append(chunk)

        data = pd.concat(dst)
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

In [None]:
drop_cols_rodion = [
    'nrow', 'objID', 'ra', 'dec', 'zspec', 'origin', '__workxid__', 'ls_sep_input', 'ls_release', 'ls_brickid',
    'ls_brickname', 'ls_objid', 'ls_brick_primary', 'ls_brightblob', 'ls_maskbits', 'ls_type', 'ls_ra', 'ls_dec',
    'ls_ra_ivar', 'ls_dec_ivar', 'ls_bx', 'ls_by', 'ls_ebv', 'ls_mjd_min', 'ls_mjd_max', 'ls_ref_cat',
    'ls_ref_id', 'ls_pmra', 'ls_pmdec', 'ls_parallax', 'ls_pmra_ivar', 'ls_pmdec_ivar', 'ls_parallax_ivar',
    'ls_ref_epoch', 'ls_gaia_pointsource', 'ls_gaia_phot_g_mean_mag', 'ls_gaia_phot_g_mean_flux_over_error',
    'ls_gaia_phot_g_n_obs', 'ls_gaia_phot_bp_mean_mag', 'ls_gaia_phot_bp_mean_flux_over_error',
    'ls_gaia_phot_bp_n_obs', 'ls_gaia_phot_rp_mean_mag', 'ls_gaia_phot_rp_mean_flux_over_error',
    'ls_gaia_phot_rp_n_obs', 'ls_gaia_phot_variable_flag', 'ls_gaia_astrometric_excess_noise',
    'ls_gaia_astrometric_excess_noise_sig', 'ls_gaia_astrometric_n_obs_al', 'ls_gaia_astrometric_n_good_obs_al',
    'ls_gaia_astrometric_weight_al', 'ls_gaia_duplicated_source', 'ls_gaia_a_g_val', 'ls_gaia_e_bp_min_rp_val',
    'ls_gaia_phot_bp_rp_excess_factor', 'ls_gaia_astrometric_sigma5d_max', 'ls_gaia_astrometric_params_solved',
    'ls_fiberflux_g', 'ls_fiberflux_r', 'ls_fiberflux_z', 'ls_fibertotflux_g', 'ls_fibertotflux_r',
    'ls_fibertotflux_z', 'ls_mw_transmission_g', 'ls_mw_transmission_r', 'ls_mw_transmission_z',
    'ls_mw_transmission_w1', 'ls_mw_transmission_w2', 'ls_mw_transmission_w3', 'ls_mw_transmission_w4',
    'ls_nobs_g', 'ls_nobs_r', 'ls_nobs_z', 'ls_nobs_w1', 'ls_nobs_w2', 'ls_nobs_w3', 'ls_nobs_w4', 'ls_rchisq_g',
    'ls_rchisq_r', 'ls_rchisq_z', 'ls_rchisq_w1', 'ls_rchisq_w2', 'ls_rchisq_w3', 'ls_rchisq_w4',
    'ls_fracflux_g', 'ls_fracflux_r', 'ls_fracflux_z', 'ls_fracflux_w1', 'ls_fracflux_w2', 'ls_fracflux_w3',
    'ls_fracflux_w4', 'ls_fracmasked_g', 'ls_fracmasked_r', 'ls_fracmasked_z', 'ls_fracin_g', 'ls_fracin_r',
    'ls_fracin_z', 'ls_anymask_g', 'ls_anymask_r', 'ls_anymask_z', 'ls_allmask_g', 'ls_allmask_r',
    'ls_allmask_z', 'ls_wisemask_w1', 'ls_wisemask_w2', 'ls_psfsize_g', 'ls_psfsize_r', 'ls_psfsize_z',
    'ls_psfdepth_g', 'ls_psfdepth_r', 'ls_psfdepth_z', 'ls_galdepth_g', 'ls_galdepth_r', 'ls_galdepth_z',
    'ls_psfdepth_w1', 'ls_psfdepth_w2', 'ls_psfdepth_w3', 'ls_psfdepth_w4', 'ls_wise_coadd_id', 'ls_fracdev',
    'ls_fracdev_ivar', 'ls_shapeexp_r', 'ls_shapeexp_r_ivar', 'ls_shapeexp_e1', 'ls_shapeexp_e1_ivar',
    'ls_shapeexp_e2', 'ls_shapeexp_e2_ivar', 'ls_shapedev_r', 'ls_shapedev_r_ivar', 'ls_shapedev_e1',
    'ls_shapedev_e1_ivar', 'ls_shapedev_e2', 'ls_shapedev_e2_ivar', 'ls___workcid__',
    'ls_healpix_id_log2nside17', 'ls_flux_g', 'ls_flux_ivar_g', 'ls_flux_r', 'ls_flux_ivar_r', 'ls_flux_z',
    'ls_flux_ivar_z', 'ls_flux_w1', 'ls_flux_ivar_w1', 'ls_flux_w2', 'ls_flux_ivar_w2', 'ls_flux_w3',
    'ls_flux_ivar_w3', 'ls_flux_w4', 'ls_flux_ivar_w4', 'ls_counterparts_number', 'ls_single_counterpart',
    'ls_counterparts_type', 'sdss_sep_input', 'sdss_objID', 'sdss_MODE', 'sdss_CLEAN', 'sdss_ra', 'sdss_dec',
    'sdss_RAERR', 'sdss_DECERR', 'sdss_cModelFlux_u', 'sdss_cModelFluxIvar_u', 'sdss_cModelFlux_g',
    'sdss_cModelFluxIvar_g', 'sdss_cModelFlux_r', 'sdss_cModelFluxIvar_r', 'sdss_cModelFlux_i',
    'sdss_cModelFluxIvar_i', 'sdss_cModelFlux_z', 'sdss_cModelFluxIvar_z', 'sdss_psfFlux_u',
    'sdss_psfFluxIvar_u', 'sdss_psfFlux_g', 'sdss_psfFluxIvar_g', 'sdss_psfFlux_r', 'sdss_psfFluxIvar_r',
    'sdss_psfFlux_i', 'sdss_psfFluxIvar_i', 'sdss_psfFlux_z', 'sdss_psfFluxIvar_z', 'sdss_counterparts_number',
    'sdss_single_counterpart', 'sdss_counterparts_type', 'ps_sep_input', 'ps_objID', 'ps_raBest', 'ps_decBest',
    'ps_raStack', 'ps_decStack', 'ps_raStackErr', 'ps_decStackErr', 'ps_raMean', 'ps_decMean', 'ps_raMeanErr',
    'ps_decMeanErr', 'ps_objInfoFlag', 'ps_qualityFlag', 'ps_primaryDetection', 'ps_bestDetection',
    'ps_duplicat', 'ps_d_to', 'ps_fitext', 'ps_devaucou', 'ps_star', 'ps_w1fit', 'ps_w1bad', 'ps_w1mag',
    'ps_dw1mag', 'ps_w2fit', 'ps_w2bad', 'ps_w2mag', 'ps_dw2mag', 'ps_gKronFlux', 'ps_gKronFluxErr',
    'ps_rKronFlux', 'ps_rKronFluxErr', 'ps_iKronFlux', 'ps_iKronFluxErr', 'ps_zKronFlux', 'ps_zKronFluxErr',
    'ps_yKronFlux', 'ps_yKronFluxErr', 'ps_gPSFFlux', 'ps_gPSFFluxErr', 'ps_rPSFFlux', 'ps_rPSFFluxErr',
    'ps_iPSFFlux', 'ps_iPSFFluxErr', 'ps_zPSFFlux', 'ps_zPSFFluxErr', 'ps_yPSFFlux', 'ps_yPSFFluxErr',
    'ps_w1flux', 'ps_dw1flux', 'ps_w2flux', 'ps_dw2flux', 'ps_counterparts_number', 'ps_single_counterpart',
    'ps_counterparts_type', 'gaiaedr3_sep_input', 'gaiaedr3_designation', 'gaiaedr3_source_id', 'gaiaedr3_ra',
    'gaiaedr3_ra_error', 'gaiaedr3_dec', 'gaiaedr3_dec_error', 'gaiaedr3_parallax', 'gaiaedr3_parallax_error',
    'gaiaedr3_pm', 'gaiaedr3_pmra', 'gaiaedr3_pmra_error', 'gaiaedr3_pmdec', 'gaiaedr3_pmdec_error',
    'gaiaedr3_astrometric_n_good_obs_al', 'gaiaedr3_astrometric_gof_al', 'gaiaedr3_astrometric_chi2_al',
    'gaiaedr3_astrometric_excess_noise', 'gaiaedr3_astrometric_excess_noise_sig', 'gaiaedr3_pseudocolour',
    'gaiaedr3_pseudocolour_error', 'gaiaedr3_visibility_periods_used', 'gaiaedr3_ruwe',
    'gaiaedr3_duplicated_source', 'gaiaedr3_phot_g_n_obs', 'gaiaedr3_phot_g_mean_mag',
    'gaiaedr3_phot_bp_mean_flux', 'gaiaedr3_phot_bp_mean_flux_error', 'gaiaedr3_phot_bp_mean_mag',
    'gaiaedr3_phot_rp_mean_flux', 'gaiaedr3_phot_rp_mean_flux_error', 'gaiaedr3_phot_rp_mean_mag',
    'gaiaedr3_dr2_radial_velocity', 'gaiaedr3_dr2_radial_velocity_error', 'gaiaedr3_l', 'gaiaedr3_b',
    'gaiaedr3_ecl_lon', 'gaiaedr3_ecl_lat', 'gaiaedr3_phot_g_mean_flux', 'gaiaedr3_phot_g_mean_flux_error',
    'gaiaedr3_counterparts_number', 'gaiaedr3_single_counterpart', 'gaiaedr3_counterparts_type',
    'sdssdr16_u_psf', 'sdssdr16_g_psf', 'sdssdr16_r_psf', 'sdssdr16_i_psf', 'sdssdr16_z_psf',
    'sdssdr16_u_cmodel', 'sdssdr16_g_cmodel', 'sdssdr16_r_cmodel', 'sdssdr16_i_cmodel', 'sdssdr16_z_cmodel',
    'psdr2_g_kron', 'psdr2_r_kron', 'psdr2_i_kron', 'psdr2_z_kron', 'psdr2_y_kron', 'psdr2_g_psf', 'psdr2_r_psf',
    'psdr2_i_psf', 'psdr2_z_psf', 'psdr2_y_psf', 'ls_flux_g_ebv', 'ls_flux_r_ebv', 'ls_flux_z_ebv',
    'ls_flux_w1_ebv', 'ls_flux_w2_ebv', 'ls_flux_w3_ebv', 'ls_flux_w4_ebv', 'decals8tr_g', 'decals8tr_r',
    'decals8tr_z', 'decals8tr_Lw1', 'decals8tr_Lw2', 'decals8tr_Lw3', 'decals8tr_Lw4', 'sdssdr16_u-g_psf',
    'sdssdr16_u-r_psf', 'sdssdr16_u-i_psf', 'sdssdr16_u-z_psf', 'sdssdr16_u_psf-cmodel', 'sdssdr16_g-r_psf',
    'sdssdr16_g-i_psf', 'sdssdr16_g-z_psf', 'sdssdr16_g_psf-cmodel', 'sdssdr16_r-i_psf', 'sdssdr16_r-z_psf',
    'sdssdr16_r_psf-cmodel', 'sdssdr16_i-z_psf', 'sdssdr16_i_psf-cmodel', 'sdssdr16_z_psf-cmodel',
    'psdr2_g-r_psf', 'psdr2_g-i_psf', 'psdr2_g-z_psf', 'psdr2_g-y_psf', 'psdr2_g_psf-kron', 'psdr2_r-i_psf',
    'psdr2_r-z_psf', 'psdr2_r-y_psf', 'psdr2_r_psf-kron', 'psdr2_i-z_psf', 'psdr2_i-y_psf', 'psdr2_i_psf-kron',
    'psdr2_z-y_psf', 'psdr2_z_psf-kron', 'psdr2_y_psf-kron', 'decals8tr_g-r', 'decals8tr_g-z', 'decals8tr_r-z',
    'sdssdr16_g_cmodel-decals8tr_g', 'sdssdr16_r_cmodel-decals8tr_r', 'sdssdr16_z_cmodel-decals8tr_z',
    'psdr2_g_kron-decals8tr_g', 'psdr2_r_kron-decals8tr_r', 'psdr2_z_kron-decals8tr_z',
    'sdssdr16_u_cmodel-decals8tr_Lw1', 'sdssdr16_u_cmodel-decals8tr_Lw2', 'sdssdr16_g_cmodel-decals8tr_Lw1',
    'sdssdr16_g_cmodel-decals8tr_Lw2', 'sdssdr16_r_cmodel-decals8tr_Lw1', 'sdssdr16_r_cmodel-decals8tr_Lw2',
    'sdssdr16_i_cmodel-decals8tr_Lw1', 'sdssdr16_i_cmodel-decals8tr_Lw2', 'sdssdr16_z_cmodel-decals8tr_Lw1',
    'sdssdr16_z_cmodel-decals8tr_Lw2', 'psdr2_g_kron-decals8tr_Lw1', 'psdr2_g_kron-decals8tr_Lw2',
    'psdr2_r_kron-decals8tr_Lw1', 'psdr2_r_kron-decals8tr_Lw2', 'psdr2_i_kron-decals8tr_Lw1',
    'psdr2_i_kron-decals8tr_Lw2', 'psdr2_z_kron-decals8tr_Lw1', 'psdr2_z_kron-decals8tr_Lw2',
    'psdr2_y_kron-decals8tr_Lw1', 'psdr2_y_kron-decals8tr_Lw2', 'decals8tr_g-Lw1', 'decals8tr_g-Lw2',
    'decals8tr_r-Lw1', 'decals8tr_r-Lw2', 'decals8tr_z-Lw1', 'decals8tr_z-Lw2', 'decals8tr_Lw1-Lw2',
    'phot_is_train_pz', 'phot_is_train_pmatch', 'phot_is_train_pclass', 'phot_is_train_star',
    'phot_is_train_gal', 'phot_is_test_xxln_m16', 'phot_is_test_s82x_l19', 'phot_is_test_s82x_a17',
    'phot_is_test_qso', 'phot_is_test_star', 'phot_is_test_gal', 'phot_is_spec_sdss', 'phot_test_field',
    '__nrow__'
]

for fold in trange(2, desc='Reading CV data'):
    data_fold = pd.concat([
        pd.read_pickle(f'/data/victor/srgz_models/2-fold-cv/results_rodion/cv2_{fold}/buf/part-00000.features.gz_pkl',
                       compression='gzip'),
    ] + [
        pd.read_pickle(file, compression='gzip') for file in glob.glob(
            f'/data/victor/srgz_models/2-fold-cv/results_rodion/cv2_{fold}/buf/part-00000.preds.x1cv2_{fold}*.gz_pkl'
        )
    ], axis=1)
    
    for mid in [19, 34]:
        data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_max']
        data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_maxConf']
        data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1a_68']
        data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1b_68']
        
    data_cv_rodion = data_fold
    
    data_fold = pd.concat([
        pd.read_pickle(f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.features.gz_pkl',
                       compression='gzip'),
    ] + [
        pd.read_pickle(file, compression='gzip') for file in glob.glob(
            f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.preds.x1cv2_{fold}*.gz_pkl'
        )
    ], axis=1)

    for mid in [21, 22, 35]:
        data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_max']
        data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_maxConf']
        data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1a_68']
        data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1b_68']
        
    data_cv = data_fold
    
    data_cv_rodion_wo_dups = data_cv_rodion.drop_duplicates(subset=['nrow', 'origin'])
    data_cv_rodion_wo_dups.index = data_cv.index
    data_cv_rodion_ready = data_cv_rodion_wo_dups.drop(columns=drop_cols)
    data_fold = pd.concat([data_cv, data_cv_rodion_ready], axis=1)
    
    data_all[fold] = data_fold
    
data_all['0+1'] = pd.concat([data_all[0], data_all[1]])

In [None]:
data_cv_rodion_ready = data_cv_rodion_wo_dups.drop(columns=drop_cols)

In [None]:
data = pd.concat([data_cv, data_cv_rodion_ready], axis=1)

In [None]:
data_all = dict()

# for fold in trange(2, desc='Reading CV data'):
#     data_fold = pd.concat([
#         pd.read_pickle(f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.features.gz_pkl',
#                        compression='gzip'),
#     ] + [
#         pd.read_pickle(file, compression='gzip') for file in glob.glob(
#             f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.preds.x1cv2_{fold}*.gz_pkl'
#         )
#     ], axis=1)
    
#     for mid in [19, 34]:
#         data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_max']
#         data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_maxConf']
#         data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1a_68']
#         data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1b_68']

#     for mid in [21, 22, 35]:
#         data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_max']
#         data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_maxConf']
#         data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1a_68']
#         data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1b_68']

#     data_all[fold] = data_fold
    
# data_all['0+1'] = pd.concat([data_all[0], data_all[1]])

drop_cols_rodion = [
    'nrow', 'objID', 'ra', 'dec', 'zspec', 'origin', '__workxid__', 'ls_sep_input', 'ls_release', 'ls_brickid',
    'ls_brickname', 'ls_objid', 'ls_brick_primary', 'ls_brightblob', 'ls_maskbits', 'ls_type', 'ls_ra', 'ls_dec',
    'ls_ra_ivar', 'ls_dec_ivar', 'ls_bx', 'ls_by', 'ls_ebv', 'ls_mjd_min', 'ls_mjd_max', 'ls_ref_cat',
    'ls_ref_id', 'ls_pmra', 'ls_pmdec', 'ls_parallax', 'ls_pmra_ivar', 'ls_pmdec_ivar', 'ls_parallax_ivar',
    'ls_ref_epoch', 'ls_gaia_pointsource', 'ls_gaia_phot_g_mean_mag', 'ls_gaia_phot_g_mean_flux_over_error',
    'ls_gaia_phot_g_n_obs', 'ls_gaia_phot_bp_mean_mag', 'ls_gaia_phot_bp_mean_flux_over_error',
    'ls_gaia_phot_bp_n_obs', 'ls_gaia_phot_rp_mean_mag', 'ls_gaia_phot_rp_mean_flux_over_error',
    'ls_gaia_phot_rp_n_obs', 'ls_gaia_phot_variable_flag', 'ls_gaia_astrometric_excess_noise',
    'ls_gaia_astrometric_excess_noise_sig', 'ls_gaia_astrometric_n_obs_al', 'ls_gaia_astrometric_n_good_obs_al',
    'ls_gaia_astrometric_weight_al', 'ls_gaia_duplicated_source', 'ls_gaia_a_g_val', 'ls_gaia_e_bp_min_rp_val',
    'ls_gaia_phot_bp_rp_excess_factor', 'ls_gaia_astrometric_sigma5d_max', 'ls_gaia_astrometric_params_solved',
    'ls_fiberflux_g', 'ls_fiberflux_r', 'ls_fiberflux_z', 'ls_fibertotflux_g', 'ls_fibertotflux_r',
    'ls_fibertotflux_z', 'ls_mw_transmission_g', 'ls_mw_transmission_r', 'ls_mw_transmission_z',
    'ls_mw_transmission_w1', 'ls_mw_transmission_w2', 'ls_mw_transmission_w3', 'ls_mw_transmission_w4',
    'ls_nobs_g', 'ls_nobs_r', 'ls_nobs_z', 'ls_nobs_w1', 'ls_nobs_w2', 'ls_nobs_w3', 'ls_nobs_w4', 'ls_rchisq_g',
    'ls_rchisq_r', 'ls_rchisq_z', 'ls_rchisq_w1', 'ls_rchisq_w2', 'ls_rchisq_w3', 'ls_rchisq_w4',
    'ls_fracflux_g', 'ls_fracflux_r', 'ls_fracflux_z', 'ls_fracflux_w1', 'ls_fracflux_w2', 'ls_fracflux_w3',
    'ls_fracflux_w4', 'ls_fracmasked_g', 'ls_fracmasked_r', 'ls_fracmasked_z', 'ls_fracin_g', 'ls_fracin_r',
    'ls_fracin_z', 'ls_anymask_g', 'ls_anymask_r', 'ls_anymask_z', 'ls_allmask_g', 'ls_allmask_r',
    'ls_allmask_z', 'ls_wisemask_w1', 'ls_wisemask_w2', 'ls_psfsize_g', 'ls_psfsize_r', 'ls_psfsize_z',
    'ls_psfdepth_g', 'ls_psfdepth_r', 'ls_psfdepth_z', 'ls_galdepth_g', 'ls_galdepth_r', 'ls_galdepth_z',
    'ls_psfdepth_w1', 'ls_psfdepth_w2', 'ls_psfdepth_w3', 'ls_psfdepth_w4', 'ls_wise_coadd_id', 'ls_fracdev',
    'ls_fracdev_ivar', 'ls_shapeexp_r', 'ls_shapeexp_r_ivar', 'ls_shapeexp_e1', 'ls_shapeexp_e1_ivar',
    'ls_shapeexp_e2', 'ls_shapeexp_e2_ivar', 'ls_shapedev_r', 'ls_shapedev_r_ivar', 'ls_shapedev_e1',
    'ls_shapedev_e1_ivar', 'ls_shapedev_e2', 'ls_shapedev_e2_ivar', 'ls___workcid__',
    'ls_healpix_id_log2nside17', 'ls_flux_g', 'ls_flux_ivar_g', 'ls_flux_r', 'ls_flux_ivar_r', 'ls_flux_z',
    'ls_flux_ivar_z', 'ls_flux_w1', 'ls_flux_ivar_w1', 'ls_flux_w2', 'ls_flux_ivar_w2', 'ls_flux_w3',
    'ls_flux_ivar_w3', 'ls_flux_w4', 'ls_flux_ivar_w4', 'ls_counterparts_number', 'ls_single_counterpart',
    'ls_counterparts_type', 'sdss_sep_input', 'sdss_objID', 'sdss_MODE', 'sdss_CLEAN', 'sdss_ra', 'sdss_dec',
    'sdss_RAERR', 'sdss_DECERR', 'sdss_cModelFlux_u', 'sdss_cModelFluxIvar_u', 'sdss_cModelFlux_g',
    'sdss_cModelFluxIvar_g', 'sdss_cModelFlux_r', 'sdss_cModelFluxIvar_r', 'sdss_cModelFlux_i',
    'sdss_cModelFluxIvar_i', 'sdss_cModelFlux_z', 'sdss_cModelFluxIvar_z', 'sdss_psfFlux_u',
    'sdss_psfFluxIvar_u', 'sdss_psfFlux_g', 'sdss_psfFluxIvar_g', 'sdss_psfFlux_r', 'sdss_psfFluxIvar_r',
    'sdss_psfFlux_i', 'sdss_psfFluxIvar_i', 'sdss_psfFlux_z', 'sdss_psfFluxIvar_z', 'sdss_counterparts_number',
    'sdss_single_counterpart', 'sdss_counterparts_type', 'ps_sep_input', 'ps_objID', 'ps_raBest', 'ps_decBest',
    'ps_raStack', 'ps_decStack', 'ps_raStackErr', 'ps_decStackErr', 'ps_raMean', 'ps_decMean', 'ps_raMeanErr',
    'ps_decMeanErr', 'ps_objInfoFlag', 'ps_qualityFlag', 'ps_primaryDetection', 'ps_bestDetection',
    'ps_duplicat', 'ps_d_to', 'ps_fitext', 'ps_devaucou', 'ps_star', 'ps_w1fit', 'ps_w1bad', 'ps_w1mag',
    'ps_dw1mag', 'ps_w2fit', 'ps_w2bad', 'ps_w2mag', 'ps_dw2mag', 'ps_gKronFlux', 'ps_gKronFluxErr',
    'ps_rKronFlux', 'ps_rKronFluxErr', 'ps_iKronFlux', 'ps_iKronFluxErr', 'ps_zKronFlux', 'ps_zKronFluxErr',
    'ps_yKronFlux', 'ps_yKronFluxErr', 'ps_gPSFFlux', 'ps_gPSFFluxErr', 'ps_rPSFFlux', 'ps_rPSFFluxErr',
    'ps_iPSFFlux', 'ps_iPSFFluxErr', 'ps_zPSFFlux', 'ps_zPSFFluxErr', 'ps_yPSFFlux', 'ps_yPSFFluxErr',
    'ps_w1flux', 'ps_dw1flux', 'ps_w2flux', 'ps_dw2flux', 'ps_counterparts_number', 'ps_single_counterpart',
    'ps_counterparts_type', 'gaiaedr3_sep_input', 'gaiaedr3_designation', 'gaiaedr3_source_id', 'gaiaedr3_ra',
    'gaiaedr3_ra_error', 'gaiaedr3_dec', 'gaiaedr3_dec_error', 'gaiaedr3_parallax', 'gaiaedr3_parallax_error',
    'gaiaedr3_pm', 'gaiaedr3_pmra', 'gaiaedr3_pmra_error', 'gaiaedr3_pmdec', 'gaiaedr3_pmdec_error',
    'gaiaedr3_astrometric_n_good_obs_al', 'gaiaedr3_astrometric_gof_al', 'gaiaedr3_astrometric_chi2_al',
    'gaiaedr3_astrometric_excess_noise', 'gaiaedr3_astrometric_excess_noise_sig', 'gaiaedr3_pseudocolour',
    'gaiaedr3_pseudocolour_error', 'gaiaedr3_visibility_periods_used', 'gaiaedr3_ruwe',
    'gaiaedr3_duplicated_source', 'gaiaedr3_phot_g_n_obs', 'gaiaedr3_phot_g_mean_mag',
    'gaiaedr3_phot_bp_mean_flux', 'gaiaedr3_phot_bp_mean_flux_error', 'gaiaedr3_phot_bp_mean_mag',
    'gaiaedr3_phot_rp_mean_flux', 'gaiaedr3_phot_rp_mean_flux_error', 'gaiaedr3_phot_rp_mean_mag',
    'gaiaedr3_dr2_radial_velocity', 'gaiaedr3_dr2_radial_velocity_error', 'gaiaedr3_l', 'gaiaedr3_b',
    'gaiaedr3_ecl_lon', 'gaiaedr3_ecl_lat', 'gaiaedr3_phot_g_mean_flux', 'gaiaedr3_phot_g_mean_flux_error',
    'gaiaedr3_counterparts_number', 'gaiaedr3_single_counterpart', 'gaiaedr3_counterparts_type',
    'sdssdr16_u_psf', 'sdssdr16_g_psf', 'sdssdr16_r_psf', 'sdssdr16_i_psf', 'sdssdr16_z_psf',
    'sdssdr16_u_cmodel', 'sdssdr16_g_cmodel', 'sdssdr16_r_cmodel', 'sdssdr16_i_cmodel', 'sdssdr16_z_cmodel',
    'psdr2_g_kron', 'psdr2_r_kron', 'psdr2_i_kron', 'psdr2_z_kron', 'psdr2_y_kron', 'psdr2_g_psf', 'psdr2_r_psf',
    'psdr2_i_psf', 'psdr2_z_psf', 'psdr2_y_psf', 'ls_flux_g_ebv', 'ls_flux_r_ebv', 'ls_flux_z_ebv',
    'ls_flux_w1_ebv', 'ls_flux_w2_ebv', 'ls_flux_w3_ebv', 'ls_flux_w4_ebv', 'decals8tr_g', 'decals8tr_r',
    'decals8tr_z', 'decals8tr_Lw1', 'decals8tr_Lw2', 'decals8tr_Lw3', 'decals8tr_Lw4', 'sdssdr16_u-g_psf',
    'sdssdr16_u-r_psf', 'sdssdr16_u-i_psf', 'sdssdr16_u-z_psf', 'sdssdr16_u_psf-cmodel', 'sdssdr16_g-r_psf',
    'sdssdr16_g-i_psf', 'sdssdr16_g-z_psf', 'sdssdr16_g_psf-cmodel', 'sdssdr16_r-i_psf', 'sdssdr16_r-z_psf',
    'sdssdr16_r_psf-cmodel', 'sdssdr16_i-z_psf', 'sdssdr16_i_psf-cmodel', 'sdssdr16_z_psf-cmodel',
    'psdr2_g-r_psf', 'psdr2_g-i_psf', 'psdr2_g-z_psf', 'psdr2_g-y_psf', 'psdr2_g_psf-kron', 'psdr2_r-i_psf',
    'psdr2_r-z_psf', 'psdr2_r-y_psf', 'psdr2_r_psf-kron', 'psdr2_i-z_psf', 'psdr2_i-y_psf', 'psdr2_i_psf-kron',
    'psdr2_z-y_psf', 'psdr2_z_psf-kron', 'psdr2_y_psf-kron', 'decals8tr_g-r', 'decals8tr_g-z', 'decals8tr_r-z',
    'sdssdr16_g_cmodel-decals8tr_g', 'sdssdr16_r_cmodel-decals8tr_r', 'sdssdr16_z_cmodel-decals8tr_z',
    'psdr2_g_kron-decals8tr_g', 'psdr2_r_kron-decals8tr_r', 'psdr2_z_kron-decals8tr_z',
    'sdssdr16_u_cmodel-decals8tr_Lw1', 'sdssdr16_u_cmodel-decals8tr_Lw2', 'sdssdr16_g_cmodel-decals8tr_Lw1',
    'sdssdr16_g_cmodel-decals8tr_Lw2', 'sdssdr16_r_cmodel-decals8tr_Lw1', 'sdssdr16_r_cmodel-decals8tr_Lw2',
    'sdssdr16_i_cmodel-decals8tr_Lw1', 'sdssdr16_i_cmodel-decals8tr_Lw2', 'sdssdr16_z_cmodel-decals8tr_Lw1',
    'sdssdr16_z_cmodel-decals8tr_Lw2', 'psdr2_g_kron-decals8tr_Lw1', 'psdr2_g_kron-decals8tr_Lw2',
    'psdr2_r_kron-decals8tr_Lw1', 'psdr2_r_kron-decals8tr_Lw2', 'psdr2_i_kron-decals8tr_Lw1',
    'psdr2_i_kron-decals8tr_Lw2', 'psdr2_z_kron-decals8tr_Lw1', 'psdr2_z_kron-decals8tr_Lw2',
    'psdr2_y_kron-decals8tr_Lw1', 'psdr2_y_kron-decals8tr_Lw2', 'decals8tr_g-Lw1', 'decals8tr_g-Lw2',
    'decals8tr_r-Lw1', 'decals8tr_r-Lw2', 'decals8tr_z-Lw1', 'decals8tr_z-Lw2', 'decals8tr_Lw1-Lw2',
    'phot_is_train_pz', 'phot_is_train_pmatch', 'phot_is_train_pclass', 'phot_is_train_star',
    'phot_is_train_gal', 'phot_is_test_xxln_m16', 'phot_is_test_s82x_l19', 'phot_is_test_s82x_a17',
    'phot_is_test_qso', 'phot_is_test_star', 'phot_is_test_gal', 'phot_is_spec_sdss', 'phot_test_field',
    '__nrow__'
]

for fold in trange(2, desc='Reading CV data'):
    data_fold = pd.concat([
        pd.read_pickle(f'/data/victor/srgz_models/2-fold-cv/results_rodion/cv2_{fold}/buf/part-00000.features.gz_pkl',
                       compression='gzip'),
    ] + [
        pd.read_pickle(file, compression='gzip') for file in glob.glob(
            f'/data/victor/srgz_models/2-fold-cv/results_rodion/cv2_{fold}/buf/part-00000.preds.x1cv2_{fold}*.gz_pkl'
        )
    ], axis=1)
    
    for mid in [19, 34]:
        data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_max']
        data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_z_maxConf']
        data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1a_68']
        data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}pswf{mid}_ci1b_68']
        
    data_cv_rodion = data_fold
    
    data_fold = pd.concat([
        pd.read_pickle(f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.features.gz_pkl',
                       compression='gzip'),
    ] + [
        pd.read_pickle(file, compression='gzip') for file in glob.glob(
            f'/data/victor/srgz_models/2-fold-cv/results/cv2_{fold}/buf/part-00000.preds.x1cv2_{fold}*.gz_pkl'
        )
    ], axis=1)

    for mid in [21, 22, 35]:
        data_fold[f'zoo_x1a{mid}_z_max'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_max']
        data_fold[f'zoo_x1a{mid}_z_maxConf'] = data_fold[f'zoo_x1cv2_{fold}{mid}_z_maxConf']
        data_fold[f'zoo_x1a{mid}_ci1a_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1a_68']
        data_fold[f'zoo_x1a{mid}_ci1b_68'] = data_fold[f'zoo_x1cv2_{fold}{mid}_ci1b_68']
        
    data_cv = data_fold
    
    data_cv_rodion_wo_dups = data_cv_rodion.drop_duplicates(subset=['nrow', 'origin'])
    data_cv_rodion_wo_dups.index = data_cv.index
    data_cv_rodion_ready = data_cv_rodion_wo_dups.drop(columns=drop_cols)
    data_fold = pd.concat([data_cv, data_cv_rodion_ready], axis=1)
    
    data_all[fold] = data_fold
    
data_all['0+1'] = pd.concat([data_all[0], data_all[1]])


for sample in ['stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']:
    sample_path = f'/data/victor/graphs4article/{sample}/buf/'
    data = list()
    for features_file in tqdm(glob.glob(
                os.path.join(sample_path, 'part-*.features.gz_pkl')
            ), desc=f'Reading files for {sample}'):
        
        chunk_number = re.findall('^part-(\d*).features.gz_pkl$', os.path.basename(features_file))[0]
        
        features = pd.read_pickle(features_file, compression='gzip')
        preds = [
            pd.read_pickle(file, compression='gzip')
            for file in glob.glob(os.path.join(sample_path, f'part-{chunk_number}.preds.*.gz_pkl'))
        ]
        data_chunk = [features] + preds
        data_chunk = [df.loc[~df.index.duplicated(keep='last')] for df in data_chunk]
        data_chunk = pd.concat(data_chunk, axis=1)
        data.append(data_chunk)
    
    data = pd.concat(data, axis=0)
    data = data.reset_index(drop=True)
    if sample == 'stripe82x-a17-table13_ls-base':
        data = data.rename(columns={'zsp': 'zspec'})
        df = data
        fl = (df['QF']<2.1)&(df['zspec']>0)&(df['zph']>0)&(df['zphML']>0)
        data = data.loc[fl]
        
    if sample == 'DR16Q_v4-wo_01_train':
        data['zspec'] = data['Z']
        mask = data['SOURCE_Z'] == b'VI'
        mask &= data['Z_CONF'] == 3
        data = data.loc[mask]
        
        bins = np.linspace(0, 7, 71)
        col = 'zspec'
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

        thres = 1800
        dst = []
        for start, end in zip(bins[:-1], bins[1:]):
            chunk = data.loc[(start < data[col]) & (data[col] <= end)]
            if len(chunk) > thres:
                _, chunk = train_test_split(chunk, test_size=thres/len(chunk), random_state=42)

            dst.append(chunk)

        data = pd.concat(dst)
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)
        
    data['zoo_x1a19_z_max'] = data['zoo_x1apswf19_z_max']
    data['zoo_x1a19_z_maxConf'] = data['zoo_x1apswf19_z_maxConf']
    data['zoo_x1a34_z_max'] = data['zoo_x1apswf34_z_max']
    data['zoo_x1a34_z_maxConf'] = data['zoo_x1apswf34_z_maxConf']
    
    data['zoo_x1a19_ci1a_68'] = data['zoo_x1apswf19_ci1a_68']
    data['zoo_x1a19_ci1b_68'] = data['zoo_x1apswf19_ci1b_68']
    data['zoo_x1a34_ci1a_68'] = data['zoo_x1apswf34_ci1a_68']
    data['zoo_x1a34_ci1b_68'] = data['zoo_x1apswf34_ci1b_68']
    
        
    data_all[sample] = data
    
data = data_all['0+1']
mask = data['class'] == 'QSO'
mask |= data['origin'] == 'VHzQs'
data_all['0+1 (QSO only)'] = data.loc[mask]

data = data_all['0+1']
mask = data['class'] == 'GALAXY'
data_all['0+1 (GALAXY only)'] = data.loc[mask]

In [None]:
# folds = ['0+1', 'stripe82x-a17-table13_ls-base', '0+1 (QSO only)', 'DR16Q_v4-wo_01_train']
# titles = ['2-fold CV', 'Stripe82X', '2-fold CV (QSO Only)', 'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)']

folds = ['0+1 (QSO only)', 'DR16Q_v4-wo_01_train']
titles = ['2-fold CV (QSO Only)', 'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)']

In [None]:
models_ids = {
    19: 'Pan-STARRS + WISE',
#     21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
}

models_ids_metrics = {
    19: 'Pan-STARRS + WISE',
    21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
    34: 'SDSS + Pan-STARRS + WISE',
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
}

models_ids_short = {
    19: 'PW',
#     21: 'PDW',
    22: 'DW',
    35: 'SPDW',
}

models_ids_metrics_short = {
    19: 'PW',
    21: 'PDW',
    22: 'DW',
    34: 'SPW',
    35: 'SPDW',
}

# metrics charts (new)

In [None]:
z_bins_cv2 = [9.827747557554806, 17.13419637985013, 17.357575120155055, 17.579100052699232, 17.758957867548197, 17.9082410475928, 18.036726268982886, 18.14583221651932, 18.244829930859414, 18.33652587079903, 18.419562410135384, 18.497557442776014, 18.571027900479034, 18.63781685183387, 18.698689421928663, 18.75542677143206, 18.8081912015041, 18.858127363223694, 18.904346386330406, 18.94794603593956, 18.9903881814915, 19.030046545247764, 19.06933391365944, 19.107757793091295, 19.144645646657338, 19.18189048105735, 19.2200396950378, 19.25739167256188, 19.29418073475197, 19.33192841343464, 19.369341545553247, 19.407212891772218, 19.445693442126636, 19.484164485968897, 19.523543927297332, 19.562803704409053, 19.6023351933546, 19.641194873381693, 19.679350762407818, 19.71798610806622, 19.755406653369835, 19.792903139172306, 19.830040899017312, 19.866648903584892, 19.90247699130969, 19.937493972774234, 19.971504133988486, 20.005410267649168, 20.039742180266142, 20.07425463440665, 20.10839592003747, 20.141836451816314, 20.175237760638385, 20.2082612279237, 20.240862190008166, 20.27363939970968, 20.307048266961374, 20.34001160666969, 20.373117967497404, 20.405324775585694, 20.43856996861692, 20.471458673635357, 20.504278503251264, 20.537111504619375, 20.57064919178948, 20.60408997540508, 20.63729785458057, 20.671598715674975, 20.70444011595687, 20.738151163020987, 20.77278935592166, 20.807208164029582, 20.842199924655635, 20.87736454554745, 20.912391661079493, 20.948799924949082, 20.985947290720336, 21.02350321886055, 21.061469305612608, 21.101228507751074, 21.140685720613337, 21.181512448109945, 21.224174977841194, 21.268813618713573, 21.314540729039173, 21.36350242052781, 21.415683003174635, 21.471887882382692, 21.533218758421445, 21.600345073306542, 21.676911313133786, 21.770678738521514, 21.893164416324517, 22.08596910146424, 100]
bins_for_metrics = {
    'zspec': ('spec-z', 100, 50),
    '!zphot': ('photo-z', 100, 50),
    'ls_mag_z': ('z_mag', z_bins_cv2, 50),
    
#     'decals8tr_r': ('r_mag', 50, 100, 50),
}
zspec_col = 'zspec'

linestyles = [
    ':', '--', 'dashdot', (0, (3,1,1,1,1,1)), '-',
]

folds = [
    '0+1',
#     '0+1 (GALAXY only)',
#     '0+1 (QSO only)',
    'DR16Q_v4-wo_01_train'
]
titles = [
#     '2-fold CV (GALAXY Only)',
    '2-fold CV',
    'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)'
]

# for bin_col, (col_name, *nbins) in bins_for_metrics.items():
for fold_idx, fold in enumerate(folds):
    print_header(fold)
    nrows=2
    ncols=len(bins_for_metrics)
    figsize_factor = 10
    figsize=(10*ncols, 6*nrows)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
#                 for i in range(ncols):
#                     axs[-1][i].set(xlabel=col_name)
    
#     for i in range(ncols):
#         axs[-1][i].set(xlabel=col_name)

    legend = dict()
    for j, (bin_col, (col_name, *nbins)) in enumerate(bins_for_metrics.items()):
        data_fold = data_all[fold]
        try:
            data_fold['ls_mag_z'] = ab_mag(data_fold['ls_flux_z'])
        except KeyError:
            data_fold['ls_mag_z'] = ab_mag(data_fold['flux_z_min_error'])
            
        if bin_col == 'ls_mag_z' and fold == '0+1':
            override_first_point = 17
        else:
            override_first_point = None
        
        notna_mask = (data_fold['zspec'] >= 0)
        for mid in models_ids:
            pred_col = f'zoo_x1a{mid}_z_max'
            notna_mask &= data_fold[pred_col].notna()
        
        for midx, (mid, model_name) in enumerate(tqdm(models_ids_metrics.items())):
            pred_col = f'zoo_x1a{mid}_z_max'
            conf_col = pred_col + 'Conf'
            if pred_col not in data_fold.columns:
                continue

            metrics = {
                "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
    #                         "w/o photometry": sklearn2singlearg(wo_phot, pred_col),
                "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
#                 "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
            }

            bins = nbins[fold_idx]
            for i, (m_name, m) in enumerate(metrics.items()):
                if m_name == "w/o photometry":
                    continue

                l, _, _, bins = metrics_by_bins(data_fold.loc[notna_mask], bin_col if bin_col != '!zphot' else pred_col,
                                                m, bins, axs[i][j], linestyle=linestyles[midx],
                                               color=f'C{midx}', override_first_point=override_first_point)
#                 print(bins)
                legend[model_name] = l
            
                axs[i][0].set_ylabel(m_name)
#             ax[j].grid(axis='y', which='minor')
            
#             break
        
        axs[-1][j].set_xlabel(col_name)
        
        
#     ax[0].set(ylim=(0.001, 0.5), yscale='log')    
#     ax[0].yaxis.set_minor_formatter(FormatStrFormatter("%.3f"))
    ylim_nmad = (0, 1.05*max([axs[0][i].get_ylim()[1] for i in range(axs.shape[1])]))
    
    axs[0][1].set_ylim(ylim_nmad)
    axs[0][0].get_xaxis().set_ticklabels([])
    axs[0][1].get_yaxis().set_ticklabels([])
    
    ylim_n015 = (0, 1.05*max(
        [axs[i][j].get_ylim()[1] for i, j in itertools.product(range(axs.shape[0]), range(1, axs.shape[1]))]
    ))
    for j in range(axs.shape[1]):
        axs[0][j].set_ylim(ylim_nmad)
        for i in range(1, axs.shape[0]):    
            axs[i][j].set_ylim(ylim_n015)
    
    for i in range(axs.shape[0]):
        axs[i][2].invert_xaxis()
        
    for i in range(axs.shape[0]-1):
        for j in range(axs.shape[1]):
            axs[i][j].get_xaxis().set_ticklabels([])
            
    for i in range(axs.shape[0]):
        for j in range(1, axs.shape[1]):
            axs[i][j].get_yaxis().set_ticklabels([])
    

    fig.legend(legend.values(), legend.keys(), loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2)
    fig.tight_layout()
    plt.show()
    plt.close()
#     break

In [None]:
z_bins_cv2 = [9.827747557554806, 17.13419637985013, 17.357575120155055, 17.579100052699232, 17.758957867548197, 17.9082410475928, 18.036726268982886, 18.14583221651932, 18.244829930859414, 18.33652587079903, 18.419562410135384, 18.497557442776014, 18.571027900479034, 18.63781685183387, 18.698689421928663, 18.75542677143206, 18.8081912015041, 18.858127363223694, 18.904346386330406, 18.94794603593956, 18.9903881814915, 19.030046545247764, 19.06933391365944, 19.107757793091295, 19.144645646657338, 19.18189048105735, 19.2200396950378, 19.25739167256188, 19.29418073475197, 19.33192841343464, 19.369341545553247, 19.407212891772218, 19.445693442126636, 19.484164485968897, 19.523543927297332, 19.562803704409053, 19.6023351933546, 19.641194873381693, 19.679350762407818, 19.71798610806622, 19.755406653369835, 19.792903139172306, 19.830040899017312, 19.866648903584892, 19.90247699130969, 19.937493972774234, 19.971504133988486, 20.005410267649168, 20.039742180266142, 20.07425463440665, 20.10839592003747, 20.141836451816314, 20.175237760638385, 20.2082612279237, 20.240862190008166, 20.27363939970968, 20.307048266961374, 20.34001160666969, 20.373117967497404, 20.405324775585694, 20.43856996861692, 20.471458673635357, 20.504278503251264, 20.537111504619375, 20.57064919178948, 20.60408997540508, 20.63729785458057, 20.671598715674975, 20.70444011595687, 20.738151163020987, 20.77278935592166, 20.807208164029582, 20.842199924655635, 20.87736454554745, 20.912391661079493, 20.948799924949082, 20.985947290720336, 21.02350321886055, 21.061469305612608, 21.101228507751074, 21.140685720613337, 21.181512448109945, 21.224174977841194, 21.268813618713573, 21.314540729039173, 21.36350242052781, 21.415683003174635, 21.471887882382692, 21.533218758421445, 21.600345073306542, 21.676911313133786, 21.770678738521514, 21.893164416324517, 22.08596910146424, 100]
bins_for_metrics = {
    'zspec': ('spec-z', 100, 50),
    '!zphot': ('photo-z', 100, 50),
    'ls_mag_z': ('z_mag', z_bins_cv2, 50),
    
#     'decals8tr_r': ('r_mag', 50, 100, 50),
}
zspec_col = 'zspec'

linestyles = [
    ':', '--', 'dashdot', (0, (3,1,1,1,1,1)), '-',
]

folds = [
#     '0+1',
#     '0+1 (GALAXY only)',
#     '0+1 (QSO only)',
    'DR16Q_v4-wo_01_train'
]
titles = [
#     '2-fold CV (GALAXY Only)',
    '2-fold CV',
    'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)'
]

# for bin_col, (col_name, *nbins) in bins_for_metrics.items():
for fold_idx, fold in enumerate(folds):
    print_header(fold)
    nrows=2
    ncols=len(bins_for_metrics)
    figsize_factor = 10
    figsize=(10*ncols, 6*nrows)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
#                 for i in range(ncols):
#                     axs[-1][i].set(xlabel=col_name)
    
#     for i in range(ncols):
#         axs[-1][i].set(xlabel=col_name)

    legend = dict()
    for j, (bin_col, (col_name, *nbins)) in enumerate(bins_for_metrics.items()):
        data_fold = data_all[fold]
        try:
            data_fold['ls_mag_z'] = ab_mag(data_fold['ls_flux_z'])
        except KeyError:
            data_fold['ls_mag_z'] = ab_mag(data_fold['flux_z_min_error'])
            
        if bin_col == 'ls_mag_z' and fold == '0+1':
            override_first_point = 17
        else:
            override_first_point = None
        
        notna_mask = (data_fold['zspec'] >= 0)
        for mid in models_ids:
            pred_col = f'zoo_x1a{mid}_z_max'
            notna_mask &= data_fold[pred_col].notna()
        
        for midx, (mid, model_name) in enumerate(tqdm(models_ids_metrics.items())):
            pred_col = f'zoo_x1a{mid}_z_max'
            conf_col = pred_col + 'Conf'
            if pred_col not in data_fold.columns:
                continue

            metrics = {
                "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
    #                         "w/o photometry": sklearn2singlearg(wo_phot, pred_col),
                "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
#                 "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
            }

            bins = nbins[fold_idx]
            for i, (m_name, m) in enumerate(metrics.items()):
                if m_name == "w/o photometry":
                    continue

                l, _, _, bins = metrics_by_bins(data_fold.loc[notna_mask], bin_col if bin_col != '!zphot' else pred_col,
                                                m, bins, axs[i][j], linestyle=linestyles[midx],
                                               color=f'C{midx}', override_first_point=override_first_point)
#                 print(bins)
                legend[model_name] = l
            
                axs[i][0].set_ylabel(m_name)
                axs[i][j].set_yscale('log')
#             ax[j].grid(axis='y', which='minor')
            
#             break
        
        axs[-1][j].set_xlabel(col_name)
        
        
#     ax[0].set(ylim=(0.001, 0.5), yscale='log')    
#     ax[0].yaxis.set_minor_formatter(FormatStrFormatter("%.3f"))
    ylim_nmad = (0, 1.05*max([axs[0][i].get_ylim()[1] for i in range(axs.shape[1])]))
    
    axs[0][1].set_ylim(ylim_nmad)
    axs[0][0].get_xaxis().set_ticklabels([])
    axs[0][1].get_yaxis().set_ticklabels([])
    
    ylim_n015 = (0, 1.05*max(
        [axs[i][j].get_ylim()[1] for i, j in itertools.product(range(axs.shape[0]), range(1, axs.shape[1]))]
    ))
    for j in range(axs.shape[1]):
        axs[0][j].set_ylim(ylim_nmad)
        for i in range(1, axs.shape[0]):    
            axs[i][j].set_ylim(ylim_n015)
    
    for i in range(axs.shape[0]):
        axs[i][2].invert_xaxis()
        
    for i in range(axs.shape[0]-1):
        for j in range(axs.shape[1]):
            axs[i][j].get_xaxis().set_ticklabels([])
            
    for i in range(axs.shape[0]):
        for j in range(1, axs.shape[1]):
            axs[i][j].get_yaxis().set_ticklabels([])
    

    fig.legend(legend.values(), legend.keys(), loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2)
    fig.tight_layout()
    plt.show()
    plt.close()
#     break

In [None]:
data = data_all['0+1']

# advanced metrics

In [None]:
def xor(a, b):
    return (a or not b) and (not a or b)


def change_value(src, src_min=None, src_max=None):
    if src_min is not None:
        if src is None:
            return src_min
        
        return min(src, src_min)
    elif src_max is not None:
        if src is None:
            return src_max
        
        return max(src, src_max)
    else:
        raise ValueError("At least one of src_min or src_max must be not None")
        

def define_bins(data: pd.DataFrame, col: str, bins_size=None, bins_number=None, data_sorted=False):
    assert not xor(bins_size is None, bins_number is None)
    assert col in data.columns
    
    if not data_sorted:
        data = data.sort_values(by=col)
    
    if bins_size is not None:
        bins_number = len(data) // bins_size
        
    if bins_number < 2:
        return {(data[col].iloc[0], data[col].iloc[-1]): data}
        
    bins_size = len(data) // bins_number
        
    bins = dict()
    bin_start = data[col].iloc[0]
    for ibin in range(bins_number):
        if ibin == bins_number - 1:
            bin_data = data[ibin*bins_size: len(data)]
        else:
            bin_data = data[ibin*bins_size: (ibin+1)*bins_size]
        
        bin_end = bin_data[col].iloc[-1]
        bins[(bin_start, bin_end)] = bin_data
        bin_start = bin_end
    
    return bins


def split_bins(data:pd.DataFrame, col: str, bins: list):
    data = data.sort_values(by=col)
    dst = dict()
    for start, end in bins:
        mask = data[col] >= start
        mask &= data[col] <= end
        dst[(start, end)] = data.loc[mask]
        
    return dst


def norm(src, min, max):
    return (src - min) / (max-min)


def add_colorbar(fig, vmin, vmax, label='', cmap='magma'):
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([1.01, 0.05, 0.025, 0.9])
    cbar_ax.set_title(label)
    norm = mpl.colors.Normalize(vmin=vmin,vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax, ticks=np.linspace(vmin, vmax, 11))
    

def _count_lines(lines):
    count = 0
    for l in lines:
        if l[0] == 3.057:
            count += 1
            
    print(count)


def metric2d(
        data, zcol, xcol, metric=None, bins=(1, 1), bins_types=('n', 'n'),
        ax=None, figsize=(10, 10), title=''):
    
    linestyles = [
        ':', '--', 'dashdot', (0, (3,1,1,1,1,1)), '-',
    ]
    
    xcol, ycol = zcol, xcol
    mask = data[xcol].notna() & data[ycol].notna()
    
    # define bins
    if isinstance(bins[0], int):
        lines = define_bins(data.loc[mask], xcol,
                            None if bins_types[0] == 'n' else bins[0],
                            bins[0] if bins_types[0] == 'n' else None)
    else:
        lines = split_bins(data.loc[mask], xcol, bins[0])
    
    print(lines.keys())
    for k, l in lines.items():
        if isinstance(bins[1], int):
            lines[k] = define_bins(l, ycol,
                                   None if bins_types[1] == 'n' else bins[1],
                                   bins[1] if bins_types[1] == 'n' else None)
        else:
            lines[k] = split_bins(l, ycol, bins[1])
        
    # calculate metrics
    metric_values = {
        l: {
            b: metric(bin_data) for b, bin_data in line_data.items()
        } for l, line_data in lines.items()
    }
    
    # define if need to create axis
    ax_given = True
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        ax_given = False
    
    # draw all the rectangles for each bin
    for i, (l, line_metrics) in enumerate(metric_values.items()):
        line_metrics = dict(map(lambda kv: (np.mean(kv[0]), kv[1]), line_metrics.items()))
        ax.plot(list(line_metrics.keys()), list(line_metrics.values()),
                c=f'C{i}', ls=linestyles[i%5], label=f'[{l[0]:.3f}, {l[1]:.3f})',
                marker='x')
        
    ax.legend()
    # if ax was not given as argument, finish chart construction
    if not ax_given:
        fig.tight_layout()
        plt.show()
        plt.close()
    
    return metric_values

In [None]:
data = data_all['DR16Q_v4-wo_01_train']
data['ls_mag_z'] = ab_mag(data['ls_flux_z']).replace([-np.inf, np.inf], np.nan)
models2show = [19, 22, 35]

mask = data['zspec'].notna() & data['ls_mag_z'].notna()
for mid in models2show:
    pred_col = f'zoo_x1a{mid}_z_max'
    mask &= data[pred_col].notna()
    
data = data.loc[mask]

raw_metrics = {
    "NMAD": srgpz.metrics.nmad_z,
    "n>0.15": srgpz.metrics.catastrophic_outliers_z,
    "n(dz_norm>0.15 | zConf<0.4)": additional_metric,
}

xbins = {'zspec': ('spec-z', 500), '!zphot': ('photo-z', 500)}
zbins = {
    '!zconf': [(0, 0.4), (0.4, 0.7), (0.7, 1.0)],
    'ls_mag_z': [(13.3, 19.5), (19.5, 20), (20, 20.5), (20.5, 21.5), (21.5, 24)],
}
zspec_col = 'zspec'

for (xcol, (xcol_name, xbins)), (zcol, zbins) in itertools.product(xbins.items(), zbins.items()):
    print_header(f'{xcol} x {zcol}')
    nrows = len(raw_metrics)
    ncols = len(models2show)
    if zcol == '!zconf':
        nrows -= 1
    
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols,
                            figsize=(10*ncols, 6*nrows), sharex=True)
    
    for j, mid in enumerate(tqdm(models2show)):
        pred_col = f'zoo_x1a{mid}_z_max'
        conf_col = pred_col + 'Conf'
        
        metrics = {
            "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
            "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
            "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
        }
        
        if zcol == '!zconf':
            metrics = {
                "NMAD": metrics["NMAD"],
                "n>0.15": metrics["n>0.15"],
            }
        
        axs[-1][j].set_xlabel(xcol_name)
        axs[0][j].set_title(models_ids[mid])
        for i, (m_name, m_func) in enumerate(metrics.items()):
            axs[i][0].set(ylabel=m_name)
            metric2d(
                data,
                zcol if zcol != '!zconf' else conf_col,
                xcol if xcol != '!zphot' else pred_col,
                m_func, (zbins, xbins), ('n', 's'), axs[i][j]
            )
            
    ymax_nmad = max([axs[0][i].get_ylim()[1] for i in range(axs.shape[1])])
    ymax_n015 = max([axs[1][i].get_ylim()[1] for i in range(axs.shape[1])])
    if nrows == 3:
        ymax_n015 = max([ymax_n015] + [axs[2][i].get_ylim()[1] for i in range(axs.shape[1])])
    
    for i in range(axs.shape[1]):
        axs[0][i].set_ylim(-0.05*ymax_nmad, 1.05*ymax_nmad)
        axs[1][i].set_ylim(-0.05*ymax_n015, 1.05*ymax_n015)
        if nrows == 3:
            axs[2][i].set_ylim(-0.05*ymax_n015, 1.05*ymax_n015)
    
            
    fig.tight_layout()
    plt.show()
    plt.close()

In [None]:
data = data_all['0+1']
data['ls_mag_z'] = ab_mag(data['flux_z_min_error']).replace([-np.inf, np.inf], np.nan)
models2show = [19, 22, 35]

mask = data['zspec'].notna() & data['ls_mag_z'].notna()
for mid in models2show:
    pred_col = f'zoo_x1a{mid}_z_max'
    mask &= data[pred_col].notna()
    
data = data.loc[mask]

raw_metrics = {
    "NMAD": srgpz.metrics.nmad_z,
    "n>0.15": srgpz.metrics.catastrophic_outliers_z,
    "n(dz_norm>0.15 | zConf<0.4)": additional_metric,
}

xbins = {'zspec': ('spec-z', 2000), '!zphot': ('photo-z', 2000)}
# zbins = {'!zconf': 5, 'ls_mag_z': 5}
zbins = {
    '!zconf': [(0, 0.4), (0.4, 0.7), (0.7, 1.0)],
    'ls_mag_z': [(13.3, 19.5), (19.5, 20), (20, 20.5), (20.5, 21.5), (21.5, 24)],
}

zspec_col = 'zspec'

for (xcol, (xcol_name, xbins)), (zcol, zbins) in itertools.product(xbins.items(), zbins.items()):
    print_header(f'{xcol} x {zcol}')
    nrows = len(raw_metrics)
    ncols = len(models2show)
    if zcol == '!zconf':
        nrows -= 1
    
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols,
                            figsize=(10*ncols, 6*nrows), sharex=True)
    
    for j, mid in enumerate(tqdm(models2show)):
        pred_col = f'zoo_x1a{mid}_z_max'
        conf_col = pred_col + 'Conf'
        
        metrics = {
            "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
            "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
            "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
        }
        
        if zcol == '!zconf':
            metrics = {
                "NMAD": metrics["NMAD"],
                "n>0.15": metrics["n>0.15"],
            }
        
        axs[-1][j].set_xlabel(xcol_name)
        axs[0][j].set_title(models_ids[mid])
        for i, (m_name, m_func) in enumerate(metrics.items()):
            axs[i][0].set(ylabel=m_name)
            metric2d(
                data,
                zcol if zcol != '!zconf' else conf_col,
                xcol if xcol != '!zphot' else pred_col,
                m_func, (zbins, xbins), ('n', 's'), axs[i][j]
            )
            
    ymax_nmad = max([axs[0][i].get_ylim()[1] for i in range(axs.shape[1])])
    ymax_n015 = max([axs[1][i].get_ylim()[1] for i in range(axs.shape[1])])
    if nrows == 3:
        ymax_n015 = max([ymax_n015] + [axs[2][i].get_ylim()[1] for i in range(axs.shape[1])])
    
    for i in range(axs.shape[1]):
        axs[0][i].set_ylim(-0.05*ymax_nmad, 1.05*ymax_nmad)
        axs[1][i].set_ylim(-0.05*ymax_n015, 1.05*ymax_n015)
        if nrows == 3:
            axs[2][i].set_ylim(-0.05*ymax_n015, 1.05*ymax_n015)
    
            
    fig.tight_layout()
    plt.show()
    plt.close()

# Stripe 82X scatterplots

In [None]:
models_ids = {
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
    19: 'Pan-STARRS + WISE',
#     21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
}

models_ids_short = {
    35: 'SPDW',
    19: 'PW',
#     21: 'PDW',
    22: 'DW',
}

In [None]:
sample = 'stripe82x-a17-table13_ls-base'
sample_name = 'Stripe82X'
data = data_all[sample]

lim = (-0.1, 5.1)

print_header(sample_name, 1)
models_series = 'x1a'

# columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()]
# if sample == 'stripe82x-a17-table13_ls-base':
#     columns_with_nas += ['zph', 'zphML']

# data = data.dropna(subset=columns_with_nas)
# if sample == 'DR16Q_v4-wo_01_train':
#     mask = data['zspec'] < 4
#     data = data.loc[mask]

print(len(data))

zspec_col = 'zspec'

ncols = 3
nrows = len(models_ids)
# if sample == 'stripe82x-a17-table13_ls-base':
#     nrows += 2

nrows //= ncols

figsize_factor = 10
figsize = (figsize_factor*ncols, figsize_factor*nrows)
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=figsize)
ax = axs.flat

# fig.suptitle(f'{sample_name}', fontsize=24, y=1.01)

print(data['SpClass'].value_counts())

for i, mid in enumerate(models_ids):
    pred_col = f'zoo_{models_series}{mid}_z_max'
    z_conf_col = pred_col + 'Conf'
    print(data[pred_col].notna().sum())
    scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[i])
    ax[i].set(title=(models_ids[mid]), xlim=lim, ylim=lim)
    if i:
        ax[i].set(ylabel='')

data['PDZbest01'] = data['PDZbest'] / 100
# if sample == 'stripe82x-a17-table13_ls-base':
#     for j, mid in enumerate(['zph', 'zphML'], start=i+1):
#         pred_col = mid
#         z_conf_col = None
#         title = 'Template Model' if pred_col == 'zph' else 'NN Model'
#         if z_conf_col is not None:
#             scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[j],
#     #                        marker='x', s=15)
#                            )
#         else:
#             scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[j],
#                        marker='x', s=15)
            
#         ax[j].set(title=title, xlim=lim, ylim=lim)

add_colorbar(fig, 0, 1, 'zConf', 'rainbow')
fig.tight_layout()
plt.show()
# fig.savefig(os.path.join(charts_path, f'scatterplots_stripe82x.png'))
plt.close()

In [None]:
for sample in ['stripe82x-a17-table13_ls-base']:
    sample_path = f'/data/victor/graphs4article/{sample}/'
    data = list()
    for features_file in tqdm(glob.glob(
                os.path.join(sample_path, 'part-*.features.gz_pkl')
            ), desc=f'Reading files for {sample}'):
        
        chunk_number = re.findall('^part-(\d*).features.gz_pkl$', os.path.basename(features_file))[0]
        
        features = pd.read_pickle(features_file, compression='gzip')
        preds = [
            pd.read_pickle(file, compression='gzip')
            for file in glob.glob(os.path.join(sample_path, f'part-{chunk_number}.best*.gz_pkl'))
        ]
        data_chunk = [features] + preds
        data_chunk = [df.loc[~df.index.duplicated(keep='last')] for df in data_chunk]
        data_chunk = pd.concat(data_chunk, axis=1)
        data.append(data_chunk)
    
    data = pd.concat(data, axis=0)
    ipd.display(data.head())
    data = data.reset_index(drop=True)
    if sample == 'stripe82x-a17-table13_ls-base':
        data = data.rename(columns={'zsp': 'zspec'})
        df = data
        fl = (df['QF']<2.1)&(df['zspec']>0)&(df['zph']>0)&(df['zphML']>0)
        data = data.loc[fl]
        
    if sample == 'DR16Q_v4-wo_01_train':
        data['zspec'] = data['Z']
        mask = data['SOURCE_Z'] == b'VI'
        mask &= data['Z_CONF'] == 3
        data = data.loc[mask]
        
        bins = np.linspace(0, 7, 71)
        col = 'zspec'
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

        thres = 1800
        dst = []
        for start, end in zip(bins[:-1], bins[1:]):
            chunk = data.loc[(start < data[col]) & (data[col] <= end)]
            if len(chunk) > thres:
                _, chunk = train_test_split(chunk, test_size=thres/len(chunk), random_state=42)

            dst.append(chunk)

        data = pd.concat(dst)
        sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

In [None]:
sample = 'stripe82x-a17-table13_ls-base'
sample_name = 'Stripe82X'
# data = data_all[sample]

lim = (-0.1, 5.1)

print_header(sample_name, 1)
models_series = 'x1a'

# columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()]
# if sample == 'stripe82x-a17-table13_ls-base':
#     columns_with_nas += ['zph', 'zphML']

# data = data.dropna(subset=columns_with_nas)
# if sample == 'DR16Q_v4-wo_01_train':
#     mask = data['zspec'] < 4
#     data = data.loc[mask]

print(len(data))

zspec_col = 'zspec'

# fig.suptitle(f'{sample_name}', fontsize=24, y=1.01)

print(data['SpClass'].value_counts())
combined_name = 'Combined Predictions'

for i, mid in enumerate({a: b for a, b in models_ids.items() if a in [19, 22, 35]}):
    ncols = 3
    nrows = 1

    figsize_factor = 10
    figsize = (figsize_factor*ncols, figsize_factor*nrows)
    fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=figsize)
    ax = axs.flat
    
    i = 1
    pred_col = f'zoo_best-x1a_z_max'
    z_conf_col = None
    scatter_photo_z(data.loc[mask], zspec_col, pred_col, z_conf_col, ax=ax[i],
                       marker='x', s=15)
    
    mask = data[pred_col].notna() & data[zspec_col].notna() & data['zph'].notna() & data['zphML'].notna()
    margs = (data.loc[mask, pred_col], data.loc[mask, zspec_col])
    metrics4title = '\n'
    metrics4title += f'N={mask.sum()} '
    metrics4title += f'NMAD={srgpz.metrics.nmad_z(*margs):.3f} '
    metrics4title += f'n>0.15={srgpz.metrics.catastrophic_outliers_z(*margs):.3f}'
    
    ax[i].set(title=(combined_name+metrics4title), xlim=lim, ylim=lim)

    data['PDZbest01'] = data['PDZbest'] / 100
    if sample == 'stripe82x-a17-table13_ls-base':
        for j, mid in zip([0, 2], ['zph', 'zphML']):
            pred_col = mid
            z_conf_col = None
            title = 'Template Model' if pred_col == 'zph' else 'NN Model'
            if z_conf_col is not None:
                scatter_photo_z(data.loc[mask], zspec_col, pred_col, z_conf_col, ax=ax[j],
        #                        marker='x', s=15)
                               )
            else:
                scatter_photo_z(data.loc[mask], zspec_col, pred_col, z_conf_col, ax=ax[j],
                           marker='x', s=15)


#             mask = data[pred_col].notna() & data[zspec_col].notna()
            margs = (data.loc[mask, pred_col], data.loc[mask, zspec_col])
            print(len(margs[0]))
            metrics4title = '\n'
            metrics4title += f'N={mask.sum()} '
            metrics4title += f'NMAD={srgpz.metrics.nmad_z(*margs):.3f} '
            metrics4title += f'n>0.15={srgpz.metrics.catastrophic_outliers_z(*margs):.3f}'

            ax[j].set(title=title+metrics4title, xlim=lim, ylim=lim)
    
    ax[1].set_ylabel('')
    ax[2].set_ylabel('')
    # add_colorbar(fig, 0, 1, 'zConf', 'rainbow')
    fig.tight_layout()
    plt.show()
    plt.close()

# DR16 and CV scatterplots

In [None]:
sample = 'DR16Q_v4-wo_01_train'
sample_name = 'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)'
lim = (-0.1, 7.1)

data = data_all[sample]
print_header(sample_name, 1)
models_series = 'x1a'

columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()]
if sample == 'stripe82x-a17-table13_ls-base':
    columns_with_nas += ['zph', 'zphML']

data = data.dropna(subset=columns_with_nas)
# if sample == 'DR16Q_v4-wo_01_train':
#     mask = data['zspec'] < 4
#     data = data.loc[mask]

print(len(data))

zspec_col = 'zspec'

ncols = 3
nrows = len(models_ids)
if sample == 'stripe82x-a17-table13_ls-base':
    nrows += 2

nrows //= ncols

figsize_factor = 10
figsize = (figsize_factor*ncols, figsize_factor*nrows)
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=figsize)
ax = axs.flat

# fig.suptitle(f'{sample_name}', fontsize=24, y=1.01)

for i, mid in enumerate(models_ids):
    pred_col = f'zoo_{models_series}{mid}_z_max'
    z_conf_col = pred_col + 'Conf'
    scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[i])
    ax[i].set(title=(models_ids[mid]), xlim=lim, ylim=lim)
    if i:
        ax[i].set(ylabel='')

if sample == 'stripe82x-a17-table13_ls-base':
    for j, mid in enumerate(['zph', 'zphML'], start=i+1):
        pred_col = mid
        z_conf_col = None
        title = 'Template Model' if pred_col == 'zph' else 'NN Model'
        scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[j],
                       marker='x', s=15)
        ax[j].set(title=title, xlim=lim, ylim=lim)

add_colorbar(fig, 0, 1, 'zConf', 'rainbow')
fig.tight_layout()
plt.show()
plt.close()

In [None]:
sample = '0+1'
sample_name = 'Cross-Validation'
lim = (-0.1, 7.1)

data = data_all[sample]
print_header(sample_name, 1)
models_series = 'x1a'

columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()]
if sample == 'stripe82x-a17-table13_ls-base':
    columns_with_nas += ['zph', 'zphML']

data = data.dropna(subset=columns_with_nas)
# if sample == 'DR16Q_v4-wo_01_train':
#     mask = data['zspec'] < 4
#     data = data.loc[mask]

print(len(data))

zspec_col = 'zspec'

ncols = 3
nrows = len(models_ids)
if sample == 'stripe82x-a17-table13_ls-base':
    nrows += 2

nrows //= ncols

figsize_factor = 10
figsize = (figsize_factor*ncols, figsize_factor*nrows)
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=figsize)
ax = axs.flat

# fig.suptitle(f'{sample_name}', fontsize=24, y=1.01)

for i, mid in enumerate(models_ids):
    pred_col = f'zoo_{models_series}{mid}_z_max'
    z_conf_col = pred_col + 'Conf'
    scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[i])
    ax[i].set(title=(models_ids[mid]), xlim=lim, ylim=lim)
    if i:
        ax[i].set(ylabel='')

if sample == 'stripe82x-a17-table13_ls-base':
    for j, mid in enumerate(['zph', 'zphML'], start=i+1):
        pred_col = mid
        z_conf_col = None
        title = 'Template Model' if pred_col == 'zph' else 'NN Model'
        scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[j],
                       marker='x', s=15)
#         ax[j].set(title=title, xlim=lim, ylim=lim)

add_colorbar(fig, 0, 1, 'zConf', 'rainbow')
fig.tight_layout()
plt.show()
plt.close()

In [None]:
sample = '0+1 (GALAXY only)'
sample_name = 'Cross-Validation (GALAXY)'
lim = (-0.1, 7.1)

data = data_all[sample]
print_header(sample_name, 1)
models_series = 'x1a'

columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()]
if sample == 'stripe82x-a17-table13_ls-base':
    columns_with_nas += ['zph', 'zphML']

data = data.dropna(subset=columns_with_nas)
# if sample == 'DR16Q_v4-wo_01_train':
#     mask = data['zspec'] < 4
#     data = data.loc[mask]

print(len(data))

zspec_col = 'zspec'

ncols = 3
nrows = len(models_ids)
if sample == 'stripe82x-a17-table13_ls-base':
    nrows += 2

nrows //= ncols

figsize_factor = 10
figsize = (figsize_factor*2.5/7*ncols, figsize_factor*nrows)
fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=figsize)
ax = axs.flat

fig.suptitle(f'{sample_name}', fontsize=24, y=1.01)

for i, mid in enumerate(models_ids):
    pred_col = f'zoo_{models_series}{mid}_z_max'
    z_conf_col = pred_col + 'Conf'
    scatter_photo_z(data, zspec_col, pred_col, z_conf_col, ax=ax[i])
    ax[i].set(title=models_ids_short[mid], xlim=(-0.1, 2.6), ylim=lim)
    if i:
        ax[i].set(ylabel='')

add_colorbar(fig, 0, 1, 'zConf', 'rainbow')
fig.tight_layout()
plt.show()
plt.close()

# Stripe82X final table

In [None]:
def argminv(l, dst_val):
    bidx = 0
    bdiff = abs(l[0] - dst_val)
    for i, v in enumerate(l[1:], start=1):
        diff = abs(v - dst_val)
        if diff < bdiff:
            bidx = i
            bdiff = diff
            
    return bidx, l[bidx]

sample_name = {
    'stripe82x-a17-table13_ls-base': 'Stripe82x-A17 sample',
    'vhzqs_test': 'VHzQs test sample',
    'DR16Q_v4-wo_01_train': 'Quasars of DR16q excluding training examples'
}

graph_type = {
    'stripe82x-a17-table13_ls-base': ('table',),
}

bins_for_metrics = {
    'stripe82x-a17-table13_ls-base': {
#         'zspec': ('spec-z', [-99, 0.5, 1, 1.5, 2, 99]),
#         'ls_mag_z': ('z_mag', [-99, 19, 20, 20.5, 21, 23]),
        'FSoft': ('FSoft', [1.0e-14, 4.0e-14, 1.0])
    },
}

# bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['zspec'][1]
# bins_for_metrics['stripe82x-a17-table13_ls-base']['zspec'] = \
#     ('z_{spec}', [(-99, 99)] + list(zip(bins[:-1], bins[1:])))

# bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['decals8tr_z'][1]
# bins_for_metrics['stripe82x-a17-table13_ls-base']['ls_mag_z'] = \
#     ('z_{mag}', list(zip(bins[:-1], bins[1:])))

bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['FSoft'][1]
bins_for_metrics['stripe82x-a17-table13_ls-base']['FSoft'] = \
    ('FSoft', [(-99, b) for b in bins])

print(bins_for_metrics)

models_used = {
    'stripe82x-a17-table13_ls-base': 'x1a',
}

models_ids = {
    19: 'Pan-STARRS + WISE',
    21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
}


models_ids_short = {
    19: '\\ref{model:pw}',
    21: '\\ref{model:pdw}',
    22: '\\ref{model:dw}',
    34: '\\ref{model:spw}',
    35: '\\ref{model:spdw}',
}

sample = 'stripe82x-a17-table13_ls-base'
data = data_all[sample]
print_header(sample_name[sample], 1)
models_series = models_used[sample]

columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()] + ['zspec']
if sample == 'stripe82x-a17-table13_ls-base':
    columns_with_nas += ['zph', 'zphML']

try:
    data_fold['ls_mag_z'] = ab_mag(data_fold['ls_flux_z'])
except KeyError:
    data_fold['ls_mag_z'] = ab_mag(data_fold['flux_z_min_error'])
    
print(columns_with_nas)
print(len(data))
data = data.dropna(subset=columns_with_nas)
print(len(data))

zspec_col = 'zspec'
for gt in graph_type[sample]:
    if gt == 'table':
        total_latex = ''
#             total_tables = {}
        
        for col, (col_name, bins) in bins_for_metrics[sample].items():
            metric_bins = bins
            
            total_tables = dict()
            for bin_start, bin_end in metric_bins:
                bin_mask = (data[col] >= bin_start) & (data[col] < bin_end)
                bin_data = data.loc[bin_mask]

                header = col_name
                if bin_start != -99:
                    header = f'{bin_start} \leq {header}'
                if bin_end != 99:
                    header = f'{header} < {bin_end}'
                if bin_start == -99 and bin_end == 99:
                    header = 'All'

                print_header(f'{header} ({len(bin_data)} objects)')
                table = defaultdict(list)
                for mid, model_name in models_ids_metrics.items():
                    pred_col = f'zoo_{models_series}{mid}_z_max'
                    conf_col = pred_col + 'Conf'

                    model_mask = bin_data[pred_col].notna()
#                     print((~model_mask).sum())
                    model_data = bin_data.loc[model_mask]
    
#                     print(model_data[pred_col].isna().sum(), len(model_data))

                    model_name = models_ids_metrics_short[mid]
                    table['Model'].append(model_name)
                    table['$NMAD$'].append(
                        round(srgpz.metrics.nmad_z(model_data[pred_col], model_data[zspec_col]), 3)
                    )
                    table['$n>0.15$'].append(
                        round(srgpz.metrics.catastrophic_outliers_z(model_data[pred_col], model_data[zspec_col]), 3)
                    )
                    table['$C_{68} - 0.68$'].append(
                         round(kstest(
                             model_data[pred_col.replace('_z_max', '_ci1a_68')] + model_data[pred_col],
                             model_data[pred_col.replace('_z_max', '_ci1b_68')] + model_data[pred_col],
                             model_data[zspec_col],
                         ) - 0.68, 3)
                    )
#                         table['$n(dz_{norm}>0.15 | zConf<0.4)$'].append(
#                             round(additional_metric(model_data[pred_col], model_data[conf_col], model_data[zspec_col]), 3)
#                         )
#                         table['w/o photometry'].append(
#                             round((~model_mask).sum() / len(model_mask), 3)
#                         )

                if sample == 'stripe82x-a17-table13_ls-base':
                    for pred_col in ['zph', 'zphML']:
                        model_name = 'Template Model' if pred_col == 'zph' else 'NN Model'

                        model_mask = bin_data[pred_col].notna()
                        model_data = bin_data.loc[model_mask]

                        table['Model'].append(model_name)
                        table['$NMAD$'].append(
                            round(srgpz.metrics.nmad_z(model_data[pred_col], model_data[zspec_col]), 3)
                        )
                        table['$n>0.15$'].append(
                            round(srgpz.metrics.catastrophic_outliers_z(model_data[pred_col], model_data[zspec_col]), 3)
                        )
                        table['$C_{68} - 0.68$'].append(
                             round(kstest(
                                 model_data[f'e_{pred_col}'],
                                 model_data[f'E_{pred_col}'],
                                 model_data[zspec_col],
                             ) - 0.68, 3)
                        )
#                             table['$n(dz_{norm}>0.15 | zConf<0.4)$'].append(
#                                 round(additional_metric(model_data[pred_col], model_data[conf_col], model_data[zspec_col]), 3)
#                             )
#                         table['w/o photometry'].append(
#                             round((~model_mask).sum() / len(model_mask), 3)
#                         )

#                     for metric in ['NMAD', 'n>0.15', 'n(dz_norm>0.15 | zConf<0.4)', 'w/o photometry']:
                for metric in ['$NMAD$', '$n>0.15$']:
                    best_value = min(table[metric])
                    table[metric] = list(map(lambda x: str(x) if x != best_value else f'\textbf{{{x}}}',
                                             table[metric]))

                metric = '$C_{68} - 0.68$'
                _, best_value = argminv(table[metric], 0)
                table[metric] = list(map(lambda x: str(x) if x != best_value else f'\textbf{{{x}}}',
                                             table[metric]))


                table = pd.DataFrame.from_dict(table)  # .sort_values(by='$NMAD$')
#                 ipd.display(table)
                title=f"Objects of {sample_name[sample]} sample with ${header}$ ({len(bin_data)} objects)"
#                     total_latex += latex_from_table(table, title, star=True) + '\n\n'
#                     total_tables[f"${header}$"]
                full_header = f'${header}$ ({len(bin_data)} objects)'
                total_tables[full_header] = table

            df = pd.DataFrame()
            for i, (header, table) in enumerate(total_tables.items()):
                table = table.set_index('Model')
                table.columns = [[header] * table.shape[1], table.columns]

                df = pd.concat([df, table], axis=1)
                
            ipd.display(df)
            total_latex += latex_from_table(df, star=True, index=True, newline=9) + '\n\n\n'
        
        print(total_latex)

    else:
        print('Wrong type of graph:', gt)

# zConf calibration

In [None]:
sample_name = {
    'stripe82x-a17-table13_ls-base': 'Stripe82x-A17 sample',
    'DR16Q_v4-wo_01_train': 'Quasars of DR16q excluding training examples'
}

models_ids = {
    19: 'Pan-STARRS + WISE',
    21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
}

In [None]:
linestyles = ['--', 'dashdot', (0, (2, 1, 1, 1)), '-']
for sample, s_name in sample_name.items():
    data = data_all[sample]
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20,7))
    for i, (mid, m_name) in enumerate(models_ids.items()):
        pred_col = f'zoo_x1a{mid}_z_max'
        conf_col = pred_col + 'Conf'
        axs[0].hist(data[conf_col], bins=np.linspace(0, 1, 21), label=m_name, histtype='step', linestyle=linestyles[i],
                 density=True)
        axs[1].hist(data[conf_col], bins=np.linspace(0, 1, 21), histtype='step', linestyle=linestyles[i],
                 density=True, cumulative=True)
        
    axs[0].set(title='Density', xlabel='zConf')
    axs[1].set(title='Cumulative', xlabel='zConf')
    fig.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2)
    fig.tight_layout()
    plt.show()
    plt.close()

# Galaxies

In [None]:
 bins_for_metrics = {
    'zspec': ('spec-z', 50),
    'decals8tr_z': ('z_mag', 50),
    'decals8tr_r': ('r_mag', 50),
}
zspec_col = 'zspec'

linestyles = [
    ':', '--', 'dashdot', '-',
]

folds = ['0+1 (GALAXY only)']
titles = ['2-fold CV (GALAXY Only)', 'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)']

for bin_col, (col_name, *nbins) in bins_for_metrics.items():
    nrows=1
    ncols=2
    figsize_factor = 10
    figsize=(10*ncols, 6*nrows)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=False, figsize=figsize)
#                 for i in range(ncols):
#                     axs[-1][i].set(xlabel=col_name)

    
    for i in range(ncols):
        axs[i].set(xlabel=col_name)
        
    axs = axs.reshape(-1, 1)

    legend = dict()
    for j, fold in enumerate(folds):
        if fold == 'stripe82x-a17-table13_ls-base':
            ax[j].axis('off')
            continue
            
        data_fold = data_all[fold]
        notna_mask = (data_fold['zspec'] >= 0)
        for mid in models_ids:
            pred_col = f'zoo_x1a{mid}_z_max'
            notna_mask &= data_fold[pred_col].notna()
        
        for midx, (mid, model_name) in enumerate(tqdm(models_ids.items())):
            pred_col = f'zoo_x1a{mid}_z_max'
            conf_col = pred_col + 'Conf'

            metrics = {
                "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
    #                         "w/o photometry": sklearn2singlearg(wo_phot, pred_col),
                "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
#                 "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
            }

            bins = nbins[j]
            for i, (m_name, m) in enumerate(metrics.items()):
                if m_name == "w/o photometry":
                    continue

                l, _, _, bins = metrics_by_bins(data_fold.loc[notna_mask], bin_col,
                                                m, bins, axs[i][j], linestyle=linestyles[midx])
                legend[model_name] = l
            
                axs[i][0].set_ylabel(m_name)
#             ax[j].grid(axis='y', which='minor')
            
#             break
        
#         axs[0][j].set_title(titles[j])
        
        
#     ax[0].set(ylim=(0.001, 0.5), yscale='log')    
#     ax[0].yaxis.set_minor_formatter(FormatStrFormatter("%.3f"))

    ylim_nmad = (0, 1.05*axs[0][0].get_ylim()[1])
    axs[0][0].set_ylim(ylim_nmad)
    
    ylim_n015 = (0, 1.05*axs[1][0].get_ylim()[1])
    axs[1][0].set_ylim(ylim_n015)

    fig.legend(legend.values(), legend.keys(), loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2)
    fig.tight_layout()
    plt.show()
    fig.savefig(os.path.join(charts_path, f'metrics_cv+dr16_{bin_col}.png'))
    plt.close()
#     break

# Общая таблица со всеми метриками

In [None]:
def argminv(l, dst_val):
    bidx = 0
    bdiff = abs(l[0] - dst_val)
    for i, v in enumerate(l[1:], start=1):
        diff = abs(v - dst_val)
        if diff < bdiff:
            bidx = i
            bdiff = diff
            
    return bidx, l[bidx]

sample_name = {
    'stripe82x-a17-table13_ls-base': 'Stripe82x-A17 sample',
#     'vhzqs_test': 'VHzQs test sample',
    'DR16Q_v4-wo_01_train': 'Quasars of DR16q',
    '0+1': "Cross-Validation",
#     '0+1 (GALAXY only)', "Cross-Validation (Galaxies)"
}

graph_type = {
    'stripe82x-a17-table13_ls-base': ('table',),
}

bins_for_metrics = {
    'stripe82x-a17-table13_ls-base': {
        'zspec': ('spec-z', [-99, 0.5, 1, 1.5, 2, 99]),
        'ls_mag_z': ('z_mag', [-99, 19, 20, 20.5, 21, 23]),
#         'FSoft': ('FSoft', [1.0e-14, 4.0e-14, 1.0])
    },
}

bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['zspec'][1]
bins_for_metrics['stripe82x-a17-table13_ls-base']['zspec'] = \
    ('z_{spec}', [(-99, 99)] + list(zip(bins[:-1], bins[1:])))

bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['decals8tr_z'][1]
bins_for_metrics['stripe82x-a17-table13_ls-base']['ls_mag_z'] = \
    ('z_{mag}', list(zip(bins[:-1], bins[1:])))

# bins = bins_for_metrics['stripe82x-a17-table13_ls-base']['FSoft'][1]
# bins_for_metrics['stripe82x-a17-table13_ls-base']['FSoft'] = \
#     ('FSoft', [(-99, b) for b in bins])

print(bins_for_metrics)

models_used = {
    'stripe82x-a17-table13_ls-base': 'x1a',
}

models_ids_metrics = {
    19: 'Pan-STARRS + WISE',
    21: 'Pan-STARRS + DESI LIS + WISE',
    22: 'DESI LIS + WISE',
    34: 'SDSS + Pan-STARRS + WISE',
    35: 'SDSS + Pan-STARRS + DESI LIS + WISE',
}

models_ids_metrics_short = {
    19: '\\ref{model:pw}',
    21: '\\ref{model:pdw}',
    22: '\\ref{model:dw}',
    34: '\\ref{model:spw}',
    35: '\\ref{model:spdw}',
}

sample = 'stripe82x-a17-table13_ls-base'
data = data_all[sample]
print_header(sample_name[sample], 1)
models_series = models_used[sample]

columns_with_nas = [f'zoo_{models_series}{mid}_z_max' for mid in models_ids.keys()] + ['zspec']
if sample == 'stripe82x-a17-table13_ls-base':
    columns_with_nas += ['zph', 'zphML']

print(columns_with_nas)
print(len(data))
data = data.dropna(subset=columns_with_nas)
print(len(data))

zspec_col = 'zspec'
for gt in graph_type[sample]:
    if gt == 'table':
        total_latex = ''
        
        for col, (col_name, bins) in bins_for_metrics[sample].items():
            metric_bins = bins
            total_table = pd.DataFrame()
            
            for bin_start, bin_end in metric_bins:
#                 total_tables = dict()
                
                table = defaultdict(list)
                for sidx, (sample, sname) in enumerate(sample_name.items()):
                    data = data_all[sample]
                    
                    try:
                        data_fold['ls_mag_z'] = ab_mag(data_fold['ls_flux_z'])
                    except KeyError:
                        data_fold['ls_mag_z'] = ab_mag(data_fold['flux_z_min_error'])
                    
                    bin_mask = (data[col] >= bin_start) & (data[col] < bin_end)
                    bin_data = data.loc[bin_mask]

                    header = col_name
                    if bin_start != -99:
                        header = f'{bin_start} \leq {header}'
                    if bin_end != 99:
                        header = f'{header} < {bin_end}'
                    if bin_start == -99 and bin_end == 99:
                        header = 'All'
                    
                    calibration_calculated = True
                    print_header(f'{header} ({sname} {len(bin_data)} objects)')
                    for mid, model_name in models_ids_metrics.items():
                        pred_col = f'zoo_{models_series}{mid}_z_max'
                        conf_col = pred_col + 'Conf'
                        if pred_col.replace('_z_max', '_ci1a_68') not in bin_data.columns:
                            calibration_calculated = False

                        model_mask = bin_data[pred_col].notna()
    #                     print((~model_mask).sum())
                        model_data = bin_data.loc[model_mask]

    #                     print(model_data[pred_col].isna().sum(), len(model_data))

                        model_name = models_ids_metrics_short[mid]
                        if sidx == 0:
                            table[('', 'Model')].append((f'${header}$', model_name))
                    
                        table[(sname, '$NMAD$')].append(
                            round(srgpz.metrics.nmad_z(model_data[pred_col], model_data[zspec_col]), 3)
                        )
                        table[(sname, '$n>0.15$')].append(
                            round(srgpz.metrics.catastrophic_outliers_z(model_data[pred_col], model_data[zspec_col]), 3)
                        )
                        if calibration_calculated:
                            table[(sname, '$C_{68} - 0.68$')].append(
                                 round(kstest(
                                     model_data[pred_col.replace('_z_max', '_ci1a_68')] + model_data[pred_col],
                                     model_data[pred_col.replace('_z_max', '_ci1b_68')] + model_data[pred_col],
                                     model_data[zspec_col],
                                 ) - 0.68, 3)
                            )
    #                         table['$n(dz_{norm}>0.15 | zConf<0.4)$'].append(
    #                             round(additional_metric(model_data[pred_col], model_data[conf_col], model_data[zspec_col]), 3)
    #                         )
    #                         table['w/o photometry'].append(
    #                             round((~model_mask).sum() / len(model_mask), 3)
    #                         )

                    for metric in ['$NMAD$', '$n>0.15$']:
                        best_value = min(table[(sname, metric)])
                        table[(sname, metric)] = list(map(lambda x: str(x) if x != best_value else f'\\textbf{{{x}}}',
                                                 table[(sname, metric)]))
                    
                    if calibration_calculated:
                        metric = '$C_{68} - 0.68$'
                        _, best_value = argminv(table[(sname, metric)], 0)
                        table[(sname, metric)] = list(map(lambda x: str(x) if x != best_value else f'\\textbf{{{x}}}',
                                                     table[(sname, metric)]))
                    

                
                table = pd.DataFrame.from_dict(table).set_index(('', 'Model'))  # .sort_values(by='$NMAD$')
                total_table = pd.concat([total_table, table], sort=False)
                
            total_table.index = pd.MultiIndex.from_tuples(total_table.index, names=("subsample", 'model'))
                                    
            print(latex_from_table(total_table, star=True, index=True, newline=9,
                                   title=f"Metrics by ${col}$", hline_every=len(models_ids_metrics),
                                   hline_every_shift=0) + '\n\n\n')
#             break
        

    else:
        print('Wrong type of graph:', gt)

# Data section

In [None]:
data_all['0+1']['class'].unique()

In [None]:
xcol = 'zspec'
ycol = 'decals8tr_z'
xname = 'spec-z'
yname = 'z_mag (DESI LIS)'

xbins = np.linspace(0, 7, 71)
ybins = np.linspace(12, 24, 71)

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(33, 20), sharex=True, sharey=True)
ax = axs.flat

kwargs  =   {'edgecolor':"C1", # for edge color
                 'linewidth':1.0, # line width of spot
                 'linestyle':'-', # line style of spot
             'facecolors': 'none'
                }

ii = 3
df = data_all['stripe82x-a17-table13_ls-base']
fsoft_mask = (df['FSoft'] > 3.0e-15) & (df['FSoft'] <= 1.0e-14)

mask = df['SpClass'].isin(['QSO', 'QSO(BA', 'QSO_BAL']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='QSO', marker='x')

mask = df['SpClass'].isin(['GALAXY', 'AGN']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='Galaxy', marker='o', **kwargs)
ax[ii].legend()
ax[ii].set_title('Stripe82X (3.0e-15 < FSoft < 1.0e-14)')

divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
cax.axis('off')


ii = 4
df = data_all['stripe82x-a17-table13_ls-base']
fsoft_mask = (df['FSoft'] > 1.0e-14) & (df['FSoft'] <= 4.0e-14)

mask = df['SpClass'].isin(['QSO', 'QSO(BA', 'QSO_BAL']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='QSO', marker='x')

mask = df['SpClass'].isin(['GALAXY', 'AGN']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='Galaxy', marker='o', **kwargs)
ax[ii].legend()
ax[ii].set_title('Stripe82X (1.0e-14 < FSoft < 4.0e-14)')

divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
cax.axis('off')


ii = 5
df = data_all['stripe82x-a17-table13_ls-base']
fsoft_mask = df['FSoft'] > 4.0e-14

mask = df['SpClass'].isin(['QSO', 'QSO(BA', 'QSO_BAL']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='QSO', marker='x')

mask = df['SpClass'].isin(['GALAXY', 'AGN']) & fsoft_mask
ax[ii].scatter(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), label='Galaxy', marker='o', **kwargs)
ax[ii].legend()
ax[ii].set_title('Stripe82X (FSoft > 4.0e-14)')

divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
cax.axis('off')


ii = 2
df = data_all['DR16Q_v4-wo_01_train']
mask = df[xcol].notna() & df[ycol].notna()
h = ax[ii].hist2d(df.loc[mask, xcol], ab_mag(df.loc[mask, 'ls_flux_z']), bins=[xbins, ybins], cmap='Blues', zorder=0.9,
                 norm=mpl.colors.LogNorm())
divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
fig.colorbar(h[3], cax=cax)

ax[ii].set_title('DR16q test sample')
ax[ii].grid()
ax[ii].set_axisbelow(False)


ii = 0
df = data_all['0+1']
mask = df['class'].isin(['GALAXY'])
mask &= df[xcol].notna() & df[ycol].notna()
h = ax[ii].hist2d(df.loc[mask, xcol], ab_mag(df.loc[mask, 'flux_z_min_error']),
                  bins=[xbins, ybins], cmap='Oranges', zorder=0.9,
                 norm=mpl.colors.LogNorm())
divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
fig.colorbar(h[3], cax=cax)

ax[ii].set_title('Train sample (GALAXY)')
ax[ii].grid()
ax[ii].set_axisbelow(False)

ii = 1
df = data_all['0+1']
mask = df['class'].isin(['QSO'])
mask &= df[xcol].notna() & df[ycol].notna()
h = ax[ii].hist2d(df.loc[mask, xcol], ab_mag(df.loc[mask, 'flux_z_min_error']),
                  bins=[xbins, ybins], cmap='Blues', zorder=0.9,
                 norm=mpl.colors.LogNorm())
divider = make_axes_locatable(ax[ii])
cax = divider.append_axes("right", size="5%", pad=0.1)
fig.colorbar(h[3], cax=cax)

ax[ii].set_title('Train sample (QSO)')
ax[ii].grid()
ax[ii].set_axisbelow(False)

axs[0][0].set(ylabel=yname)
axs[1][0].set(xlabel=xname, ylabel=yname)
axs[1][1].set(xlabel=xname)
axs[1][2].set(xlabel=xname)


dx = 0.12
dy = 0.07
axs[0][0].set(xlim=(0-dx, 7+dx), ylim=(12-dy, 24+dy)[::-1])

fig.tight_layout()
plt.show()
plt.close()

In [None]:
fxs = [3e-15, 1e-14, 4e-14, 1]
for fx_min, fx_max in zip(fxs[:-1], fxs[1:]):
    sname = f'stripe82x_{fx_min}<Fx<{fx_max}'
    print(sname)
    data = data_all['stripe82x-a17-table13_ls-base']
    mask = (fx_min <= data['FSoft']) & (data['FSoft'] < fx_max)
    data_all[sname] = data.loc[mask]

In [None]:
table = defaultdict(list)
    
samples = {
    '0+1': 'Train',
    'stripe82x-a17-table13_ls-base': 'Stripe82X',
    'stripe82x_3e-15<Fx<1e-14': 'Stripe82X ($\expnumber{3}{-15} <= FSoft < \expnumber{1}{-14}$)',
    'stripe82x_1e-14<Fx<4e-14': 'Stripe82X ($\expnumber{1}{-14} <= FSoft < \expnumber{4}{-14}$)',
    'stripe82x_4e-14<Fx<1': 'Stripe82X ($FSoft > \expnumber{4}{-14}$)',
    'DR16Q_v4-wo_01_train': 'DR16q w/o train',
}

for s, sname in samples.items():
    print(s)
    df = data_all[s]
    
    mask = df['zspec'].notna()
    df = df.loc[mask]
    
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
    
    mask_sdss = True
    for pb in 'ugriz':
        try:
            mask_sdss &= df[f'sdss_psfFlux_{pb}'].notna()
            mask_sdss &= df[f'sdss_psfFluxIvar_{pb}'].notna()
            mask_sdss &= df[f'sdss_cModelFlux_{pb}'].notna()
            mask_sdss &= df[f'sdss_cModelFluxIvar_{pb}'].notna()
        except KeyError:
            mask_sdss &= df[f'psfFlux_{pb}'].notna()
            mask_sdss &= df[f'psfFluxIvar_{pb}'].notna()
            mask_sdss &= df[f'cModelFlux_{pb}'].notna()
            mask_sdss &= df[f'cModelFluxIvar_{pb}'].notna()

    mask_ps = True
    for pb in 'grizy':
        try:
            mask_ps &= df[f'ps_{pb}PSFFluxErr'].notna()
            mask_ps &= df[f'ps_{pb}PSFFlux'].notna()
        except KeyError:
            mask_ps &= df[f'{pb}PSFFluxErr_min_error'].notna()
            mask_ps &= df[f'{pb}PSFFlux_min_error'].notna()
        
    mask_ls = True
    for pb in ['g', 'r', 'z', 'w1', 'w2']:
        try:
            mask_ls &= df[f'ls_flux_ivar_{pb}'].notna()
            mask_ls &= df[f'ls_flux_{pb}'].notna()
        except KeyError:
            mask_ls &= df[f'flux_ivar_{pb}_min_error'].notna()
            mask_ls &= df[f'flux_{pb}_min_error'].notna()
    
    table[('No. sources', 'Total')].append(len(df))
    table[('No. sources', 'Galaxy')].append(df['class'].isin(['GALAXY', 'AGN']).sum())
    table[('No. sources', 'QSO')].append(df['class'].isin(['QSO', 'QSO(BA', 'QSO_BAL']).sum())
    table[('With photometry from', 'DESI LIS')].append(f"{mask_ls.sum()}")
    table[('With photometry from', 'Pan-STARRS')].append(f"{(mask_ps & mask_ls).sum()}")
    table[('With photometry from', 'SDSS')].append(f"{(mask_sdss & mask_ls).sum()}")
    table[('With photometry from', 'All 3 surveys')].append(f"{(mask_sdss & mask_ps & mask_ls).sum()}")    

table = pd.DataFrame.from_dict(table, orient='index', columns=samples.values())
table = table.T
table.columns = pd.MultiIndex.from_tuples(table.columns, names=['','Sample'])
ipd.display(table)
print(latex_from_table(table, title='Data', index=True, star=True))

In [None]:
samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
samples_names = ['Train', 'Stripe82X', 'DR16q w/o train']
dict(zip(samples, samples_names))

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(21, 14))
plt.tick_params(axis='y', which='minor')

samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
samples_names = ['Train', 'Stripe82X', 'DR16q w/o train']
for i, s in enumerate(samples):
    df = data_all[s]
    print(s, len(df))
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
        
    df['decals8tr_w1'] = df['decals8tr_Lw1']
    
    for j, pb in enumerate(['g', 'z', 'w1']):
        try:
            mags = df[f'ls_flux_{pb}']
        except KeyError:
            mags = df[f'flux_{pb}_min_error']
            
        mags = ab_mag(mags)
        
        mask = df['class'].isin(['QSO', 'QSO(BA', 'QSO_BAL'])
        ax[i][j].hist(mags.loc[mask], np.linspace(9, 32, 47),
                      histtype='step', linestyle='-', color='C0', label='QSO')

        mask = df['class'].isin(['GALAXY', 'AGN'])
        ax[i][j].hist(mags.loc[mask], np.linspace(9, 32, 47),
                      histtype='step', linestyle='--', color='C1', label='GALAXY')
        
        ax[i][j].set(yscale='log')
        if i < ax.shape[0] - 1:
            ax[i][j].get_xaxis().set_ticklabels([])
        else:
            ax[i][j].set_xlabel(f'{pb}_mag')
            
for i in range(ax.shape[0]):
    ylim = (
        min([a.get_ylim()[0] for a in ax[i]]),
        max([a.get_ylim()[1] for a in ax[i]])
    )
    for j, a in enumerate(ax[i]):
        a.grid(ls=':')
        a.set_ylim(ylim)
        a.invert_xaxis()
        if j:
            a.get_yaxis().set_ticklabels([])
        else:
            a.set_ylabel(f'No. sources\n{samples_names[i]}')

ax[0][2].legend()
fig.tight_layout()
plt.show()
plt.close()

In [None]:
total_latex = ''
table = defaultdict(list)
samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
for s in samples:
    print(s)
    df = data_all[s]
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
    
    for pb in 'ugriz':
        table[f'${pb}_{{model}}$'].append(f"{df[f'sdssdr16_{pb}_cmodel'].max():.3f}")
        table[f'${pb}_{{psf}}$'].append(f"{df[f'sdssdr16_{pb}_psf'].max():.3f}")

table = pd.DataFrame.from_dict(table, orient='index', columns=['Train sample', 'Stripe82X', 'DR16q test'])
table = table.T
print_header('sdss table')
ipd.display(table)
total_latex += latex_from_table(table, title='Data', index=True, star=True, newline=10) + '\n\n'

table = defaultdict(list)
samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
for s in samples:
    print(s)
    df = data_all[s]
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
    
    for pb in 'ugriz':
        table[f'${pb}_{{model}}$'].append(f"{df[f'sdssdr16_{pb}_cmodel'].max():.3f}")
        table[f'${pb}_{{psf}}$'].append(f"{df[f'sdssdr16_{pb}_psf'].max():.3f}")

table = pd.DataFrame.from_dict(table, orient='index', columns=['Train sample', 'Stripe82X', 'DR16q test'])
table = table.T
print_header('sdss table')
ipd.display(table)
total_latex += latex_from_table(table, title='Data SDSS', index=True, star=True, newline=10) + '\n\n'

table = defaultdict(list)
samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
for s in samples:
    print(s)
    df = data_all[s]
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
    
    for pb in 'grizy':
        table[f'${pb}_{{kron}}$'].append(f"{df[f'psdr2_{pb}_kron'].max():.3f}")
        table[f'${pb}_{{psf}}$'].append(f"{df[f'psdr2_{pb}_psf'].max():.3f}")

table = pd.DataFrame.from_dict(table, orient='index', columns=['Train sample', 'Stripe82X', 'DR16q test'])
table = table.T
print_header('sdss table')
ipd.display(table)
total_latex += latex_from_table(table, title='Data PS', index=True, star=True, newline=10) + '\n\n'

table = defaultdict(list)
samples = ['0+1', 'stripe82x-a17-table13_ls-base', 'DR16Q_v4-wo_01_train']
for s in samples:
    print(s)
    df = data_all[s]
    if s == 'stripe82x-a17-table13_ls-base':
        df['class'] = df['SpClass']
    if s == 'DR16Q_v4-wo_01_train':
        df['class'] = 'QSO'
    if s == '0+1':
        df['ls_brickid'] = df['brickid']
        df['ps_objID'] = df['objID_psdr2']
        df['sdss_objID'] = df['objID_sdssdr16']
    
    for pb in ['g', 'r', 'z', 'Lw1', 'Lw2']:
        table[f'${pb}$'].append(f"{df[f'decals8tr_{pb}'].max():.3f}")

table = pd.DataFrame.from_dict(table, orient='index', columns=['Train sample', 'Stripe82X', 'DR16q test'])
table = table.T
print_header('sdss table')
ipd.display(table)
total_latex += latex_from_table(table, title='Data LS', index=True, star=True, newline=10) + '\n\n'

print(total_latex)

### Распределение объектов трейна до и после обрезания пиков

In [None]:
train_new = data_all['0+1']
train_old = pd.read_pickle(
    '/data/victor/srgz_models/data/train_QSO_XbalancedGALAXY-sdss_unwise-wo_3XMM_XXLN_S82X_LH-asinhmags.gz_pkl',
    compression='gzip'
)

In [None]:
train_old = train_old.drop_duplicates(subset=['ra', 'dec'])
train_old = train_old.drop_duplicates(subset=['objID'])

In [None]:
bins = np.linspace(0, 7, 71)

fig, ax = plt.subplots(figsize=(12, 7))
ax.hist(train_old.loc[train_old['class'] == 'QSO', 'zspec'], bins,
         histtype='step', label='QSO before cut', linestyle='--', color='C0')
ax.hist(train_new.loc[(train_new['class'] == 'QSO') & (train_new['origin'] != 'VHzQs'), 'zspec'], bins,
         histtype='step', label='QSO after cut', linestyle='-', color='C0')

ax.hist(train_old.loc[train_old['class'] == 'GALAXY', 'zspec'], bins,
         histtype='step', label='GALAXY before cut', linestyle='--', color='C1')
ax.hist(train_new.loc[train_new['class'] == 'GALAXY', 'zspec'], bins,
         histtype='step', label='GALAXY after cut', linestyle='-', color='C1')

ax.hist(train_new.loc[train_new['origin'] == 'VHzQs', 'zspec'], bins,
         histtype='step', label='VHzQs objects', linestyle='-', color='C2')

ax.set_xlabel('spec-z')
ax.set_ylabel('No. sources')
# plt.yscale('log')
ax.set_xlim(-0.1, 5.1)
ax.legend(ncol=3)
ax.grid(ls=':')


from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
zoom_ax = fig.add_axes([0,0,1,1])
zoom_ax.set_axes_locator(InsetPosition(ax, [0.6, 0.48, 0.39, 0.35]))

zoom_ax.hist(train_old.loc[train_old['class'] == 'QSO', 'zspec'], bins,
         histtype='step', label='QSO before cut', linestyle='--', color='C0')
zoom_ax.hist(train_new.loc[(train_new['class'] == 'QSO') & (train_new['origin'] != 'VHzQs'), 'zspec'], bins,
         histtype='step', label='QSO after cut', linestyle='-', color='C0')

zoom_ax.hist(train_new.loc[train_new['origin'] == 'VHzQs', 'zspec'], bins,
         histtype='step', label='VHzQs objects', linestyle='-', color='C2')

zoom_ax.set_xlim(4.4, 7.1)
zoom_ax.set_ylim(0, 150)
zoom_ax.grid(ls=':')
zoom_ax.set_xticks([4.5, 5.0, 5.5, 6.0, 6.5, 7.0])

fig.tight_layout()
plt.show()
plt.close()

# Rodion VS LIS

In [None]:
sample = 'DR16Q_v4-wo_01_train'

sample_path = f'/data/victor/graphs4article/{sample}/buf/'
data = list()
for features_file in tqdm(glob.glob(
            os.path.join(sample_path, 'part-*.features.gz_pkl')
        ), desc=f'Reading files for {sample}'):

    chunk_number = re.findall('^part-(\d*).features.gz_pkl$', os.path.basename(features_file))[0]

    features = pd.read_pickle(features_file, compression='gzip')
    preds = [
        pd.read_pickle(file, compression='gzip')
        for file in glob.glob(os.path.join(sample_path, f'part-{chunk_number}.preds.*.gz_pkl'))
    ]
    data_chunk = [features] + preds
    data_chunk = [df.loc[~df.index.duplicated(keep='last')] for df in data_chunk]
    data_chunk = pd.concat(data_chunk, axis=1)
    data.append(data_chunk)

data = pd.concat(data, axis=0)
data = data.reset_index(drop=True)

data['zspec'] = data['Z']
mask = data['SOURCE_Z'] == b'VI'
mask &= data['Z_CONF'] == 3
data = data.loc[mask]

bins = np.linspace(0, 7, 71)
col = 'zspec'
sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

thres = 1800
dst = []
for start, end in zip(bins[:-1], bins[1:]):
    chunk = data.loc[(start < data[col]) & (data[col] <= end)]
    if len(chunk) > thres:
        _, chunk = train_test_split(chunk, test_size=thres/len(chunk), random_state=42)

    dst.append(chunk)

data = pd.concat(dst)
sns.distplot(data[col], kde=False, norm_hist=False, bins=bins)

In [None]:
models_ids_metrics = {
    '19': 'Pan-STARRS + WISE (DESI LIS)',
    'pswf19': 'Pan-STARRS + WISE (Rodion)',
#     '34': 'SDSS + Pan-STARRS + WISE (DESI LIS)',
#     'pswf34': 'SDSS + Pan-STARRS + WISE (Rodion)',
}

models_ids_metrics_short = {
    '19': 'PW (LS)',
    'pswf19': 'PW (R)',
#     '34': 'SPW (LS)',
#     'pswf34': 'SPW (R)',
}

In [None]:
z_bins_cv2 = [9.827747557554806, 17.13419637985013, 17.357575120155055, 17.579100052699232, 17.758957867548197, 17.9082410475928, 18.036726268982886, 18.14583221651932, 18.244829930859414, 18.33652587079903, 18.419562410135384, 18.497557442776014, 18.571027900479034, 18.63781685183387, 18.698689421928663, 18.75542677143206, 18.8081912015041, 18.858127363223694, 18.904346386330406, 18.94794603593956, 18.9903881814915, 19.030046545247764, 19.06933391365944, 19.107757793091295, 19.144645646657338, 19.18189048105735, 19.2200396950378, 19.25739167256188, 19.29418073475197, 19.33192841343464, 19.369341545553247, 19.407212891772218, 19.445693442126636, 19.484164485968897, 19.523543927297332, 19.562803704409053, 19.6023351933546, 19.641194873381693, 19.679350762407818, 19.71798610806622, 19.755406653369835, 19.792903139172306, 19.830040899017312, 19.866648903584892, 19.90247699130969, 19.937493972774234, 19.971504133988486, 20.005410267649168, 20.039742180266142, 20.07425463440665, 20.10839592003747, 20.141836451816314, 20.175237760638385, 20.2082612279237, 20.240862190008166, 20.27363939970968, 20.307048266961374, 20.34001160666969, 20.373117967497404, 20.405324775585694, 20.43856996861692, 20.471458673635357, 20.504278503251264, 20.537111504619375, 20.57064919178948, 20.60408997540508, 20.63729785458057, 20.671598715674975, 20.70444011595687, 20.738151163020987, 20.77278935592166, 20.807208164029582, 20.842199924655635, 20.87736454554745, 20.912391661079493, 20.948799924949082, 20.985947290720336, 21.02350321886055, 21.061469305612608, 21.101228507751074, 21.140685720613337, 21.181512448109945, 21.224174977841194, 21.268813618713573, 21.314540729039173, 21.36350242052781, 21.415683003174635, 21.471887882382692, 21.533218758421445, 21.600345073306542, 21.676911313133786, 21.770678738521514, 21.893164416324517, 22.08596910146424, 100]
bins_for_metrics = {
    'zspec': ('spec-z', 50),
    '!zphot': ('photo-z', 50),
    'ls_mag_z': ('z_mag', 50),
    
#     'decals8tr_r': ('r_mag', 50, 100, 50),
}
zspec_col = 'zspec'

linestyles = [
    '--', ':', 'dashdot', (0, (3,1,1,1,1,1)), '-',
]

folds = [
    'DR16Q_v4-wo_01_train'
]
titles = [
    'DR16q w/o train (SOURCE_Z = VI and Z_CONF = 3)'
]

# for bin_col, (col_name, *nbins) in bins_for_metrics.items():
for fold_idx, fold in enumerate(folds):
    print_header(fold)
    nrows=3
    ncols=len(bins_for_metrics)
    figsize_factor = 10
    figsize=(10*ncols, 6*nrows)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
#                 for i in range(ncols):
#                     axs[-1][i].set(xlabel=col_name)
    
#     for i in range(ncols):
#         axs[-1][i].set(xlabel=col_name)

    legend = dict()
    for j, (bin_col, (col_name, *nbins)) in enumerate(bins_for_metrics.items()):
        data_fold = data
        try:
            data_fold['ls_mag_z'] = ab_mag(data_fold['ls_flux_z'])
        except KeyError:
            data_fold['ls_mag_z'] = ab_mag(data_fold['flux_z_min_error'])
            
        if bin_col == 'ls_mag_z' and fold == '0+1':
            override_first_point = 17
        else:
            override_first_point = None
        
        notna_mask = (data_fold['zspec'] >= 0)
        for mid in models_ids:
            pred_col = f'zoo_x1a{mid}_z_max'
            notna_mask &= data_fold[pred_col].notna()
        
        for midx, (mid, model_name) in enumerate(tqdm(models_ids_metrics.items())):
            pred_col = f'zoo_x1a{mid}_z_max'
            conf_col = pred_col + 'Conf'
            if pred_col not in data_fold.columns:
                continue

            metrics = {
                "NMAD": sklearn2singlearg(srgpz.metrics.nmad_z, pred_col, zspec_col),
    #                         "w/o photometry": sklearn2singlearg(wo_phot, pred_col),
                "n>0.15": sklearn2singlearg(srgpz.metrics.catastrophic_outliers_z, pred_col, zspec_col),
                "n(dz_norm>0.15 | zConf<0.4)": sklearn2singlearg(additional_metric, pred_col, conf_col, zspec_col),
            }

            bins = nbins[fold_idx]
            for i, (m_name, m) in enumerate(metrics.items()):
                if m_name == "w/o photometry":
                    continue

                l, _, _, bins = metrics_by_bins(data_fold.loc[notna_mask], bin_col if bin_col != '!zphot' else pred_col,
                                                m, bins, axs[i][j], linestyle=linestyles[midx],
                                               color=f'C{midx}', override_first_point=override_first_point)
#                 print(bins)
                legend[model_name] = l
            
                axs[i][0].set_ylabel(m_name)
#             ax[j].grid(axis='y', which='minor')
            
#             break
        
        axs[-1][j].set_xlabel(col_name)
        
        
#     ax[0].set(ylim=(0.001, 0.5), yscale='log')    
#     ax[0].yaxis.set_minor_formatter(FormatStrFormatter("%.3f"))
    ylim_nmad = (0, 1.05*max([axs[0][i].get_ylim()[1] for i in range(axs.shape[1])]))
    
    axs[0][1].set_ylim(ylim_nmad)
    axs[0][0].get_xaxis().set_ticklabels([])
    axs[0][1].get_yaxis().set_ticklabels([])
    
    ylim_n015 = (0, 1.05*max(
        [axs[i][j].get_ylim()[1] for i, j in itertools.product(range(axs.shape[0]), range(1, axs.shape[1]))]
    ))
    for j in range(axs.shape[1]):
        axs[0][j].set_ylim(ylim_nmad)
        for i in range(1, axs.shape[0]):    
            axs[i][j].set_ylim(ylim_n015)
    
    for i in range(axs.shape[0]):
        axs[i][2].invert_xaxis()
        
    for i in range(axs.shape[0]-1):
        for j in range(axs.shape[1]):
            axs[i][j].get_xaxis().set_ticklabels([])
            
    for i in range(axs.shape[0]):
        for j in range(1, axs.shape[1]):
            axs[i][j].get_yaxis().set_ticklabels([])
    

    fig.legend(legend.values(), legend.keys(), loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2)
    fig.tight_layout()
    plt.show()
    fig.savefig(os.path.join(charts_path, f'metrics_cv+dr16_{bin_col}.png'))
    plt.close()
#     break

In [None]:
scatter_photo_z(data, 'zoo_x1a19_z_max', 'zoo_x1apswf19_z_max')

In [None]:
scatter_photo_z(data, 'zspec', 'zoo_x1apswf19_z_max', 'zoo_x1apswf19_z_maxConf')

In [None]:
scatter_photo_z(data, 'zspec', 'zoo_x1a19_z_max', 'zoo_x1a19_z_maxConf')