In [None]:
import numpy as np
import matplotlib.pyplot as plt

from itertools import combinations
from collections import Counter, defaultdict
from astropy.table import Table, join, vstack
from astropy.coordinates import SkyCoord, Distance
from SAGA.utils import add_skycoord
from easyquery import Query, QueryMaker
from astropy.io import ascii

import SAGA
from SAGA.database import FitsTable, GoogleSheets
from SAGA import ObjectCuts as C
from SAGA.utils import add_skycoord, fill_values_by_query
from SAGA.utils.distance import z2v, d2m, m2d
from SAGA.utils.display import show_images
from SAGA.objects.object_catalog import calc_fiducial_p_sat, calc_fiducial_p_sat_corrected
print(SAGA.__version__)

In [None]:
saga = SAGA.QuickStart()
saga.set_default_base_version("paper2")

In [None]:
lg = SAGA.database.FitsTable("../data/McConnachie2012.fit").read()
lg["Mr"] = lg["vmag_lc"] - 0.4
lg["SubG"][np.argmax(lg["Name"] == "Pegasus dIrr")] = "M31 "
lg = (Query((np.isfinite, "vmag_lc"), ~QueryMaker.equals("Name", "Canis Major"), "Mr < -10")).filter(lg)

mw_sat_mr = np.asarray(np.sort(QueryMaker.equals("SubG", "MW  ").filter(lg, "Mr")))
m31_sat_mr = np.asarray(np.sort(QueryMaker.equals("SubG", "M31 ").filter(lg, "Mr")))

In [None]:
mw = SAGA.database.GoogleSheets("1O8tGgnHXRcAT8P78J3V2pWVb3CWE1Hvcowily79tdg4", 0).read()
mw_sat_mr[:4] = np.sort(mw["Mr_o"])[:4]
mw_sat_mr[4] = -12.3  # Leo I

In [None]:
def annotate_base(base):
    base["Mr_"] = base["r_mag"] - d2m(base["HOST_DIST"])
    base["Mr"] = np.where(np.isfinite(base["Mr"]) & C.has_spec.mask(base), base["Mr"], base["Mr_"])
    base["log_sm"] = np.where(np.isfinite(base["log_sm"]), base["log_sm"], 1.254 + 1.0976 * base["gr"] - 0.4 * base["Mr"])
    
    base["human_selected"] = 0
    hs = saga.database["human_selected"].read()
    for i in range(1, 4):
        base["human_selected"] += np.in1d(base["OBJID"], Query(f"score == {i}").filter(hs, "OBJID")).astype(np.int) * i

    return base

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Documents/Research/SAGA/PaperII/data-archive/saga_base_all.fits").read()
base = saga.host_catalog.construct_host_query("paper2_complete").filter(base)
assert len(np.unique(base["HOSTID"])) == 36

base = annotate_base(base)
sats = C.is_sat.filter(base)

In [None]:
nsat_missing_all = int(np.ceil(Query(~C.has_spec, C.faint_end_limit).filter(base, "p_sat_corrected").sum()))
nsat_missing_limit = int(np.ceil(Query(~C.has_spec, C.r_abs_limit).filter(base, "p_sat_corrected").sum()))
print(nsat_missing_all, nsat_missing_limit)

In [None]:
sats = C.is_sat.filter(base)
sats_maybe_all = Query(~C.has_spec, "p_sat_corrected > 0").filter(base)

In [None]:
seed = 1177 #923 #444 #177 # 114 # 47
while True:
    #seed += 1
    sats_maybe = Query(
        (lambda p: p > np.random.RandomState(seed).rand(len(p)), "p_sat_corrected"),  # must put in 1st place
        ~C.has_spec, 
    ).filter(base)

    if (np.abs((C.faint_end_limit.count(sats_maybe) - nsat_missing_all) < 1.5 and
        C.r_abs_limit.count(sats_maybe)) == nsat_missing_limit and
        (~Query(C.high_priority_sb_tight, C.gr_cut_tight, "Mr > -15.5")).count(sats_maybe) == 0):
        print(seed)
        break
    else:
        seed += 1
        
