In [None]:
import sys
sys.path.insert(0, "/home/yymao/Dropbox/Academia/Collaborations/SAGA-local/saga")

In [None]:
import SAGA
print(SAGA.__version__)

from SAGA import ObjectCuts as C
from easyquery import *
from SAGA.utils import *
from SAGA.targets import add_cut_scores
from SAGA.objects.object_catalog import calc_fiducial_p_sat
from astropy.table import vstack, Table, join
from astropy.coordinates import Distance

from collections import defaultdict
from itertools import chain

import numexpr as ne
from scipy.stats import binned_statistic
from scipy.stats import iqr
from scipy.stats import norm

In [None]:
from statsmodels.discrete.discrete_model import Logit

class MyModel(Logit):
    def __init__(self, endog, exog, **kwargs):
        self._max_value = float(kwargs.get("max_value", 1))
        if not (self._max_value > 0 and self._max_value <= 1):
            raise ValueError
        super(MyModel, self).__init__(endog, exog, **kwargs)
        
    def _Lambda(self, X):
        return 
        
    def cdf(self, X):
        return self._max_value * super(MyModel, self).cdf(X)
    
    def pdf(self, X):
        return self._max_value * super(MyModel, self).pdf(X)
    
    def loglike(self, params):
        return np.sum(self.loglikeobs(params))
        
    def loglikeobs(self, params):
        y = self.endog
        X = self.exog
        L = self.cdf(np.dot(X, params))
        return np.log(L) * y + np.log1p(-L) * (1-y)
    
    def hessian(self, params):
        X = self.exog
        L = super(MyModel, self).cdf(np.dot(X,params))
        return -np.dot(self._max_value*L*(1-L)*X.T,X)

In [None]:
def generate_label_arr(data, valid_query, *value_setters):
    
    label_arr = np.empty(len(data))
    label_arr.fill(-1.0)
    label_arr[Query(valid_query).mask(data)] = 0

    for q, setter in value_setters:
        q_this = Query(valid_query, q)
        if callable(setter):
            value = setter(q_this.filter(data))
        else:
            value = setter
        label_arr[q_this.mask(data)] = value
    
    return label_arr

class LogitFit:
    def __init__(self, data, label_arr, feature_cols=("r_mag", "sb_r", "gr"), max_prob=1, add_prob_column=None, silent=False):
        valid_label_mask = (label_arr >= 0) & (label_arr <= 1)       
        self._feature_cols = list(feature_cols)
        feature_arr = get_feature_array(data, self._feature_cols, True)
        logit_mod = MyModel(label_arr[valid_label_mask], feature_arr[valid_label_mask], max_value=max_prob)
        self.logit_res = logit_mod.fit(disp=False)
        if add_prob_column:
            if add_prob_column is True:
                add_prob_column = "p_sat"
            data[add_prob_column] = self.calc_prob(feature_arr, False)
        if not silent:
            display(self.logit_res.summary())

    def calc_prob(self, data, convert_arr_format=True):
        feature_arr = get_feature_array(data, self._feature_cols, True) if convert_arr_format else data
        return self.logit_res.predict(feature_arr)
    
def calc_prob_with_params(base, params, max_prob=1, feature_cols=("r_mag", "sb_r", "gr")):
        mu = np.zeros(len(base))
        for p, col in zip(params, feature_cols):
            mu += base[col] * p
        mu += params[-1]
        mu = np.where(np.isnan(mu), np.inf, mu)
        p = max_prob / (1 + np.exp(-mu))
        return p

In [None]:
saga = SAGA.QuickStart()
saga.set_default_base_version("paper2")

In [None]:
def get_feature_array(table, cols, add_constant=False):
    data = [table[col].astype(np.float64) for col in cols]
    if add_constant:
        data.append(np.ones(len(table), dtype=np.float64))
    return np.vstack(data).T 

In [None]:
def annotate_base(base):
    #base = add_surface_brightness(base)
    base = add_cut_scores(base)
    base['Mr_est'] = base['r_mag'] - Distance(base['HOST_DIST'][0], unit="Mpc").distmod.value
    
    base["gr_c"] = np.where(
        C.valid_g_mag.mask(base), 
        base["gr"],
        np.where(
            C.valid_i_mag.mask(base), 
            base["ri"]+0.1, 
            np.where(C.valid_z_mag.mask(base), base["rz"]+0.2, 0.92-0.03*base["r_mag"]),
        ),
    )

    base["sb_c"] = np.where(
        C.valid_sb.mask(base), 
        base["sb_r"],
        20.75+0.6*(base["r_mag"]-14),
    )
    
    return base

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_all.fits").read()
base = annotate_base(base)
hosts = np.unique(base["HOSTID"])
print(len(hosts))

