In [None]:
import SAGA
print(SAGA.__version__)

In [None]:
from SAGA import ObjectCuts as C
from easyquery import *
from SAGA.utils import *
from astropy.table import vstack, Table, join
from astropy.coordinates import Distance

import numexpr as ne
from scipy.stats import binned_statistic

In [None]:
from statsmodels.discrete.discrete_model import Logit

class MyModel(Logit):
    def __init__(self, endog, exog, **kwargs):
        self._max_value = float(kwargs.get("max_value", 1))
        if not (self._max_value > 0 and self._max_value <= 1):
            raise ValueError
        super(MyModel, self).__init__(endog, exog, **kwargs)
        
    def _Lambda(self, X):
        return 
        
    def cdf(self, X):
        return self._max_value * super(MyModel, self).cdf(X)
    
    def pdf(self, X):
        return self._max_value * super(MyModel, self).pdf(X)
    
    def loglike(self, params):
        return np.sum(self.loglikeobs(params))
        
    def loglikeobs(self, params):
        y = self.endog
        X = self.exog
        L = self.cdf(np.dot(X, params))
        return np.log(L) * y + np.log1p(-L) * (1-y)
    
    def hessian(self, params):
        X = self.exog
        L = super(MyModel, self).cdf(np.dot(X,params))
        return -np.dot(self._max_value*L*(1-L)*X.T,X)

In [None]:
def generate_label_arr(data, valid_query, *value_setters):
    
    label_arr = np.empty(len(data))
    label_arr.fill(-1.0)
    label_arr[Query(valid_query).mask(data)] = 0

    for q, setter in value_setters:
        q_this = Query(valid_query, q)
        if callable(setter):
            value = setter(q_this.filter(data))
        else:
            value = setter
        label_arr[q_this.mask(data)] = value
    
    return label_arr

class LogitFit:
    def __init__(self, data, label_arr, feature_cols=("r_mag", "sb_r", "gr"), max_prob=1, add_prob_column=None, silent=False):
        valid_label_mask = (label_arr >= 0) & (label_arr <= 1)       
        self._feature_cols = list(feature_cols)
        feature_arr = get_feature_array(data, self._feature_cols, True)
        logit_mod = MyModel(label_arr[valid_label_mask], feature_arr[valid_label_mask], max_value=max_prob)
        self.logit_res = logit_mod.fit(disp=False)
        if add_prob_column:
            if add_prob_column is True:
                add_prob_column = "p_sat"
            data[add_prob_column] = self.calc_prob(feature_arr, False)
        if not silent:
            display(self.logit_res.summary())

    def calc_prob(self, data, convert_arr_format=True):
        feature_arr = get_feature_array(data, self._feature_cols, True) if convert_arr_format else data
        return self.logit_res.predict(feature_arr)
    
    
def get_feature_array(table, cols, add_constant=False):
    data = [table[col].astype(np.float64) for col in cols]
    if add_constant:
        data.append(np.ones(len(table), dtype=np.float64))
    return np.vstack(data).T 

In [None]:
saga = SAGA.QuickStart()

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_all.fits").read()
hosts = np.unique(base["HOSTID"])
print(len(hosts))

In [None]:
base["gr_c"] = np.where(
    C.valid_g_mag.mask(base), 
    base["gr"],
    np.where(
        C.valid_i_mag.mask(base), 
        base["ri"]+0.1, 
        np.where(C.valid_z_mag.mask(base), base["rz"]+0.2, 0.92-0.03*base["r_mag"]),
    ),
)

base["sb_c"] = np.where(
    C.valid_sb.mask(base), 
    base["sb_r"],
    20+0.6*(base["r_mag"]-14),
)

In [None]:
spec_by_aat_mmt = QueryMaker.contains("SPEC_REPEAT", "AAT") | QueryMaker.contains("SPEC_REPEAT", "MMT")
targeted_by_aat_mmt = QueryMaker.contains("SPEC_REPEAT_ALL", "AAT") | QueryMaker.contains("SPEC_REPEAT_ALL", "MMT")
failed_spec_aat_mmt = ~Query(spec_by_aat_mmt, C.has_spec)

In [None]:
valid = Query(
    saga.host_catalog.construct_host_query("paper2_complete"), 
    "r_mag > 17", 
    #C.faint_end_limit, 
    Query(C.has_spec, "SPEC_Z >= 0.003") | (~C.has_spec),
    targeted_by_aat_mmt,
)