print(C.r_abs_limit.count(sats_maybe), len(sats_maybe))

In [None]:
t = Query(QueryMaker.equals("survey", "NSA"), "SATS == 3").filter(base, ["Mr", "HOST_MK", "HOSTID", "log_sm"])
print(len(t))

fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
ax[0].scatter(t["HOST_MK"], t["Mr"])
ax[0].plot([-24.6, -23], np.array([-24.6, -23])+2.5, c="C1")
ax[1].hist(t["log_sm"]);

In [None]:
hostlist = saga.host_catalog.load(include_stats=True, query="paper2_complete")["HOSTID", "sats_Mr_limit", "sats_total", "COMMON_NAME", "K_ABS"]
hostlist = join(hostlist, t, "HOSTID", "left")
hostlist["Mr"] = np.where(hostlist["Mr"].mask, hostlist["K_ABS"] + 2.5, hostlist["Mr"])
print("completeness limit =", np.median(-12.3 - (hostlist["Mr"])))

color_dict = np.load("/home/yymao/Documents/Research/SAGA/PaperII/data-archive/color_dict.npz")

In [None]:
lf_data = dict()
dm_bins = np.linspace(0, 11, 15)

for base_this in base.group_by("HOSTID").groups:
    host_this = base_this["HOSTID"][0]
    host_mag = hostlist[hostlist["HOSTID"] == host_this]["Mr"][0]
    base_this = Query("p_sat_corrected > 0", C.is_sat | (~C.has_spec)).filter(base_this, ["Mr", "p_sat_corrected", "SATS"])
    lf_data_this = dict()
    base_this.sort('Mr')    
    lf_data_this['lf_x'] = np.append(np.insert(np.vstack([base_this['Mr']]*2).T.flatten(), 0, -23.0), -9.0)
    lf_data_this['lf_y_est'] = np.vstack([np.insert(np.cumsum(base_this['p_sat_corrected']), 0, 0.0)]*2).T.flatten()
    lf_data_this['lf_y'] = np.vstack([np.insert(np.cumsum(C.is_sat.mask(base_this).astype(np.int)), 0, 0.0)]*2).T.flatten()
    lf_data_this['dm'] = np.searchsorted(C.is_sat.filter(base_this, "Mr") - host_mag, dm_bins)
    lf_data_this['dm_cor'] = np.insert(np.cumsum(base_this['p_sat_corrected']),0,0)[np.searchsorted(base_this['Mr'] - host_mag, dm_bins)]
    lf_data[host_this] = lf_data_this
    
for host_this, mr in zip(("MW", "M31"), (mw_sat_mr, m31_sat_mr)):
    lf_data_this = dict()
    lf_data_this['lf_x'] = np.append(np.insert(np.vstack([mr]*2).T.flatten(), 0, -23.0), -9.0)
    lf_data_this['lf_y'] = np.vstack([np.insert(np.arange(1, len(mr)+1), 0, 0.0)]*2).T.flatten()
    lf_data[host_this] = lf_data_this

In [None]:
fq_lg = ascii.read(format="fast_tab", names=["log_sm", "fq", "fq_ue", "fq_le"], data_start=0,
    table="""3.49	1.00	1.00	0.77
4.49	1.00	1.00	0.90
5.49	0.94	0.96	0.83
6.49	1.00	1.00	0.77
7.49	0.80	0.88	0.54
8.49	0.66	0.81	0.38
9.18	0.00	0.60	0.00
""")

