In [None]:
import SAGA
print(SAGA.__version__)

from SAGA import ObjectCuts as C
from easyquery import *
from SAGA.utils import *
from SAGA.utils.distance import d2m
from SAGA.targets import add_cut_scores
from SAGA.objects.object_catalog import calc_fiducial_p_sat
from astropy.table import vstack, Table, join
from astropy.coordinates import Distance

from collections import defaultdict
from itertools import chain

import numexpr as ne
from scipy.stats import binned_statistic
from scipy.stats import iqr
from scipy.stats import norm
from scipy.optimize import minimize_scalar
from functools import reduce

In [None]:
from statsmodels.discrete.discrete_model import Logit

class MyModel(Logit):
    def __init__(self, endog, exog, **kwargs):
        self._max_value = float(kwargs.get("max_value", 1))
        if not (self._max_value > 0 and self._max_value <= 1):
            raise ValueError
        super(MyModel, self).__init__(endog, exog, **kwargs)
        
    def _Lambda(self, X):
        return 
        
    def cdf(self, X):
        return self._max_value * super(MyModel, self).cdf(X)
    
    def pdf(self, X):
        return self._max_value * super(MyModel, self).pdf(X)
    
    def loglike(self, params):
        return np.sum(self.loglikeobs(params))
        
    def loglikeobs(self, params):
        y = self.endog
        X = self.exog
        L = self.cdf(np.dot(X, params))
        return np.log(L) * y + np.log1p(-L) * (1-y)
    
    def hessian(self, params):
        X = self.exog
        L = super(MyModel, self).cdf(np.dot(X,params))
        return -np.dot(self._max_value*L*(1-L)*X.T,X)

In [None]:
def generate_label_arr(data, valid_query, *value_setters):
    
    label_arr = np.empty(len(data))
    label_arr.fill(-1.0)
    label_arr[Query(valid_query).mask(data)] = 0

    for q, setter in value_setters:
        q_this = Query(valid_query, q)
        if callable(setter):
            value = setter(q_this.filter(data))
        else:
            value = setter
        label_arr[q_this.mask(data)] = value
    
    return label_arr

class LogitFit:
    def __init__(self, data, label_arr, feature_cols=("r_mag", "sb_r", "gr"), max_prob=1, add_prob_column=None, silent=False):
        valid_label_mask = (label_arr >= 0) & (label_arr <= 1)       
        self._feature_cols = list(feature_cols)
        feature_arr = get_feature_array(data, self._feature_cols, True)
        logit_mod = MyModel(label_arr[valid_label_mask], feature_arr[valid_label_mask], max_value=max_prob)
        self.logit_res = logit_mod.fit(disp=False)
        if add_prob_column:
            if add_prob_column is True:
                add_prob_column = "p_sat"
            data[add_prob_column] = self.calc_prob(feature_arr, False)
        if not silent:
            display(self.logit_res.summary())

    def calc_prob(self, data, convert_arr_format=True):
        feature_arr = get_feature_array(data, self._feature_cols, True) if convert_arr_format else data
        return self.logit_res.predict(feature_arr)
    
def calc_prob_with_params(base, params, max_prob=1, feature_cols=("r_mag", "sb_r", "gr")):
        mu = np.zeros(len(base))
        for p, col in zip(params, feature_cols):
            mu += base[col] * p
        mu += params[-1]
        mu = np.where(np.isnan(mu), np.inf, mu)
        p = max_prob / (1 + np.exp(-mu))
        return p

In [None]:
saga = SAGA.QuickStart()
saga.set_default_base_version("paper2")

In [None]:
def get_feature_array(table, cols, add_constant=False):
    data = [table[col].astype(np.float64) for col in cols]
    if add_constant:
        data.append(np.ones(len(table), dtype=np.float64))
    return np.vstack(data).T 

In [None]:
def annotate_base(base):
    base["Mr_"] = base["r_mag"] - d2m(base["HOST_DIST"])
    base["Mr"] = np.where(np.isfinite(base["Mr"]) & C.has_spec.mask(base), base["Mr"], base["Mr_"])
    base["log_sm"] = np.where(np.isfinite(base["log_sm"]), base["log_sm"], 1.254 + 1.0976 * base["gr"] - 0.4 * base["Mr"])
    
    base["human_selected"] = 0
    hs = saga.database["human_selected"].read()
    for i in range(1, 4):
        base["human_selected"] += np.in1d(base["OBJID"], Query(f"score == {i}").filter(hs, "OBJID")).astype(np.int) * i

    base["gr_c"] = np.where(
        C.valid_g_mag.mask(base), 
        base["gr"],
        np.where(
            C.valid_i_mag.mask(base), 
            base["ri"]+0.1, 
            np.where(C.valid_z_mag.mask(base), base["rz"]+0.2, 0.92-0.03*base["r_mag"]),
        ),
    )

    base["sb_c"] = np.where(
        C.valid_sb.mask(base), 
        base["sb_r"],
        20.75+0.6*(base["r_mag"]-14),
    )

    return base

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_all.fits").read()
base = annotate_base(base)
hosts = np.unique(base["HOSTID"])
print(len(hosts))