In [None]:
feature_cols = ("r_mag", "sb_c", "gr_c")

In [None]:
model = LogitFit(
    base, 
    generate_label_arr(
        base, 
        valid,
        (failed_spec_aat_mmt, 1),
    ), 
    max_prob=0.35,
    feature_cols=feature_cols,
    add_prob_column=False,
)

In [None]:
prob_collect = []
params = []

valid_base = valid.filter(base, list(feature_cols) + ["SPEC_REPEAT", "SPEC_REPEAT_ALL", "ZQUALITY"])
n_valid = len(valid_base)

for i in range(1000):
    base_this = valid_base[np.random.randint(n_valid, size=n_valid)]
    model = LogitFit(
        base_this, 
        generate_label_arr(
            base_this, 
            Query(),
            (failed_spec_aat_mmt, 1),
        ), 
        max_prob=0.35,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    params.append(model.logit_res.params.copy())
    prob_collect.append(model.calc_prob(base))

prob_collect = np.array(prob_collect)

base["p_failed_med"] = np.median(prob_collect, axis=0)
#base["p_failed_mean"] = np.mean(prob_collect, axis=0)
#base["p_failed_iqr"] = iqr(prob_collect, axis=0, scale="normal")
#base["p_failed_std"] = np.std(prob_collect, axis=0, ddof=1)

base["p_failed"] = base["p_failed_med"]
#base["p_failed_err"] = base["p_failed_iqr"]

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10.5,4))

for ax_this in ax:
    ax_this.set_xlim(17, 21)
    ax_this.set_xlabel(r"$r_o$")
ax[0].set_ylabel(r"$\mu_{r_o, {\rm eff}}$")
ax[1].set_ylabel(r"$(g-r)_o$")

t = Query(valid, ~failed_spec_aat_mmt).filter(base)
ax[0].scatter(t["r_mag"], t["sb_c"], c='C0', s=0.5, lw=0, rasterized=True, label=None)
ax[1].scatter(t["r_mag"], t["gr_c"], c='C0', s=0.5, lw=0, rasterized=True, label="Redshifts from AAT/MMT")


t = Query(valid, failed_spec_aat_mmt, C.high_priority_cuts, ~C.has_spec).filter(base)
t.sort("p_failed")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_failed"]), s=12, lw=0, cmap="autumn_r", vmin=np.log10(0.005), vmax=np.log10(0.4), rasterized=True, alpha=1, label="No redshifts yet", marker='X')
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_failed"]), s=12, lw=0, cmap="autumn_r", vmin=np.log10(0.005), vmax=np.log10(0.4), rasterized=True, alpha=1, label=None, marker='X')


t = Query(valid, failed_spec_aat_mmt, C.high_priority_cuts, C.has_spec).filter(base)
t.sort("p_failed")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_failed"]), s=12, lw=0, cmap="summer_r", vmin=np.log10(0.01), vmax=np.log10(0.4), rasterized=True, alpha=1, label="w/ single-slit redshifts")
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_failed"]), s=12, lw=0, cmap="summer_r", vmin=np.log10(0.01), vmax=np.log10(0.4), rasterized=True, alpha=1, label=None)

r = np.linspace(12, 21, 10)
ax[0].plot(r, 0.6*(r-14) + 17.8, c='grey', ls="--", label=None, lw=1)
ax[1].plot(r, -0.06*(r-14) + 0.9, c='grey', ls="--", label='Targeting cuts', lw=1)
ax[0].axvline(20.75, c='grey', ls="--", label=None, lw=1)
ax[1].axvline(20.75, c='grey', ls="--", label=None, lw=1)


lgnd = ax[0].legend(loc="lower right", scatterpoints=3, frameon=True, fontsize="small", title="Failed AAT/MMT targets")
lgnd.legendHandles[0]._sizes = [25]
lgnd.legendHandles[1]._sizes = [25]
lgnd.legendHandles[0].set_color(plt.cm.autumn_r([0.8, 0.5, 0.2]))
lgnd.legendHandles[1].set_color(plt.cm.summer_r([0.8, 0.5, 0.2]))

lgnd = ax[1].legend(loc="upper left", scatterpoints=3, frameon=True, fontsize="medium")
#lgnd.legendHandles[0]._sizes = [10]
lgnd.legendHandles[1]._sizes = [10]
#lgnd.legendHandles[2]._sizes = [20]