fq_tinker = ascii.read(format="fast_tab", names=["log_sm0", "fq", "log_sm1",  "fq_ue", "log_sm2", "fq_le"], data_start=0,
    table="""9.7632	0.0625	9.7655	0.0455	9.7633	0.0788
9.8511	0.1149	9.8510	0.0986	9.8511	0.1312
9.9459	0.1340	9.9459	0.1170	9.9459	0.1503
10.0479	0.1835	10.0479	0.1637	10.0480	0.2019
10.1477	0.2514	10.1476	0.2295	10.1477	0.2734
10.2449	0.2790	10.2472	0.2563	10.2473	0.3010
10.3494	0.3618	10.3493	0.3363	10.3495	0.3859
10.4468	0.4433	10.4467	0.4142	10.4445	0.4709
10.5488	0.4666	10.5487	0.4340	10.5489	0.4985
10.6485	0.5381	10.6484	0.5033	10.6486	0.5728
10.7458	0.5812	10.7481	0.5437	10.7483	0.6188
10.8503	0.6471	10.8501	0.5982	10.8504	0.6938
10.9523	0.7079	10.9522	0.6526	10.9525	0.7632
11.0495	0.7249	11.0493	0.6575	11.0521	0.7929
11.1397	0.7517	11.1394	0.6695	11.1399	0.8346
11.2347	0.8211	11.2319	0.6943	11.2327	0.9472
""")
fq_tinker["log_sm"] = (fq_tinker["log_sm0"] + fq_tinker["log_sm1"] + fq_tinker["log_sm2"]) / 3

fq_geha = ascii.read(format="fast_tab", names=["log_sm", "fq","fq_e"], data_start=0,
    table="""9.9	0.084	0.0052
9.7	0.031	0.0032
9.5	0.017	0.0023
9.3	0.007	0.0016
9.1	0.002	0.0014
8.9	0.000	0.0019
8.7	0.000	0.0028
8.5	0.000	0.0043
8.3	0.000	0.0065
8.1	0.000	0.0091
7.8	0.000	0.0093
7.3	0.000	0.0193
""")
fq_geha["fq_ue"] = fq_geha["fq"] + fq_geha["fq_e"]
fq_geha["fq_le"] = np.maximum(0, fq_geha["fq"] - fq_geha["fq_e"])

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 4.3))

ax_this = ax[0]

for hostid, color in color_dict.items():
    host_this = hostlist[np.argmax(hostlist["HOSTID"] == hostid)]
    lf_data_this = lf_data[hostid]
    ax_this.plot(lf_data_this['lf_x'], lf_data_this['lf_y'], lw=1.5, c=color, label=host_this['COMMON_NAME'], alpha=0.95, rasterized=True)
    ax_this.fill_between(lf_data_this['lf_x'], lf_data_this['lf_y'], lf_data_this['lf_y_est'], lw=0, color=color, alpha=0.25, rasterized=True)

for hostid, ls in zip(("MW", "M31"), ('--', '-.')): 
    lf_data_this = lf_data[hostid]
    ax_this.plot(lf_data_this['lf_x'], lf_data_this['lf_y'], lw=1, c="k", label=hostid, ls=ls)
    
nrow1, nrow2 = 18, 13
nrow3 = len(hostlist) + 2 - nrow1 - nrow2
(lines, labels) = ax_this.get_legend_handles_labels()
l = plt.Line2D([0], [0], linestyle='none', marker='')
for _ in range(nrow1-nrow2):
    lines.insert(nrow1 + nrow2, l)
    labels.insert(nrow1 + nrow2, "")
for _ in range(nrow1-nrow3):
    lines.append(l)
    labels.append("")

ax_this.legend(lines, labels, fontsize=10, ncol=3, loc='upper left', handlelength=0.9, handletextpad=0.5, columnspacing=0.6, borderpad=0.1, labelspacing=0.25)
ax_this.axvspan(-12.3, -10, color='k', alpha=0.15, rasterized=True, zorder=99)
ax_this.set_xlim(-22, -10)
ax_this.set_ylim(-0.1, 13)
ax_this.set_xlabel(r'$M_{r, o}$  [mag]')
ax_this.set_ylabel(r'$N_{\rm sat}(<M_{r, o})$')
ax_this.yaxis.set_ticks_position('both')

