In [None]:
import sys
sys.path.insert(0, "/home/yymao/Dropbox/Academia/Collaborations/SAGA-local/saga")

In [None]:
import SAGA
print(SAGA.__version__)

from SAGA import ObjectCuts as C
from easyquery import *
from SAGA.utils import *
from SAGA.targets import add_cut_scores
from SAGA.objects.object_catalog import calc_fiducial_p_sat
from astropy.table import vstack, Table, join
from astropy.coordinates import Distance

import numexpr as ne
from scipy.stats import binned_statistic
from scipy.stats import iqr

In [None]:
from statsmodels.discrete.discrete_model import Logit

class MyModel(Logit):
    def __init__(self, endog, exog, **kwargs):
        self._max_value = float(kwargs.get("max_value", 1))
        if not (self._max_value > 0 and self._max_value <= 1):
            raise ValueError
        super(MyModel, self).__init__(endog, exog, **kwargs)
        
    def _Lambda(self, X):
        return 
        
    def cdf(self, X):
        return self._max_value * super(MyModel, self).cdf(X)
    
    def pdf(self, X):
        return self._max_value * super(MyModel, self).pdf(X)
    
    def loglike(self, params):
        return np.sum(self.loglikeobs(params))
        
    def loglikeobs(self, params):
        y = self.endog
        X = self.exog
        L = self.cdf(np.dot(X, params))
        return np.log(L) * y + np.log1p(-L) * (1-y)
    
    def hessian(self, params):
        X = self.exog
        L = super(MyModel, self).cdf(np.dot(X,params))
        return -np.dot(self._max_value*L*(1-L)*X.T,X)

In [None]:
def generate_label_arr(data, valid_query, *value_setters):
    
    label_arr = np.empty(len(data))
    label_arr.fill(-1.0)
    label_arr[Query(valid_query).mask(data)] = 0

    for q, setter in value_setters:
        q_this = Query(valid_query, q)
        if callable(setter):
            value = setter(q_this.filter(data))
        else:
            value = setter
        label_arr[q_this.mask(data)] = value
    
    return label_arr

class LogitFit:
    def __init__(self, data, label_arr, feature_cols=("r_mag", "sb_r", "gr"), max_prob=1, add_prob_column=None, silent=False):
        valid_label_mask = (label_arr >= 0) & (label_arr <= 1)       
        self._feature_cols = list(feature_cols)
        feature_arr = get_feature_array(data, self._feature_cols, True)
        logit_mod = MyModel(label_arr[valid_label_mask], feature_arr[valid_label_mask], max_value=max_prob)
        self.logit_res = logit_mod.fit(disp=False)
        if add_prob_column:
            if add_prob_column is True:
                add_prob_column = "p_sat"
            data[add_prob_column] = self.calc_prob(feature_arr, False)
        if not silent:
            display(self.logit_res.summary())

    def calc_prob(self, data, convert_arr_format=True):
        feature_arr = get_feature_array(data, self._feature_cols, True) if convert_arr_format else data
        return self.logit_res.predict(feature_arr)

In [None]:
saga = SAGA.QuickStart()

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_all.fits").read()
hosts = np.unique(base["HOSTID"])
print(len(hosts))

In [None]:
def get_feature_array(table, cols, add_constant=False):
    data = [table[col].astype(np.float64) for col in cols]
    if add_constant:
        data.append(np.ones(len(table), dtype=np.float64))
    return np.vstack(data).T 

In [None]:
base = add_cut_scores(base)
base['Mr_est'] = base['r_mag'] - Distance(base['HOST_DIST'][0], unit="Mpc").distmod.value

In [None]:
base["gr_c"] = np.where(
    C.valid_g_mag.mask(base), 
    base["gr"],
    np.where(
        C.valid_i_mag.mask(base), 
        base["ri"]+0.1, 
        np.where(C.valid_z_mag.mask(base), base["rz"]+0.2, 0.92-0.03*base["r_mag"]),
    ),
)

base["sb_c"] = np.where(
    C.valid_sb.mask(base), 
    base["sb_r"],
    20+0.6*(base["r_mag"]-14),
)

In [None]:
HostQuery = saga.host_catalog.construct_host_query

In [None]:
sup = np.testing.suppress_warnings()
sup.filter(message="invalid value encountered in true_divide")  # module must match exactly