In [None]:
base_other = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_other.fits").read()
base_other = annotate_base(base_other)
hosts_other = np.unique(base_other["HOSTID"])
print(len(hosts_other))

In [None]:
Query(C.is_sat, ~C.high_priority_cuts).count(base), Query(C.is_sat, ~C.high_priority_cuts).count(base_other)

In [None]:
HostQuery = saga.host_catalog.construct_host_query

In [None]:
sup = np.testing.suppress_warnings()
sup.filter(message="invalid value encountered in true_divide")  # module must match exactly

@sup
def plot_prob_hist(ax, data, col_x, bins, additonal_label=None, label_query=C.is_sat):
    ys = [label_query.mask(data), np.ones(len(data), dtype=np.int), (data['p_sat'] if 'p_sat' in data.colnames else np.zeros(len(data)))]
    s = binned_statistic(data[col_x], ys, bins=bins, statistic='sum')
    
    dn = np.where(s.statistic[1], s.statistic[1], 1)
    p = s.statistic[0] / dn
    e = np.maximum(np.sqrt(s.statistic[0] * (s.statistic[1] - s.statistic[0])), 1) * 0.84 / (dn**1.5)
    eh = np.where(e+p <= 1, e, 1-p)
    el = np.where(p-e >= 0, e, p)
    x = (s.bin_edges[1:] + s.bin_edges[:-1]) * 0.5
    
    ax.bar(x, p, s.bin_edges[1:] - s.bin_edges[:-1], alpha=0.6, color='C0');
    ax.errorbar(x, p, [el, eh], ls='', c='C0');
    
    if 'p_sat' in data.colnames:
        ax.plot(x, s.statistic[2] / s.statistic[1], c='C1')
    ax.set_xlabel(col_x + (", {}".format(additonal_label) if additonal_label else ""))
    
    ax.set_ylim(0, max(np.nanmax(s.statistic[2] / s.statistic[1]), np.nanmax(p))*1.1)

In [None]:
valid = Query(
    HostQuery("really_need_spec < 40"), 
    "r_mag > 13", 
    C.faint_end_limit, 
    Query(C.has_spec, "SPEC_Z >= 0.003") | (~C.has_spec),
)

In [None]:
feature_cols = ("r_mag", "sb_c", "gr_c")
valid_base = valid.filter(base, list(feature_cols) + ["SATS"])