In [None]:
base_other = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_other.fits").read()
base_other = annotate_base(base_other)
hosts_other = np.unique(base_other["HOSTID"])
print(len(hosts_other))

In [None]:
HostQuery = saga.host_catalog.construct_host_query

In [None]:
sup = np.testing.suppress_warnings()
sup.filter(message="invalid value encountered in true_divide")  # module must match exactly

@sup
def plot_prob_hist(ax, data, col_x, bins, additonal_label=None, label_query=C.is_sat):
    ys = [label_query.mask(data), np.ones(len(data), dtype=np.int), (data['p_sat'] if 'p_sat' in data.colnames else np.zeros(len(data)))]
    s = binned_statistic(data[col_x], ys, bins=bins, statistic='sum')
    
    dn = np.where(s.statistic[1], s.statistic[1], 1)
    p = s.statistic[0] / dn
    e = np.maximum(np.sqrt(s.statistic[0] * (s.statistic[1] - s.statistic[0])), 1) * 0.84 / (dn**1.5)
    eh = np.where(e+p <= 1, e, 1-p)
    el = np.where(p-e >= 0, e, p)
    x = (s.bin_edges[1:] + s.bin_edges[:-1]) * 0.5
    
    ax.bar(x, p, s.bin_edges[1:] - s.bin_edges[:-1], alpha=0.6, color='C0');
    ax.errorbar(x, p, [el, eh], ls='', c='C0');
    
    if 'p_sat' in data.colnames:
        ax.plot(x, s.statistic[2] / s.statistic[1], c='C1')
    ax.set_xlabel(col_x + (", {}".format(additonal_label) if additonal_label else ""))
    
    ax.set_ylim(0, max(np.nanmax(s.statistic[2] / s.statistic[1]), np.nanmax(p))*1.1)

In [None]:
sat_prob_setter = (
    #(Query("human_selected > 0", C.paper2_targeting_cut, ~C.has_spec), 0.4),
    (C.is_sat, 1),
)

In [None]:
valid = Query(
    HostQuery("paper2_need_spec / paper2_total < 0.2"),
    "r_mag > 13", 
    "r_mag < 20.75", 
    reduce(lambda a, b: (a | b), [q[0] for q in sat_prob_setter], Query(C.has_spec, "SPEC_Z >= 0.003")),
)

In [None]:
feature_cols = ("r_mag", "sb_c", "gr_c")
valid_base = valid.filter(base, list(feature_cols) + ["Mr_", "SATS", "OBJID", "SPEC_Z", "ZQUALITY", "gr", "gr_err", "human_selected", "g_mag", "HOST_VHOST", "radius", "HOST_DIST", "sb_r", "sb_r_err"])