@sup
def plot_prob_hist(ax, data, col_x, bins, additonal_label=None, label_query=C.is_sat):
    ys = [label_query.mask(data), np.ones(len(data), dtype=np.int), (data['p_sat'] if 'p_sat' in data.colnames else np.zeros(len(data)))]
    s = binned_statistic(data[col_x], ys, bins=bins, statistic='sum')
    
    dn = np.where(s.statistic[1], s.statistic[1], 1)
    p = s.statistic[0] / dn
    e = np.maximum(np.sqrt(s.statistic[0] * (s.statistic[1] - s.statistic[0])), 1) * 0.84 / (dn**1.5)
    eh = np.where(e+p <= 1, e, 1-p)
    el = np.where(p-e >= 0, e, p)
    x = (s.bin_edges[1:] + s.bin_edges[:-1]) * 0.5
    
    ax.bar(x, p, s.bin_edges[1:] - s.bin_edges[:-1], alpha=0.6, color='C0');
    ax.errorbar(x, p, [el, eh], ls='', c='C0');
    
    if 'p_sat' in data.colnames:
        ax.plot(x, s.statistic[2] / s.statistic[1], c='C1')
    ax.set_xlabel(col_x + (", {}".format(additonal_label) if additonal_label else ""))
    
    ax.set_ylim(0, max(np.nanmax(s.statistic[2] / s.statistic[1]), np.nanmax(p))*1.1)

In [None]:
valid = Query(
    HostQuery("really_need_spec < 20"), 
    "r_mag > 12.15", 
    C.faint_end_limit, 
    C.high_priority_cuts, 
    Query(C.has_spec, "SPEC_Z >= 0.003") | (~C.has_spec),
)

In [None]:
prob_sat_collect = []
params = []

feature_cols = ("r_mag", "sb_c", "gr_c")
valid_base = valid.filter(base, list(feature_cols) + ["SATS"])
n_valid = len(valid_base)

for i in range(1000):
    base_this = valid_base[np.random.randint(n_valid, size=n_valid)]
    model_sat = LogitFit(
        base_this, 
        generate_label_arr(
            base_this, 
            Query(),
            (C.is_sat, 1),
        ), 
        max_prob=0.5,
        feature_cols=feature_cols,
        add_prob_column=False,
        silent=True,
    )
    params.append(model_sat.logit_res.params.copy())
    prob_sat_collect.append(model_sat.calc_prob(base))

prob_sat_collect = np.array(prob_sat_collect)

base["p_sat_med"] = np.median(prob_sat_collect, axis=0)
base["p_sat_mean"] = np.mean(prob_sat_collect, axis=0)
base["p_sat_iqr"] = iqr(prob_sat_collect, axis=0, scale="normal")
base["p_sat_std"] = np.std(prob_sat_collect, axis=0, ddof=1)

base["p_sat"] = base["p_sat_med"]
base["p_sat_err"] = base["p_sat_iqr"]

In [None]:
from scipy.stats import norm
base["p_sat_84"], base["p_sat_16"] = np.percentile(prob_sat_collect, norm.sf([-1, 1])*100.0, axis=0)

In [None]:
params = np.array(params)
params.mean(axis=0), np.median(params, axis=0)

In [None]:
plt.scatter(params[:,2], params[:,3], s=3)
mask = (np.abs(params[:,0] + 1.68) < 0.04) & (np.abs(params[:,1] - 1.25) < 0.04) & (np.abs(params[:,2] + 5.5) < 0.05)
plt.scatter(params[mask,2], params[mask,3], s=3)
print(np.median(params[mask,:], axis=0))
#plt.xlim(-3, -0.5)
#plt.ylim(-10, 0)

In [None]:
X = []
for hostid in hosts:
    t = QueryMaker.equals("HOSTID", hostid).filter(base)
    t1 = Query(C.high_priority_cuts, C.faint_end_limit, ~C.has_spec).filter(t)
    X.append((
        hostid,
        Query(C.is_sat, C.faint_end_limit).count(t),
        Query(C.has_spec, C.faint_end_limit).filter(t, "p_sat_med").sum(),
        len(t1),
        t1["p_sat_approx"].sum(),
        t1["p_sat_med"].sum(),
        t1["p_sat_84"].sum(),
        t1["p_sat_16"].sum(),
    ))

X = Table(np.array(X, np.dtype([
    ('HOSTID', '<U10'), 
    ('sats_r_limit', np.int), 
    ('sats_r_limit_pred', np.float),
    ('really_need_spec', np.int), 
    ('sats_missed_approx', np.float),
    ('sats_missed_pred', np.float),
    ('sats_missed_84', np.float),
    ('sats_missed_16', np.float),
])))

X.write("/home/yymao/Downloads/saga_sat_prob.csv", overwrite=True)

In [None]:
plt.errorbar(X["sats_missed_approx"], X["sats_missed_pred"], yerr=[X["sats_missed_pred"]-X["sats_missed_16"], X["sats_missed_84"]-X["sats_missed_pred"]], marker='.', ls='')
plt.loglog([0.005, 3], [0.005, 3], lw=1)
plt.axvline(0.32, c='grey', lw=1)
plt.axhline(0.35, c='grey', lw=1)


In [None]:
X = X[np.lexsort([X["sats_missed_pred"], X["really_need_spec"]])]