In [None]:
max_prob_list = np.arange(0.3, 0.51, 0.01)
llf = []
for max_prob in max_prob_list:
    model_sat = LogitFit(
        valid_base, 
        generate_label_arr(
            valid_base, 
            Query(),
            (C.is_sat, 1),
        ), 
        max_prob=max_prob,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    llf.append(model_sat.logit_res.llf)

In [None]:
max_prob = np.round(max_prob_list[np.argmax(llf)], 2)
plt.plot(max_prob_list, llf)
plt.axvline(max_prob, c="C3")
print(max_prob)

In [None]:
model_sat = LogitFit(
    valid_base, 
    generate_label_arr(
        valid_base, 
        Query(),
        (C.is_sat, 1),
    ), 
    max_prob=max_prob,
    feature_cols=feature_cols,
    add_prob_column=False,
)
params_best = model_sat.logit_res.params.copy()

In [None]:
params = []

sats_idx = np.flatnonzero(C.is_sat.mask(valid_base))
nonsats_idx = np.flatnonzero((~C.is_sat).mask(valid_base))
n_sats = len(sats_idx)

for i in range(1000):
    sats_idx_this = sats_idx[np.random.randint(n_sats, size=n_sats)]
    base_this = valid_base[np.concatenate([sats_idx_this, nonsats_idx])]
    model_sat = LogitFit(
        base_this, 
        generate_label_arr(
            base_this, 
            Query(),
            (C.is_sat, 1),
        ), 
        max_prob=max_prob,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    params.append(model_sat.logit_res.params.copy())

params = np.array(params)

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(9,4))
params_med = np.median(params, axis=0)
ax[0].scatter(params[:,0], params[:,1], s=3)
ax[1].scatter(params[:,2], params[:,3], s=3)

mask = (np.abs(params[:,0] - params_best[0]) < 0.02) & (np.abs(params[:,1] -  params_best[1]) < 0.02) & (np.abs(params[:,2] -  params_best[2]) < 0.1)
ax[0].scatter(params[mask,0], params[mask,1], s=3)
ax[1].scatter(params[mask,2], params[mask,3], s=3)

print(np.median(params[mask,:], axis=0))
#plt.xlim(-3, -0.5)
#plt.ylim(-10, 0)

In [None]:
def add_probs(base, params):
    prob_sat_collect = np.array([calc_prob_with_params(base, params_this, max_prob, feature_cols) for params_this in params])
    base["p_sat_84"], base["p_sat_50"], base["p_sat_16"] = np.percentile(prob_sat_collect, norm.sf([-1, 0, 1])*100.0, axis=0)
    base["p_sat"] = base["p_sat_50"]
    return base, prob_sat_collect

In [None]:
base, prob_sat = add_probs(base, params)
base_other, prob_sat_other = add_probs(base_other, params)

In [None]:
base["p_sat_best"] = calc_prob_with_params(base, params_best, max_prob, feature_cols)
base_other["p_sat_best"] = calc_prob_with_params(base_other, params_best, max_prob, feature_cols)

params_approx = np.round(params_best, 3).tolist()
base["p_sat_approx_new"] = calc_prob_with_params(base, params_approx, max_prob, feature_cols)
base_other["p_sat_approx_new"] = calc_prob_with_params(base_other, params_approx, max_prob, feature_cols)
print(params_approx + [max_prob])

In [None]:
print(
    Query(C.is_sat, C.faint_end_limit, 'r_mag > 13').filter(base, "p_sat_50").min(),
    Query(C.high_priority_cuts, C.faint_end_limit, 'r_mag > 13').filter(base, "p_sat_50").min(),
    Query(~C.high_priority_cuts, C.faint_end_limit, 'r_mag > 13').filter(base, "p_sat_50").max(),
)

In [None]:
pt = norm.sf([-1, 0, 1])*100.0
X = []
for hostid in chain(hosts, hosts_other):    
    base_this, prob_this = (base, prob_sat) if hostid in hosts else (base_other, prob_sat_other)
    mask = Query(QueryMaker.equals("HOSTID", hostid), C.faint_end_limit, ~C.has_spec).mask(base_this)
    row = [
        hostid, 
        np.count_nonzero(mask & C.high_priority_cuts.mask(base_this)),
        base_this["p_sat_approx"][mask].sum(),
        base_this["p_sat_approx_new"][mask].sum(),
        base_this["p_sat_best"][mask].sum(),
    ]
    row.extend(np.percentile(prob_this[:,mask].sum(axis=1), pt))
    X.append(tuple(row))

X = Table(np.array(X, np.dtype([
    ('HOSTID', '<U10'), 
    ('really_need_spec', np.int), 
    ('sats_missed_approx', np.float),
    ('sats_missed_approx_new', np.float),
    ('sats_missed_best', np.float),
    ('sats_missed_84', np.float),
    ('sats_missed_pred', np.float),
    ('sats_missed_16', np.float),
])))

X.write("/home/yymao/Downloads/saga_sat_prob.csv", overwrite=True)

In [None]:
plt.errorbar(X["sats_missed_approx"], X["sats_missed_pred"], yerr=[X["sats_missed_pred"]-X["sats_missed_16"], X["sats_missed_84"]-X["sats_missed_pred"]], marker='.', ls='')
plt.loglog([0.01, 5], [0.01, 5], lw=1)
plt.axvline(0.32, c='grey', lw=1)
plt.axhline(0.35, c='grey', lw=1)

In [None]:
plt.figure(figsize=(4.8, 4))

X = X[np.lexsort([X["sats_missed_pred"], X["really_need_spec"]])]

dx = np.ones(len(X))
dx -= 0.01
dx[::2] += 0.02
X["really_need_spec_dx"] = np.maximum(X["really_need_spec"], 1) * dx

p_cut = 0.35
n_cut = 40

complete_def_approx = Query(saga.paper2_complete_definition, QueryMaker.in1d("HOSTID", hosts))
#complete_def_approx = Query(f"sats_missed_approx < 0.33", f"really_need_spec < {n_cut}", QueryMaker.in1d("HOSTID", hosts))
complete_def = Query(f"sats_missed_84 < {p_cut}", f"really_need_spec < {n_cut}")
p1_complete= HostQuery("paper1_complete")

#assert set(complete_def_approx.filter(X, "HOSTID")) == set(complete_def.filter(X, "HOSTID"))
assert Query(p1_complete, ~complete_def).count(X) == 0

for q, color, label, marker in (
    (Query(~complete_def), 'C0', "Incomplete hosts ({})", 'o'),
    (Query(complete_def, ~p1_complete), 'C3', 'Newly complete hosts ({})', 's'),
    (Query(p1_complete), 'C1', 'Paper I complete hosts ({})', '^'),
):
    X1 = q.filter(X)
    plt.errorbar(X1["really_need_spec_dx"], X1["sats_missed_pred"], 
                 yerr=[X1["sats_missed_pred"]-X1["sats_missed_16"], X1["sats_missed_84"]-X1["sats_missed_pred"]], marker=marker, ls='', 
                 label=label.format(len(X1)), ms=3, color=color)

    
plt.xscale("log")
plt.yscale("log")
plt.xlim(1, None)


xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()

#plt.axvspan(n_cut, xlim[-1], color="C7", alpha=0.2)
y_cut = np.log(p_cut / ylim[0]) / np.log(ylim[1] / ylim[0])

#plt.axvspan(xlim[0], n_cut, ymax=y_cut, color="C2", alpha=0.25)
#plt.axvspan(xlim[0], n_cut, ymin=y_cut, color="C6", alpha=0.2)

plt.text(xlim[0], p_cut*1.2, "    needs single-slit\n$\\uparrow$ follow-up", fontsize=12)
plt.text(xlim[0], p_cut/1.1, r"$\downarrow$ complete", fontsize=12, va="top")
plt.text(n_cut, ylim[1]*0.95, "$\\rightarrow$ needs more\n      pointings", fontsize=12, va="top")

plt.axhline(p_cut, color="C7", ls="--")
plt.axvline(n_cut, color="C7", ls="--")


plt.legend(loc="lower right", markerfirst=False, handletextpad=-0.25, bbox_to_anchor=(1.07, -0.02), fontsize=13)

plt.xlabel("# primary targets without redshift")
plt.ylabel("Model incompleteness ($\sum\, \mathcal{R}_{\mathrm{sat},i}$)")

plt.tight_layout()
plt.savefig("/home/yymao/Downloads/host_completeness_def.pdf")

In [None]:
print(set(complete_def_approx.filter(X, "HOSTID")) == set(complete_def.filter(X, "HOSTID")))
Query(~complete_def, f"really_need_spec < {n_cut}").filter(X, "sats_missed_approx").min(), complete_def.filter(X, "sats_missed_approx").max()

In [None]:
saga.host_catalog.load(Query(~complete_def, f"really_need_spec < {n_cut}+5").filter(X, "HOSTID").tolist())["HOSTID", "COMMON_NAME"]

In [None]:
hid = Query(~complete_def, f"really_need_spec < {n_cut}").filter(X, "HOSTID")
Query(QueryMaker.in1d("HOSTID", hid), ~C.has_spec, "p_sat > 0.01", C.basic_cut2).filter(base, ["OBJID", "RA", "DEC", "HOSTID", "r_mag", "p_sat_50", "p_sat_approx", "OBJID_decals", "REMOVE_decals"])

In [None]:
hid = complete_def.filter(X, "HOSTID")
data = Query(C.has_spec, QueryMaker.in1d("HOSTID", hid), C.basic_cut2).filter(base)

fig, ax = plt.subplots(ncols=3, figsize=(10, 3.5))

for i, ax_this in enumerate(ax):
    if i == 0:
        key = "r_mag"
        ax_this.set_xlabel(r'$r_o$')
        bins = np.linspace(13, 21, 11)
        ax_this.set_xlim(bins[0], bins[-1])
        bins[-1] = 20.75
        ax_this.set_xticks([13, 15, 17, 19, 21])
        ax_this.axvline(20.75, c="grey", lw=1)
        
    elif i == 1:
        key = "score_sb_r"
        ax_this.set_xlabel(r'$\mu_{r_o,{\rm eff}} + \sigma_\mu + 0.6\,(r_o-14)$')
        bins = np.linspace(18, 25, 11)
        ax_this.set_xlim(bins[0], bins[-1])
        ax_this.set_xticks([18, 19, 20, 21, 22, 23, 24, 25]) 
        ax_this.axvline(18.75, c="grey", lw=1)
        
    elif i == 2:
        key = "score_gr_r"
        ax_this.set_xlabel(r'$(g-r)_o - \sigma_{gr} - 0.06\,(r_o - 14) $')
        bins = np.linspace(0.1, 1, 10)
        ax_this.set_xlim(bins[0], bins[-1])
        ax_this.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])
        ax_this.axvline(0.9, c="grey", lw=1)
        
    x = bins[:-1]
    width = np.ediff1d(bins)
    bin_index = np.digitize(data[key], bins)
    
    y1 = np.bincount(bin_index[C.is_sat.mask(data)], minlength=len(bins)+1)[1:-1]
    y = np.bincount(bin_index, minlength=len(bins)+1)[1:-1]
    ax_this.bar(x, y1/y, width=width, align='edge', alpha=0.3, label="Data", color="C1")
    
    p = (y1+1)/(y+2)
    e = np.sqrt(p*(1-p) / y)
    eh = np.where(e+p <= 1, e, 1-p)
    el = np.where(p-e >= 0, e, p)
    x = (bins[1:] + bins[:-1]) * 0.5
    ax_this.errorbar(x, p, [el, eh], ls='', c='C1')

    s1 = binned_statistic(data[key], data["p_sat_16"], bins=bins, statistic='sum')
    s2 = binned_statistic(data[key], data["p_sat_84"], bins=bins, statistic='sum')

    ax_this.fill_between(x, s1.statistic/y, s2.statistic/y, label="Model", color="C9", alpha=0.8)
    ax_this.set_yscale("log")
    ax_this.set_ylim(1e-4, 1)