ax_this = ax[1]



ax_this.fill_between(fq_tinker["log_sm"], fq_tinker["fq_le"], fq_tinker["fq_ue"], color="C4", alpha=0.7, lw=0, rasterized=True)
ax_this.fill_between(fq_geha["log_sm"], fq_geha["fq_le"], fq_geha["fq_ue"], color="C4", alpha=0.7, rasterized=True)


ax_this.errorbar(fq_lg["log_sm"], fq_lg["fq"], marker='x', ls="", c="C1", ms=6)
ax_this.fill_between(fq_lg["log_sm"], fq_lg["fq_le"], fq_lg["fq_ue"], color="C1", alpha=0.25, lw=0, rasterized=True)

pbins = np.linspace(0, 100, 8)
log_sm = np.concatenate([sats['log_sm']])
bins = np.percentile(log_sm, pbins)
x = np.percentile(log_sm, (pbins[1:]+pbins[:-1])*0.5)

interloper_corr = lambda d: 0.95 - d*0.001

n_q = np.histogram(Query('EW_Halpha < 2').filter(sats, 'log_sm'), bins)[0]
sats_q = Query('EW_Halpha < 2').filter(sats, ["log_sm", "RHOST_KPC"])
sats_maybe_q = Query().filter(sats_maybe_all, ["log_sm", "RHOST_KPC", "p_sat_corrected"])
n_q_corr = np.histogram(sats_q['log_sm'], bins, weights=interloper_corr(sats_q["RHOST_KPC"]))[0]
n_q_corr += np.histogram(sats_maybe_q['log_sm'], bins, weights=sats_maybe_q["p_sat_corrected"]*interloper_corr(sats_maybe_q["RHOST_KPC"]))[0]

n_all = np.histogram(sats['log_sm'], bins)[0]
n_all_corr = np.histogram(sats['log_sm'], bins, weights=interloper_corr(sats["RHOST_KPC"]))[0]
n_all_corr += np.histogram(sats_maybe_all['log_sm'], bins, weights=sats_maybe_all["p_sat_corrected"]*interloper_corr(sats_maybe_all["RHOST_KPC"]))[0]

p = n_q/n_all
pb = (n_q+1)/(n_all+2)
perr = np.sqrt(pb*(1-pb) / n_all)
p1 = np.minimum(1, p + perr)
p2 = np.maximum(0, p - perr)
p_corr = np.maximum(0, (n_q_corr/n_all_corr) - p)

ax_this.errorbar(x, p, yerr=(p-p2, p1-p), ls='', marker='o', lw=2, capsize=2, ms=5, c="C2")
ax_this.errorbar(x, p, yerr=(np.zeros_like(p), p1-p+p_corr), ls='', marker='', alpha=0.6, c="C2", lw=2, rasterized=True)



this_work_label = "This work" and "Mao+20"
ax_this.set_xlim(5.3, 11)
ax_this.text(6.75, 0.2, f"SAGA sat.\n({this_work_label})", ha="right", va="bottom", color="C2", fontsize=13, fontweight="bold")
ax_this.text(7.2, 0.91, "MW+M31 sat. (Wetzel+15)", ha="left", va="bottom", color="C1", fontsize=13, fontweight="bold")
ax_this.text(10.7, 0.6, "Field gal.\n(Geha+12)", ha="right", va="bottom", color="C4", fontsize=13, fontweight="bold")
#plt.legend(loc="upper right",  markerfirst=False, fontsize=13,frameon=True, handletextpad=0)
ax_this.set_xlabel("$\\log\;[M_*/M_\\odot]$")
ax_this.set_ylabel("Quenched fraction")
ax_this.set_ylim(-0.05, 1.05)
ax_this.axhline(0, lw=0.5, color="grey")
ax_this.axhline(1, lw=0.5, color="grey")

fig.tight_layout()
plt.savefig('/home/yymao/Downloads/sat_lf_qf.pdf', bbox_inches='tight')