In [None]:
dx = np.ones(len(X))
dx -= 0.01
dx[::2] += 0.02
X["really_need_spec_dx"] = np.maximum(X["really_need_spec"], 1) * dx

In [None]:
plt.figure(figsize=(5, 4))

p_cut = 0.35
n_cut = 35.5

complete_def_approx = saga.paper2_complete
complete_def = Query(f"sats_missed_84 < {p_cut}", f"really_need_spec < {n_cut}")
p1_complete= HostQuery("paper1_complete")
#complete_def=complete_def_approx

#assert set(complete_def_approx.filter(X, "HOSTID")) == set(complete_def.filter(X, "HOSTID"))
assert Query(p1_complete, ~complete_def).count(X) == 0

for q, color, label in (
    (~complete_def, 'C0', "Hosts to be completed"),
    (Query(complete_def, ~p1_complete), 'C1', 'Newly complete hosts ({})'),
    (Query(p1_complete), 'C3', 'Paper I complete hosts ({})'),
):
    X1 = q.filter(X)
    plt.errorbar(X1["really_need_spec_dx"], X1["sats_missed_pred"], 
                 yerr=[X1["sats_missed_pred"]-X1["sats_missed_16"], X1["sats_missed_84"]-X1["sats_missed_pred"]], marker='o', ls='', 
                 label=label.format(len(X1)), ms=3, alpha=0.85, color=color)
    #plt.plot(X1["really_need_spec_dx"], X1["sats_missed_approx"], marker='x', ls='', ms=3, alpha=0.85, color=color)


plt.axhline(p_cut, lw=1, ls="--", c='grey')
plt.axvline(n_cut, lw=1, ls="--", c='grey')

plt.legend(loc="lower right", markerfirst=False, handletextpad=-0.25, bbox_to_anchor=(1.07, -0.02))
plt.xscale("symlog", linthreshx=100, subsx=[2,3,4,5,6,7,8,9], linscalex=np.log10(20))
plt.xscale("log")
plt.xlim(1, None)
plt.yscale("log")
plt.xlabel("Number of targets without redshift")
plt.ylabel("Expected number of satellites\nmissed in incomplete targets")

plt.tight_layout()
plt.savefig("/home/yymao/Downloads/host_completeness_def.pdf")

In [None]:
complete_def.filter(X, "sats_missed_84").sum()

In [None]:
hid = Query(f"sats_missed_84 >= {p_cut}", f"really_need_spec < {n_cut}").filter(X, "HOSTID")
Query(QueryMaker.in1d("HOSTID", hid), ~C.has_spec, "p_sat_med > 0.05", C.basic_cut2, C.high_priority_cuts).filter(base, ["OBJID", "RA", "DEC", "HOSTID", "r_mag", "p_sat_med", "p_sat_approx"])

In [None]:
hid = complete_def.filter(X, "HOSTID")
sats_gr = Query('Mr_est < -12.3', C.basic_cut2, C.valid_g_mag, (C.is_sat | Query("p_sat_approx > 0.1", ~C.has_spec)), QueryMaker.in1d("HOSTID", hid)).filter(base)
sats_gr.sort('RHOST_KPC')
rbins = np.linspace(0, 300, 5)

bin_ed = np.searchsorted(sats_gr['RHOST_KPC'], rbins)
gr = np.array([np.percentile((sats_gr['gr'][i:j]), [16, 50, 84]) for i, j in zip(bin_ed[:-1], bin_ed[1:])])

sats_gr_real = C.is_sat.filter(sats_gr)
plt.errorbar(sats_gr_real['RHOST_KPC'], sats_gr_real['gr'], sats_gr_real['gr_err'], marker='o', ls='', ms=6, mec='None', label='SAGA satellites')

sats_gr_maybe = (~C.is_sat).filter(sats_gr)
plt.errorbar(sats_gr_maybe['RHOST_KPC'], sats_gr_maybe['gr'], sats_gr_maybe['gr_err'], marker='o', ls='', ms=6, mec='None', alpha=0.4, label='maybe satellites')

plt.errorbar(0.5*(rbins[:-1]+rbins[1:]), gr[:,1], (gr[:,1]-gr[:,0], gr[:,2]-gr[:,1]), ls='-', c='C2', marker='', lw=3, alpha=0.5, label='median')
plt.ylabel(r'$(g-r)_o$')
plt.xlabel(r'$d_{\rm proj}$ [kpc]')
plt.xlim(0, 305)
plt.ylim(0, 1)
plt.legend(frameon=True)
#plt.yscale('log')
#plt.axhline(0, c='k', lw=0.5)
plt.tight_layout()
#plt.savefig('/home/yymao/Downloads/sat_ew.pdf', bbox_inches='tight')