ax[0].set_ylim(19, 25)
ax[1].set_ylim(0, 1)
fig.tight_layout()

cbar = fig.colorbar(cs, ax=ax, pad=0.01, ticks=np.log10([0.01, 0.02, 0.05, 0.1, 0.2, 0.4]))
cbar.ax.set_yticklabels(['$< 0.01$', '$0.02$', '$0.05$', '$0.1$', '$0.2$', '$0.4$']) 
cbar.ax.minorticks_off()
cbar.ax.set_xlabel(r"      $p_{\rm fail}$")
plt.savefig('/home/yymao/Downloads/spec-failure.pdf', dpi=200)

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10.5,4))

for ax_this in ax:
    ax_this.set_xlim(17, 21)
    ax_this.set_xlabel(r"$r_o$")
ax[0].set_ylabel(r"$\mu_{r_o, {\rm eff}}$")
ax[1].set_ylabel(r"$(g-r)_o$")

t = Query(valid, ~failed_spec_aat_mmt).filter(base)
ax[0].scatter(t["r_mag"], t["sb_c"], c='C0', s=1, lw=0, rasterized=True, alpha=0.8, label=None)
ax[1].scatter(t["r_mag"], t["gr_c"], c='C0', s=1, lw=0, rasterized=True, alpha=0.8, label="Redshifts from AAT/MMT")


t = Query(valid, failed_spec_aat_mmt, C.high_priority_cuts, ~C.has_spec).filter(base)
t.sort("p_failed")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c='C3', s=10, lw=0, vmin=np.log10(0.005), vmax=np.log10(0.4), rasterized=True, alpha=0.8, label="No redshifts yet", marker='X')
ax[1].scatter(t["r_mag"], t["gr_c"], c='C3', s=10, lw=0, vmin=np.log10(0.005), vmax=np.log10(0.4), rasterized=True, alpha=0.8, label=None, marker='X')


t = Query(valid, failed_spec_aat_mmt, C.high_priority_cuts, C.has_spec).filter(base)
t.sort("p_failed")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c='C2', s=15, lw=0, cmap="summer_r", vmin=np.log10(0.01), vmax=np.log10(0.4), rasterized=True, alpha=1, label="w/ single-slit redshifts", marker="^")
ax[1].scatter(t["r_mag"], t["gr_c"], c='C2', s=15, lw=0, cmap="summer_r", vmin=np.log10(0.01), vmax=np.log10(0.4), rasterized=True, alpha=1, label=None, marker="^")

r = np.linspace(12, 21, 10)
ax[0].plot(r, 0.6*(r-14) + 17.8, c='grey', ls="--", label=None, lw=1)
ax[1].plot(r, -0.06*(r-14) + 0.9, c='grey', ls="--", label='Targeting cuts', lw=1)
ax[0].axvline(20.75, c='grey', ls="--", label=None, lw=1)
ax[1].axvline(20.75, c='grey', ls="--", label=None, lw=1)


lgnd = ax[0].legend(loc="lower right", scatterpoints=3, frameon=True, fontsize="small", title="Failed AAT/MMT targets")
lgnd.legendHandles[0]._sizes = [25]
lgnd.legendHandles[1]._sizes = [25]
#lgnd.legendHandles[0].set_color(plt.cm.autumn_r([0.8, 0.5, 0.2]))
#lgnd.legendHandles[1].set_color(plt.cm.summer_r([0.8, 0.5, 0.2]))

lgnd = ax[1].legend(loc="upper left", scatterpoints=3, frameon=True, fontsize="medium")
#lgnd.legendHandles[0]._sizes = [10]
lgnd.legendHandles[1]._sizes = [10]
#lgnd.legendHandles[2]._sizes = [20]


ax[0].set_ylim(19, 25)
ax[1].set_ylim(0, 1)
fig.tight_layout()

#cbar = fig.colorbar(cs, ax=ax, pad=0.01, ticks=np.log10([0.01, 0.02, 0.05, 0.1, 0.2, 0.4]))
#cbar.ax.set_yticklabels(['$< 0.01$', '$0.02$', '$0.05$', '$0.1$', '$0.2$', '$0.4$']) 
#cbar.ax.minorticks_off()
#cbar.ax.set_xlabel(r"      $p_{\rm fail}$")
plt.savefig('/home/yymao/Downloads/spec-failure.pdf', dpi=200)