In [None]:
from IPython.display import display, Image
import matplotlib as mpl
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy.io import fits
import numpy as np
import scipy as sp
import os
import subprocess

In [None]:
mpl.rcParams["figure.figsize"] = [30, 30]
mpl.rcParams["image.interpolation"] = "nearest"
mpl.rcParams["image.cmap"] = "hot"
mpl.rcParams["image.origin"] = "lower"

In [None]:
# outdir = os.path.join(os.getenv("SCRATCH"), "tests")
outdir = os.path.join("/global/cfs/cdirs/m4943/grismsim/visual_inspection")
baseline_tag = "v0.0.1-22-gfe96f98"

In [None]:
wfi_cen_ra = 10
wfi_cen_dec = 10
wfi_cen_pa = 60

In [None]:
def parse_times(f, tag, det):
    times = {"tag": tag,
             "det": det}
    
    meta = ["NStars", "NGals", "NPSFs"]

    for line in f.readlines():
        for m in meta:
            if m in line:
                times[m] = line.split(":")[1].strip()

        if line.startswith("Split"):
            key, val = line.split(":")
            times[key] = val.strip()
        
        elif len(kv := line.split("-")) == 2:
            times[kv[0].strip()] = kv[1].strip()

    return times          

In [None]:
tag_list = [baseline_tag, subprocess.check_output("git describe --tags", shell=True).decode().strip()]

fn_dict = {
    tag:
    {
        f"SCA{det_num:02}":
        {
            "grism": f"grism_ra{wfi_cen_ra}_dec{wfi_cen_dec}_pa{wfi_cen_pa}_detSCA{det_num:02}_{tag}.fits",
            "refimage": f"refimage_ra{wfi_cen_ra}_dec{wfi_cen_dec}_pa{wfi_cen_pa}_detSCA{det_num:02}_{tag}.fits",
            "timings": f"timings_for_grism_ra{wfi_cen_ra}_dec{wfi_cen_dec}_pa{wfi_cen_pa}_detSCA{det_num:02}_{tag}.txt"
        }
        for det_num in range(1, 19)
    }
    for tag in tag_list
}

In [None]:
for _, tag_fns in fn_dict.items():
    for _, det_fn in tag_fns.items():
        assert os.path.exists(os.path.join(outdir, det_fn["grism"])), f"{det_fn["grism"]} not found in {outdir}"
        assert os.path.exists(os.path.join(outdir, det_fn["refimage"])), f"{det_fn["refimage"]} not found in {outdir}"
        assert os.path.exists(os.path.join(outdir, det_fn["timings"])), f"{det_fn["timings"]} not found in {outdir}"

In [None]:
fig, ax = plt.subplots(18, 6, figsize=(40, 100), constrained_layout=True)

for jj, (tag, tag_fns) in enumerate(fn_dict.items()):
    jj *= 3
    for det, det_fn in tag_fns.items():

        ii = int(det[3:]) - 1

        if ii == 0:
            ax[ii][1 + jj].text(0.5, 1.075, tag, transform=ax[ii][1 + jj].transAxes, ha="center", va="bottom", fontsize=14)

        with fits.open(os.path.join(outdir, det_fn["grism"])) as f:
            ax[ii][0 + jj].imshow(f["SCI"].data, vmin=0, vmax=1)
            ax[ii][0 + jj].set_title(det)
            ax[ii][1 + jj].imshow(f["model"].data, vmin=0, vmax=0.25)
            ax[ii][1 + jj].set_title(det)

        with fits.open(os.path.join(outdir, det_fn["refimage"])) as f:
            ax[ii][2 + jj].imshow(f["IMAGE"].data, vmin=0, vmax=1)
            ax[ii][2 + jj].set_title(det)

In [None]:
for tag in tag_list:
    display(Image(filename=os.path.join(outdir, f"{tag}_footprint.png")))

In [None]:
img_list = []
fig, ax = plt.subplots(18, 6, figsize=(30, 35), constrained_layout=True)

for jj, (tag, tag_fns) in enumerate(fn_dict.items()):
    jj *= 3

    for det, det_fn in tag_fns.items():
        temp_dict = {}
        temp_dict["tag"] = tag
        temp_dict["det"] = det

        ii = int(det[3:]) - 1

        if ii == 0:
            ax[ii][1 + jj].text(0.5, 1.25, tag, transform=ax[ii][1 + jj].transAxes, ha="center", va="bottom", fontsize=12)

        with fits.open(os.path.join(outdir, det_fn["grism"])) as f:
            data = f["SCI"].data.ravel()
            sel = data > 0
            temp_dict["SCI"] = data[sel]
            
            ax[ii][0 + jj].hist(data, bins=1000)
            ax[ii][0 + jj].set_xlim(0, 1)
            ax[ii][0 + jj].set_title(det)

            data = f["model"].data.ravel()
            sel = data > 0
            temp_dict["model"] = data[sel]

            sel &= data < 0.1
            data = data[sel]

            ax[ii][1 + jj].hist(data, bins=100)
            ax[ii][1 + jj].set_xlim(0, 0.05)
            ax[ii][1 + jj].set_title(det)

        with fits.open(os.path.join(outdir, det_fn["refimage"])) as f:
            data = f["IMAGE"].data.ravel()
            sel = data > 0
            temp_dict["IMAGE"] = data[sel]

            sel &= data < 0.2
            data = data[sel]
            
            ax[ii][2 + jj].hist(data, bins=100)
            ax[ii][2 + jj].set_xlim(0, 0.1)
            ax[ii][2 + jj].set_title(det)
        
    img_list.append(temp_dict)

plt.show()

In [None]:
img_keys = ["SCI", "model", "IMAGE"]

img_tbl = Table(img_list)

for tag in tag_list:
    sel = img_tbl["tag"] == tag
    temp_table = img_tbl[sel]
    print("\033[0;36m" + tag)
    print("-----------" + "\033[0m")
    for key in img_keys:
        data = np.array(temp_table[key][0], dtype=np.float64)
        description = sp.stats.describe(data)

        print(key, "NObs", description.nobs)
        print(key, "min/max", description.minmax)
        print(key, "mean", description.mean)
        print(key, "variance", description.variance)
        print("-----------")

In [None]:
times = []
for tag, tag_fns in fn_dict.items():
    for det, det_fn in tag_fns.items():
        with open(os.path.join(outdir, det_fn["timings"])) as f:
            times.append(parse_times(f, tag, det))

timing_table = Table(times)
timing_table

In [None]:
time_keys = ["Split 0-1", "Split 1-2", "Split 2-3", "Split 3-4", "Split 4-5", "Split 5-6", "Split 6-7", "Split 7-8", 
             "PSF_grid_load", "star_PSF_eval", "star_placement", "star_spec_prep", "star_grism_sim", "gal_PSF_eval", 
             "gal_PSF_conv", "gal_placement", "gal_spec_prep", "gal_grism_sim"]

for tag in tag_list:
    sel = timing_table["tag"] == tag
    temp_table = timing_table[sel]
    print("\033[0;36m" + tag)
    print("-----------" + "\033[0m")
    for key in time_keys:
        data = np.array(temp_table[key], dtype=np.float64)
        description = sp.stats.describe(data)

        print(key, "NObs", description.nobs)
        print(key, "min/max", description.minmax)
        print(key, "mean", description.mean)
        print(key, "variance", description.variance)
        print("-----------")