In [None]:
hid = complete_def.filter(X, "HOSTID")
fig, ax = plt.subplots(nrows=5, ncols=3, figsize=(12, 15))
ax = ax.flatten()
base["p_sat"] = base["p_sat_approx"]
data = Query(C.has_spec, QueryMaker.in1d("HOSTID", hid)).filter(base)

plot_prob_hist(ax[0], data, 'r_mag', np.linspace(12.5, 20.75, 11))
plot_prob_hist(ax[1], data, 'sb_r', np.linspace(19, 25, 11))
plot_prob_hist(ax[2], data, 'gr', np.linspace(0.1, 0.9, 11))
plot_prob_hist(ax[3], data, 'score_sb_r', np.linspace(17.5, 22.5, 11))
plot_prob_hist(ax[4], data, 'score_gr_r', np.linspace(0.1, 1, 11))
plot_prob_hist(ax[5], data, 'p_sat', np.linspace(0, 0.5, 11))

plot_prob_hist(ax[6], data, 'score_ri_r', np.linspace(-0.05, 0.7, 11))
plot_prob_hist(ax[7], data, 'score_rz_r', np.linspace(-0.25, 1, 11))
plot_prob_hist(ax[8], data, 'ug', np.linspace(0.5, 2, 11))

data_ = Query("r_mag < 17.7").filter(data)
plot_prob_hist(ax[9], data_, 'p_sat', np.linspace(0, 0.5, 11), "r < 17.7")
plot_prob_hist(ax[10], data_, 'sb_r', np.linspace(19, 25, 11), "r < 17.7")
plot_prob_hist(ax[11], data_, 'gr', np.linspace(0.1, 0.9, 11), "r < 17.7")

data_ = Query("r_mag >= 17.7").filter(data)
plot_prob_hist(ax[12], data_, 'p_sat', np.linspace(0, 0.5, 11), "r >= 17.7")
plot_prob_hist(ax[13], data_, 'sb_r', np.linspace(19, 25, 11), "r >= 17.7")
plot_prob_hist(ax[14], data_, 'gr', np.linspace(0.1, 0.9, 11), "r >= 17.7")

fig.tight_layout()

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10.5,4))

for ax_this in ax:
    ax_this.set_xlim(12, 21)
    ax_this.set_xlabel(r"$r_o$")
ax[0].set_ylabel(r"$\mu_{r_o, {\rm eff}}$")
ax[1].set_ylabel(r"$(g-r)_o$")

cut = Query(C.faint_end_limit, C.high_priority_cuts)
t = Query(~cut).filter(base)
ax[0].scatter(t["r_mag"], t["sb_c"], c='grey', s=0.1, lw=0, rasterized=True, label="Non-target galaxies")
ax[1].scatter(t["r_mag"], t["gr_c"], c='grey', s=0.1, lw=0, rasterized=True, label="Non-target galaxies")


t = Query(cut, ~C.has_spec).filter(base)
t.sort("p_sat")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_sat"]), s=10, lw=0, cmap="winter_r", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.85, label="Target galaxies", marker='X')
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_sat"]), s=10, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.8, label="Targets (no redshift)", marker='X')


t = Query(cut, C.has_spec).filter(base)
t.sort("p_sat")
cs = ax[0].scatter(t["r_mag"], t["sb_c"], c=np.log10(t["p_sat"]), s=7, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.85, label="Target galaxies")
ax[1].scatter(t["r_mag"], t["gr_c"], c=np.log10(t["p_sat"]), s=7, lw=0, cmap="winter", vmin=-3, vmax=np.log10(0.5), rasterized=True, alpha=0.8, label="Targets (w/ redshift)")


#ax[0].plot([14, 21], [14*0.6+12, 21*0.6+12])
#ax[0].axhline(24.75)
#ax[0].axvline(17.77)

lgnd = ax[1].legend(loc="upper left", scatterpoints=3, frameon=True)
lgnd.legendHandles[0]._sizes = [10]
lgnd.legendHandles[1]._sizes = [25]
lgnd.legendHandles[2]._sizes = [20]
lgnd.legendHandles[1].set_color(plt.cm.winter([0.8, 0.5, 0.2]))
lgnd.legendHandles[2].set_color(plt.cm.winter([0.8, 0.5, 0.2]))
ax[0].set_ylim(16, 26)
ax[1].set_ylim(-0.3, 2)
fig.tight_layout()

cbar = fig.colorbar(cs, ax=ax, pad=0.01, ticks=np.log10([0.001, 0.01, 0.1, 0.5]))
cbar.ax.set_yticklabels(['$< 10^{-3}$', '$0.01$', '$0.1$', '$0.5$']) 
cbar.ax.set_xlabel(r"  $p_{\rm sat}$")
plt.savefig('/home/yymao/Downloads/sat-prob.pdf', dpi=200)