In [None]:
import numpy as np
import matplotlib.pyplot as plt

from itertools import combinations
from collections import Counter, defaultdict
from astropy.table import Table
from astropy.coordinates import SkyCoord, Distance
from SAGA.utils import add_skycoord
from easyquery import Query, QueryMaker

import SAGA
from SAGA.database import FitsTable
from SAGA import ObjectCuts as C
from SAGA.utils import add_skycoord, fill_values_by_query
from SAGA.utils.distance import z2v

print(SAGA.__version__)

In [None]:
saga = SAGA.QuickStart()
saga.set_default_base_version("paper2")

In [None]:
base = SAGA.database.FitsTable("/home/yymao/Downloads/saga_base_all.fits").read()
base = saga.host_catalog.construct_host_query("paper2_complete").filter(base)
assert len(np.unique(base["HOSTID"])) == 36

sats = Query(C.is_sat).filter(base)
del base

In [None]:
assert len(sats) == 127
assert C.r_abs_limit.count(sats) == 123
assert Query("EW_Halpha < 2").count(sats) == 18
assert Query("EW_Halpha < 2", C.r_abs_limit).count(sats) == 18
assert Query("EW_Halpha < 2", "Mr < -16").count(sats) == 2
assert len(np.unique(sats["HOSTID"])) == 34

In [None]:
hostlist = saga.host_catalog.load(include_stats=True, query="paper2_complete")["HOSTID", "sats_Mr_limit", "sats_total", "COMMON_NAME"]

hostlist["brightest_sat"] = [
    QueryMaker.equals("HOSTID", h).filter(sats, "Mr").min() if n else 0 for h, n in zip(hostlist["HOSTID"], hostlist["sats_total"])
]

hostlist["sats_neg"] = -hostlist["sats_Mr_limit"]
hostlist["sats_neg_total"] = -hostlist["sats_total"]
hostlist.sort(["sats_neg", "sats_neg_total", "brightest_sat"])

In [None]:
counts = Counter(sats["HOSTID"])
print(len(counts), sum(counts.values()), max(counts.values()))

assert all(n == counts[h] for h, n in zip(hostlist["HOSTID"], hostlist["sats_total"]))

In [None]:
color_dict = dict(zip(hostlist["HOSTID"], plt.cm.rainbow_r(np.linspace(0, 1, len(hostlist)))))
#np.savez("/home/yymao/Dropbox/Academia/Collaborations/SAGA/Temp/paper2-preliminary-data-products/color_dict.npz", **color_dict)

In [None]:
ncols = 12
assert ncols >= max(counts.values())

d = defaultdict(list)
for host, count in zip(hostlist["HOSTID"], hostlist["sats_total"]):
    if count:
        d[count].append(host)
    
rows = []
for count, hosts in d.items():
    while hosts:
        row = [(hosts.pop(0), count)]
        space = ncols - count
        n_now = space
        while space and n_now:
            if n_now in d and d[n_now]:
                row.append((d[n_now].pop(0), n_now)) 
                space -= n_now
                n_now = space
            else:
                n_now -= 1    
        rows.append(row)
        
nrows = len(rows)
print(nrows)

In [None]:
from PIL import Image
import requests
from io import BytesIO
import os

params_dict = dict(pixscale=0.2, size=200, layer="dr8")
base_url = "http://legacysurvey.org/viewer-dev/jpeg-cutout/"
cache_dir = "/home/yymao/Dropbox/Academia/Collaborations/SAGA/Images/PaperII/satellites"

def load_image(*args, **kwargs):
    response = requests.get(*args, **kwargs)
    return response.content

def get_satellite_image(sat):
    
    fname = f"{sat['survey']}-{sat['OBJID']}.jpg"
    fpath = os.path.join(cache_dir, fname)
    try:
        img = open(fpath, "rb").read()
    except FileNotFoundError:    
        params_dict["ra"] = sat["RA"]
        params_dict["dec"] = sat["DEC"]
        params_dict["layer"] = "des-dr1" if sat["OBJID"] == 467510764 else "dr8"
        img = load_image(base_url, params_dict)
        open(fpath, "wb").write(img)
    return Image.open(BytesIO(img))

In [None]:
from matplotlib.patches import FancyBboxPatch, Polygon

def add_box(ax, x, y, width, height, color):
    p_bbox = FancyBboxPatch((x, y),
                        width, height,
                        boxstyle="round,pad=0,rounding_size=0.01",
                        ec=color, fc="none", zorder=10., linewidth=2,
                        )
    ax.add_patch(p_bbox)
    