ax[0].legend()
ax[0].set_ylabel("Satellite rate")
plt.tight_layout(w_pad=0.1)
plt.savefig('/home/yymao/Downloads/model_demo.pdf', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(4.8,3.5))
hid = complete_def.filter(X, "HOSTID")
data = Query(C.has_spec, QueryMaker.in1d("HOSTID", hid), C.is_clean2, C.is_galaxy2, ~C.is_sat, C.faint_end_limit).filter(base)

bins = np.linspace(-9, np.log(0.4), 31)

plt.hist(np.log10(Query("SPEC_Z >= 0.003", "SPEC_Z < 0.03").filter(data, "p_sat_50")), bins, density=True, alpha=0.5, label=r"$z < 0.03$, non-satellite");
plt.hist(np.log10(Query("SPEC_Z >= 0.03").filter(data, "p_sat_50")), bins, density=True, alpha=0.3, color="C3", label=r"$z > 0.03$");


plt.xlabel("$\\log (\\mathcal{R}_{\\rm sat})$")
plt.ylabel("Normed density")
plt.legend()
plt.tight_layout()
plt.savefig('/home/yymao/Downloads/model_separation.pdf', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

for ax_this in ax:
    ax_this.set_xlim(12, 21)
    ax_this.set_xlabel(r"$r_o$")
ax[0].set_ylabel(r"$\mu_{r_o, {\rm eff}}$")
ax[1].set_ylabel(r"$(g-r)_o$")

cut = Query(C.faint_end_limit, C.high_priority_cuts)
cut = Query("p_sat >= 1e-3")

t = (~cut).filter(base)
ax[0].scatter(t["r_mag"], t["sb_c"], c='grey', s=0.1, lw=0, rasterized=True)
ax[1].scatter(t["r_mag"], t["gr_c"], c='grey', s=0.1, lw=0, rasterized=True, label="Targets ($\mathcal{R}_\mathrm{sat}<10^{-3}$)")

"""
t = Query(cut, ~C.has_spec).filter(base)
t.sort("p_sat")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_sat"]), s=10, lw=0, cmap="winter_r", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.85, marker='X')
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_sat"]), s=10, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.8, label="Primary targets (no redshift)", marker='X')
"""

t = Query(cut).filter(base)
t.sort("p_sat")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_sat"]), s=5, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.85,)
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_sat"]), s=5, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.8, label=r"Targets ($\mathcal{R}_\mathrm{sat}\geq10^{-3}$)")

r = np.linspace(12, 21, 10)
ax[0].plot(r, 0.6*(r-14) + 18.55, color='C4',  ls="--",alpha=0.9)
ax[1].plot(r, -0.06*(r-14) + 0.9, color='C4', alpha=0.9, ls="--", label='Targeting cuts')

lgnd = ax[1].legend(loc="upper left", scatterpoints=3, frameon=True, fontsize=12)
#lgnd.legendHandles[0]._sizes = [10]
lgnd.legendHandles[1]._sizes = [20]
lgnd.legendHandles[2]._sizes = [20]
#lgnd.legendHandles[1].set_color(plt.cm.winter([0.8, 0.5, 0.2]))
lgnd.legendHandles[2].set_color(plt.cm.winter([0.8, 0.5, 0.2]))


ax[0].set_ylim(17, 27)
ax[1].set_ylim(-0.3, 2)
fig.tight_layout()

cbar = fig.colorbar(cs, ax=ax, pad=0.01, ticks=np.log10([0.001, 0.01, 0.1, 0.5]), label="Satellite Rate ($\\mathcal{R}_{\\rm sat}$)")
cbar.ax.set_yticklabels(['$10^{-3}$', '$0.01$', '$0.1$', '$0.5$']) 
plt.savefig('/home/yymao/Downloads/sat-prob.pdf', dpi=200)