In [None]:
def obj_func(max_prob):
    model_sat = LogitFit(
        valid_base, 
        generate_label_arr(
            valid_base, 
            Query(),
            *sat_prob_setter
        ), 
        max_prob=max_prob,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    return -model_sat.logit_res.llf

res = minimize_scalar(obj_func, bounds=(0.2, 1), method="bounded")
max_prob = np.round(res.x, 3)
print(res)
print(max_prob)

In [None]:
model_sat = LogitFit(
    valid_base, 
    generate_label_arr(
        valid_base, 
        Query(),
        *sat_prob_setter
    ), 
    max_prob=max_prob,
    feature_cols=feature_cols,
    add_prob_column=False,
)
params_best = model_sat.logit_res.params.copy()

In [None]:
params = []

pick = reduce(lambda a, b: (a | b), [q[0] for q in sat_prob_setter], C.is_sat)

sats_idx = np.flatnonzero(pick.mask(valid_base))
nonsats_idx = np.flatnonzero((~pick).mask(valid_base))
n_sats = len(sats_idx)

for i in range(100):
    sats_idx_this = sats_idx[np.random.randint(n_sats, size=n_sats)]
    base_this = valid_base[np.concatenate([sats_idx_this, nonsats_idx])]
    model_sat = LogitFit(
        base_this, 
        generate_label_arr(
            base_this, 
            Query(),
            *sat_prob_setter
        ), 
        max_prob=max_prob,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    params.append(model_sat.logit_res.params.copy())

params = np.array(params)

In [None]:
def add_probs(base, params):
    prob_sat_collect = np.array([calc_prob_with_params(base, params_this, max_prob, feature_cols) for params_this in params])

    base["p_sat_84_o"], base["p_sat_50_o"], base["p_sat_16_o"] = np.percentile(prob_sat_collect, norm.sf([-1, 0, 1])*100.0, axis=0)
    base["p_sat_o"] = base["p_sat_50_o"]
    
    for q, p in sat_prob_setter:
        if p < 1:
            prob_sat_collect[:,q.mask(base)] = p + np.random.randn(q.count(base))*1e-5
        else:
            prob_sat_collect[:,q.mask(base)] = p
    
    base["p_sat_84"], base["p_sat_50"], base["p_sat_16"] = np.percentile(prob_sat_collect, norm.sf([-1, 0, 1])*100.0, axis=0)
    base["p_sat"] = base["p_sat_50"]
    return base, prob_sat_collect

In [None]:
base, prob_sat = add_probs(base, params)
base_other, prob_sat_other = add_probs(base_other, params)

In [None]:
base["p_sat_best"] = calc_prob_with_params(base, params_best, max_prob, feature_cols)
base_other["p_sat_best"] = calc_prob_with_params(base_other, params_best, max_prob, feature_cols)

params_approx = np.round(params_best, 3).tolist()
base["p_sat_approx_new"] = calc_prob_with_params(base, params_approx, max_prob, feature_cols)
base_other["p_sat_approx_new"] = calc_prob_with_params(base_other, params_approx, max_prob, feature_cols)
print(params_approx + [max_prob])

for q, p in sat_prob_setter:
    if p < 1:
        base["p_sat_approx_new"][q.mask(base)] = p + np.random.randn(q.count(base))*1e-5
    else:
        base["p_sat_approx_new"][q.mask(base)] = p

In [None]:
pt = norm.sf([-1, 0, 1])*100.0
X = []

high_priority_sb_new = Query("sb_r + abs(sb_r_err) - 0.7 * (r_mag - 14) > 18.5") | (~C.valid_sb)
main_targeting_cuts = high_priority_sb_new & C.high_priority_gr

base["idx"] = np.arange(len(base))
base_other["idx"] = np.arange(len(base_other))

for base_this in chain(base.group_by("HOSTID").groups, base_other.group_by("HOSTID").groups):
    hostid = base_this["HOSTID"][0]
    prob_this = (prob_sat if hostid in hosts else prob_sat_other)[:,base_this["idx"]]
    
    mask = Query(C.faint_end_limit, ~C.has_spec).mask(base_this)
    row = [
        hostid, 
        Query(C.basic_cut2, main_targeting_cuts).count(base_this),
        Query(C.basic_cut2, main_targeting_cuts, ~C.has_spec).count(base_this),
        Query(C.is_sat, C.faint_end_limit).count(base_this),
        base_this["p_sat_approx"][mask].sum(),
        base_this["p_sat_approx_new"][mask].sum(),
        base_this["p_sat_best"][mask].sum(),
    ]
    row.extend(np.percentile(prob_this[:,mask].sum(axis=1), pt))
    X.append(tuple(row))

X = Table(np.array(X, np.dtype([
    ('HOSTID', '<U10'), 
    ('really_need_spec_total', np.int), 
    ('really_need_spec', np.int), 
    ('n_sats', np.int), 
    ('sats_missed_approx', np.float),
    ('sats_missed_approx_new', np.float),
    ('sats_missed_best', np.float),
    ('sats_missed_84', np.float),
    ('sats_missed_pred', np.float),
    ('sats_missed_16', np.float),
])))

In [None]:
plt.figure(figsize=(4.8, 4))

X = X[np.lexsort([X["sats_missed_pred"], X["really_need_spec"]])]

dx = np.ones(len(X))
dx -= 0.01
dx[::2] += 0.02
X["really_need_spec_dx"] = np.maximum(X["really_need_spec"], 1) * dx

p_cut = 0.8

#complete_def_approx = Query(saga.paper2_complete_definition, QueryMaker.in1d("HOSTID", hosts))
#complete_def_approx = Query(f"sats_missed_approx < 0.33", f"really_need_spec < {n_cut}", QueryMaker.in1d("HOSTID", hosts))
complete_def = Query(f"really_need_spec / really_need_spec_total < (1-{p_cut})")
#complete_def = saga.host_catalog.construct_host_query("paper2_complete")
p1_complete= HostQuery("paper1_complete")

#assert set(complete_def_approx.filter(X, "HOSTID")) == set(complete_def.filter(X, "HOSTID"))
#assert Query(p1_complete, ~complete_def).count(X) == 0

for q, color, label, marker in (
    (Query(~complete_def), 'C0', "Incomplete hosts ({})", 'o'),
    (Query(complete_def, ~p1_complete), 'C3', 'Newly complete hosts ({})', 's'),
    (Query(p1_complete), 'C1', 'Paper I complete hosts ({})', '^'),
):
    X1 = q.filter(X)
    n = 1#np.maximum(1,X1["n_sats"])
    plt.errorbar(1-X1["really_need_spec"]/X1["really_need_spec_total"], X1["sats_missed_pred"] / n, 
                 yerr=[(X1["sats_missed_pred"]-X1["sats_missed_16"]) / n, (X1["sats_missed_84"]-X1["sats_missed_pred"]) / n], marker=marker, ls='', 
                 label=label.format(len(X1)), ms=3, color=color)

    
#plt.xscale("log")
plt.yscale("log")
#plt.xlim(1, None)


xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()

#plt.axvspan(n_cut, xlim[-1], color="C7", alpha=0.2)
#y_cut = np.log(p_cut / ylim[0]) / np.log(ylim[1] / ylim[0])

#plt.axvspan(xlim[0], n_cut, ymax=y_cut, color="C2", alpha=0.25)
#plt.axvspan(xlim[0], n_cut, ymin=y_cut, color="C6", alpha=0.2)

#plt.text(xlim[0], p_cut*1.2, "    needs single-slit\n$\\uparrow$ follow-up", fontsize=12)
#plt.text(xlim[0], p_cut/1.1, r"$\downarrow$ complete", fontsize=12, va="top")
#plt.text(n_cut, ylim[1]*0.95, "$\\rightarrow$ needs more\n      pointings", fontsize=12, va="top")

#plt.axhline(p_cut, color="C7", ls="--")
plt.axvline(p_cut, color="C7", ls="--")


plt.legend(loc="lower left", markerfirst=True, handletextpad=-0.25, fontsize=13, borderpad=0)
plt.ylim(0.05, 10)
plt.xlabel("Spec coverage in primary targeting region")
plt.ylabel("Model incompleteness ($\sum\, \mathcal{R}_{\mathrm{sat},i}$)")

plt.tight_layout()

In [None]:
data = valid.filter(base)

fig, ax = plt.subplots(ncols=3, figsize=(10, 3.5))

for i, ax_this in enumerate(ax):
    if i == 0:
        key = "r_mag"
        ax_this.set_xlabel(r'$r_o$')
        bins = np.linspace(13, 21, 10)
        ax_this.set_xlim(bins[0], bins[-1])
        bins[-1] = 20.75
        #ax_this.set_xticks([13, 15, 17, 19, 21])
        #ax_this.axvline(20.75, c="grey", lw=1)
        #ax_this.axvspan(-12.3, -11, color="grey", alpha=0.2)
        
    elif i == 1:
        key = "sb_r"
        ax_this.set_xlabel(r'$\mu_{r_o,{\rm eff}}$')
        bins = np.linspace(18.75, 26.25, 11)
        ax_this.set_xlim(bins[0], bins[-1])
        ax_this.set_xticks([19, 20, 21, 22, 23, 24, 25, 26]) 
        #ax_this.axvline(18.75, c="grey", lw=1)
        #ax_this.axvspan(18, 18.75, color="grey", alpha=0.2)
        
    elif i == 2:
        key = "gr"
        ax_this.set_xlabel(r'$(g-r)_o$')
        bins = np.linspace(-0.05, 0.85, 10)
        ax_this.set_xlim(bins[0], bins[-1])
        ax_this.set_xticks([0, 0.2, 0.4, 0.6, 0.8])
        #ax_this.axvspan(0.9, 1, color="grey", alpha=0.2)
        
    x = bins[:-1]
    width = np.ediff1d(bins)
    bin_index = np.digitize(data[key], bins)
    
    
    y = np.bincount(bin_index, minlength=len(bins)+1)[1:-1]
    y1 = np.zeros_like(y, dtype=np.float64)
    for q, p in sat_prob_setter:
        y1 += np.bincount(bin_index[q.mask(data)], minlength=len(bins)+1)[1:-1] * p    
    ax_this.bar(x, y1/y, width=width, align='edge', alpha=0.3, label="Data", color="C1")
    
    p = (y1+1)/(y+2)
    e = np.sqrt(p*(1-p) / y)
    eh = np.where(e+p <= 1, e, 1-p)
    el = np.where(p-e >= 0, e, p)
    x = (bins[1:] + bins[:-1]) * 0.5
    ax_this.errorbar(x, p, [el, eh], ls='', c='C1')

    s1 = binned_statistic(data[key], data["p_sat_16_o"], bins=bins, statistic='sum')
    s2 = binned_statistic(data[key], data["p_sat_84_o"], bins=bins, statistic='sum')

    ax_this.fill_between(x, s1.statistic/y, s2.statistic/y, label="Model", color="C9", alpha=0.8)
    ax_this.set_yscale("log")
    ax_this.set_ylim(1e-4, 1)

ax[0].legend()
ax[0].set_ylabel("Satellite rate")
plt.tight_layout(w_pad=0.1)
plt.savefig('/home/yymao/Downloads/model_demo.pdf', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

for ax_this in ax:
    ax_this.set_xlim(11.8, 21)
    ax_this.set_xlabel(r"$r_o$")
ax[0].set_ylabel(r"$\mu_{r_o, {\rm eff}}$")
ax[1].set_ylabel(r"$(g-r)_o$")

cut = Query("p_sat >= 1e-3")

t = (~cut).filter(base)
ax[0].scatter(t["r_mag"], t["sb_c"], c='grey', s=0.1, lw=0, rasterized=True)
ax[1].scatter(t["r_mag"], t["gr_c"], c='grey', s=0.1, lw=0, rasterized=True, label="Targets ($\mathcal{R}_\mathrm{sat}<10^{-3}$)")

t = Query(cut).filter(base)
t.sort("p_sat")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_sat"]), s=5, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.85,)
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_sat"]), s=5, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.8, label=r"Targets ($\mathcal{R}_\mathrm{sat}\geq10^{-3}$)")

r = np.linspace(11.8, 21, 10)
#ax[0].plot(r, 0.6*(r-14) + 18.55, color='C3',  ls="--",alpha=0.9)
#ax[1].plot(r, -0.06*(r-14) + 0.9, color='C3', alpha=0.9, ls="--", label='Targeting cuts')
#ax[1].plot(r, -0.05*(r-14) + 0.85, color='C3', alpha=0.4, ls="--", label='Targeting cuts')
ax[0].plot(r, 0.7*(r-14) + 18.5, color='C4', ls="--", label='Primary targeting region')
ax[1].plot(r, -0.06*(r-14) + 0.9, color='C4', ls="--",)
#ax[1].plot(r, -0.06*(r-14) + 1.1, color='C3', ls="--", alpha=0.4)

#ax[0].plot(r, 0.7*(r-14) + 18.7, color='C3',  ls="--",alpha=0.9)
#ax[1].plot(r, -0.05*(r-14) + 1, color='C3', alpha=0.7, ls="--")

ax[0].legend(loc="lower right", scatterpoints=3, frameon=True, fontsize=12)

lgnd = ax[1].legend(loc="upper left", scatterpoints=3, frameon=True, fontsize=12)
#lgnd.legendHandles[0]._sizes = [10]
lgnd.legendHandles[0]._sizes = [20]
lgnd.legendHandles[1]._sizes = [20]
#lgnd.legendHandles[1].set_color(plt.cm.winter([0.8, 0.5, 0.2]))
lgnd.legendHandles[1].set_color(plt.cm.winter([0.8, 0.5, 0.2]))


ax[0].set_ylim(17, 27)
ax[1].set_ylim(-0.3, 2)
fig.tight_layout()

cbar = fig.colorbar(cs, ax=ax, pad=0.01, ticks=np.log10([0.001, 0.01, 0.1, 0.5]), label="Satellite Rate ($\\mathcal{R}_{\\rm sat}$)")
cbar.ax.set_yticklabels(['$10^{-3}$', '$0.01$', '$0.1$', '$0.5$']) 
cbar.ax.set_rasterized(True)
plt.savefig('/home/yymao/Downloads/sat-prob.pdf', dpi=200)