def add_dogear(ax, pos, color, width=0.25):
    if pos == "left":
        xy = np.array([[0, 1-width], [0, 1], [width, 1]])
    else:
        xy = np.array([[1-width, 1], [1, 1], [1, 1-width]])
    
    tri = Polygon(xy, fc=color, zorder=10., linewidth=0, transform=ax.transAxes)
    ax.add_patch(tri)

In [None]:
xpad = 0.18
ypad = 0.38
lpad = 0.06

total_width = ncols + xpad * (ncols+1)
total_height = nrows + ypad * (nrows+1)

def calc_location(row, col):
    return xpad + col*(1+xpad), ypad + (nrows-row-1)*(1+ypad)

fig = plt.figure(figsize=(10, 10*total_height/total_width), linewidth=0, constrained_layout=False, tight_layout=False)



ax_f = plt.Axes(fig, [0,0,1,1], label="all")
ax_f.set_axis_off()
ax_f.set_facecolor("none")
fig.add_axes(ax_f)

for irow, row in enumerate(rows):
    icol = 0
    for host, count in row:
        sats_this = QueryMaker.equals("HOSTID", host).filter(sats)
        sats_this.sort("r_mag")
        box_start = calc_location(irow, icol)
        new_host = True
        for sat in sats_this:
            x, y = calc_location(irow, icol)
            ax = plt.Axes(fig, [x/total_width, y/total_height, 1/total_width, 1/total_height], label=f"{irow}-{icol}")
            ax.set_axis_off()
            fig.add_axes(ax)
            img = get_satellite_image(sat)
            ax.imshow(img, aspect="auto")
            if sat["EW_Halpha"] < 2:
                add_dogear(ax, "left", "C3")
            if sat["Mr"] >= -12.295:
                add_dogear(ax, "right", "C9")
            if new_host:
                ax.text(-5, -ypad*50, sat["HOST_COMMON_NAME"], fontweight="heavy", fontsize=13)
                new_host = False
            icol += 1
        x = box_start[0] - lpad
        y = box_start[1] - lpad
        dx = count + (count-1)*xpad + lpad*2 
        dy = 1 + lpad*2
        add_box(ax_f, x/total_width, y/total_height, dx/total_width, dy/total_height, color=color_dict[host])

        
for host in hostlist["HOSTID"][-2:]:
    x, y = calc_location(irow, icol)
    box_start = calc_location(irow, icol)
    ax = plt.Axes(fig, [x/total_width, y/total_height, 1/total_width, 1/total_height], label=f"{irow}-{icol}")
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.text(-5/200, 1+ypad*50/200, QueryMaker.equals("HOSTID", host).filter(hostlist, "COMMON_NAME")[0], fontweight="heavy", fontsize=13)
    x = box_start[0] - lpad
    y = box_start[1] - lpad
    dx = count + (count-1)*xpad + lpad*2 
    dy = 1 + lpad*2
    add_box(ax_f, x/total_width, y/total_height, dx/total_width, dy/total_height, color=color_dict[host])
    icol += 1

icol += 1
for text in ("Quenched", "Below\n$M_{r,o}$ limit"):
    x, y = calc_location(irow, icol)
    ax = plt.Axes(fig, [x/total_width, y/total_height, 1/total_width, 1/total_height], label=f"{irow}-{icol}")
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow([[0]], cmap="gray_r", vmin=-10, vmax=50)
    ax.text(0, 0, text, ha="center", va="center", fontsize=13)
    if text == "Quenched":
        add_dogear(ax, "left", "C3")
    else:
        add_dogear(ax, "right", "C9")
    icol += 1
"""
x, y = calc_location(irow, icol)
ax = plt.Axes(fig, [x/total_width, y/total_height, 1/total_width, 1/total_height], label=f"{irow}-{icol}")
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow([[0]], cmap="gray_r", vmin=-10, vmax=50)
add_dogear(ax, "left", "C3")
add_dogear(ax, "right", "C9")
ax.text(-0.49, -0.2, "Quenched", ha="left", va="center", fontsize=11, color="C3", fontweight="bold")
ax.text(0.49, 0.15, "↑\nBelow\n$M_r$ limit ", ha="right", va="center", fontsize=11, color="C9", fontweight="bold")
"""

plt.savefig("/home/yymao/Downloads/saga-all-sats.pdf", bbox_inches="tight")