In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob

import plotly.express as px
import plotly.io as pio

pio.renderers.default = "iframe"

from saturation.utils import *

In [None]:
n_cores = 28

spark = (SparkSession.builder
         .master(f"local[{n_cores}]")
         .appName("Saturation")
         .config("spark.sql.shuffle.partitions", "500")
         .config("spark.driver.memory", "30g")
         .config("spark.driver.maxResultSize", "8g")
         .getOrCreate())

In [None]:
base_path = "/data/saturation/thesis_run_20240726"

In [None]:
configs_df = create_configs_df(read_configs(base_path, spark))
configs_pdf = configs_df.toPandas()

In [None]:
# For stats, select only fields we care about, alias them
sample_percent = 0.125

data = spark.read.parquet(f"{base_path}/*/statistics_*.parquet")
data = data.sample(sample_percent).select(
    "simulation_id",
    "ntot",
    "nobs",
    "mnnd",
    "z",
    "za",
    "radius_mean",
    "radius_stdev",
    F.col("areal_density").alias("ad"),
    F.log10("mnnd").alias("log_mnnd"),
    F.log10("nobs").alias("log_nobs"),
    F.log10("ntot").alias("log_ntot"),
)

# Join with configs to get simulation parameters
data_joined = join_configs(data, configs_df, spark).select(
    *data.columns,
    "simulation_id",
    F.col("slope").alias("b"),
    "mrp",
    "erat",
    "rmult",
)

data = data_joined.sort(F.rand()).cache()

In [None]:
configs_dict = dict()
for config_file in glob.glob(f"{base_path}/config/config_*.yaml"):
    for run_config in read_config(Path(config_file))["run_configurations"]:
        configs_dict.update(run_config)

In [None]:
min_b = round(data.select(F.min("b")).collect()[0][0], 2)
max_b = round(data.select(F.max("b")).collect()[0][0], 2)
min_b, max_b

## Plots of each variable by ntot

### Nobs vs ntot

In [None]:
def plot_overall(
    *,
    data: DataFrame,
    x_var: str,
    x_label: str,
    y_var: str,
    y_label: str,
    fig_name: str,
    x_axis_range: Tuple[float, float]=None,
    y_axis_range: Tuple[float, float]=None,
    n_target: int=None,
    color_by_var: str=None,
    color_by_label: str=None,
    color_bar_min: float=0.0,
    color_bar_max: float=0.0,
    min_ntot: int=100,
    title: str=None
):
    FONT_SIZE = 16
    FIGURE_SIZE = (8, 4)

    select_list = [x_var, y_var]
    if color_by_var:
        select_list.append(color_by_var)
    df_subset_spark = (
        data.where(F.col("ntot") > min_ntot)
            .select(*select_list)
    )
    if n_target:
        subset_sample_fraction = n_target / df_subset_spark.count()
        df_subset = df_subset_spark.sample(subset_sample_fraction).toPandas()
    else:
        df_subset = df_subset_spark.toPandas()
    
    fig = plt.figure(figsize=FIGURE_SIZE)
    ax = fig.add_subplot(111)

    if x_axis_range:
        ax.set_xlim(x_axis_range)

    if y_axis_range:
        ax.set_ylim(y_axis_range)

    if title:
        ax.set_title(title)

    ax.scatter(
        df_subset[x_var],
        df_subset[y_var],
        s=.25,
        c=df_subset[color_by_var] if color_by_var else None
    )

    if color_by_label:
        sm = plt.cm.ScalarMappable(cmap=plt.colormaps["cividis"])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label(label=color_by_label, size=FONT_SIZE)
        sm.set_clim(vmin=color_bar_min, vmax=color_bar_max)
    
    ax.set_xlabel(x_label, fontsize=FONT_SIZE)
    ax.set_ylabel(y_label, fontsize=FONT_SIZE)

    plt.savefig(f"figures/{fig_name}.png")
    plt.show()

    return ax, fig

In [None]:
data.limit(10).toPandas()

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="nobs",
    x_label="$N_{obs}$",
    y_var="ntot",
    y_label="$N_{tot}$",
    fig_name="nobs_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="log_nobs",
    x_label="$log_{10}(N_{obs})$",
    x_axis_range=(2.0, 4.5),
    y_var="log_ntot",
    y_label="$log_{10}(N_{tot})$",
    y_axis_range=(2.0, 5.7),
    fig_name="log_nobs_by_log_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
steps = 15
b_delta = (max_b - min_b) / steps
b_ranges = [
    (min_b + x * b_delta, min_b + (x + 1) * b_delta)
    for x in range(steps)
]

for low, high in b_ranges:
    print((low, high))
    _ = plot_overall(
        data=data.where(F.col("b").between(low, high)),
        n_target=15000,
        x_var="log_nobs",
        x_label="$log_{10}(N_{obs})$",
        x_axis_range=(2.0, 4.5),
        y_var="log_ntot",
        y_label="$log_{10}(N_{tot})$",
        y_axis_range=(2.0, 5.7),
        fig_name=f"log_nobs_by_log_ntot_b_{low:.1f}_{high:.1f}",
        color_by_var="erat",
        color_by_label="$E_{rat}$",
        color_bar_min=3.0,
        color_bar_max=15.0,
    )

### Nobs vs mnnd

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="mnnd",
    x_label="$\\overline{NN}_d$",
    y_var="ntot",
    y_label="$N_{tot}$",
    fig_name="mnnd_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="log_mnnd",
    x_label="$log_{10}(\\overline{NN}_d)$",
    y_var="log_ntot",
    y_label="$log_{10}(N_{tot})$",
    fig_name="log_mnnd_by_log_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
steps = 15
b_delta = (max_b - min_b) / steps
b_ranges = [
    (min_b + x * b_delta, min_b + (x + 1) * b_delta)
    for x in range(steps)
]

for low, high in b_ranges:
    print((low, high))
    _ = plot_overall(
        data=data.where(F.col("b").between(low, high)),
        n_target=50000,
        x_var="log_mnnd",
        x_label="$log_{10}(\\overline{NN}_d)$",
        x_axis_range=(0.8, 2.1),
        y_var="log_ntot",
        y_label="$log_{10}(N_{tot})$",
        y_axis_range=(2.0, 5.6),
        fig_name=f"log_mnnd_by_log_ntot_b_{low:.1f}_{high:.1f}",
        color_by_var="erat",
        color_by_label="$E_{rat}$",
        color_bar_min=3.0,
        color_bar_max=15.0,
    )

### Z and ZA by ntot

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="ntot",
    x_label="$N_{tot}$",
    y_var="z",
    y_label="$Z$",
    fig_name="z_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
steps = 15
b_delta = (max_b - min_b) / steps
b_ranges = [
    (min_b + x * b_delta, min_b + (x + 1) * b_delta)
    for x in range(steps)
]

for low, high in b_ranges:
    print((low, high))
    _ = plot_overall(
        data=data.where(F.col("b").between(low, high)),
        n_target=50000,
        x_var="ntot",
        x_label="$N_{tot}$",
        # x_axis_range=(0, 200000),
        y_var="z",
        y_label="$Z$",
        # y_axis_range=(-15, 30),
        fig_name=f"z_by_ntot_b_{low:.1f}_{high:.1f}",
        # color_by_var="mrp",
        # color_by_label="$M_r$",
        # color_bar_min=0.25,
        # color_bar_max=0.75
        color_by_var="erat",
        color_by_label="$E_{rat}$",
        color_bar_min=3.0,
        color_bar_max=15.0,
    )

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="ntot",
    x_label="$N_{tot}$",
    y_var="za",
    y_label="$Z_a$",
    fig_name="za_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
steps = 15
b_delta = (max_b - min_b) / steps
b_ranges = [
    (min_b + x * b_delta, min_b + (x + 1) * b_delta)
    for x in range(steps)
]

for low, high in b_ranges:
    print((low, high))
    _ = plot_overall(
        data=data.where(F.col("b").between(low, high)),
        n_target=50000,
        x_var="ntot",
        x_label="$N_{tot}$",
        # x_axis_range=(0, 200000),
        y_var="za",
        y_label="$Z_a$",
        # y_axis_range=(0, 500),
        fig_name=f"za_by_ntot_b_{low:.1f}_{high:.1f}",
        color_by_var="erat",
        color_by_label="$E_{rat}$",
        color_bar_min=3.0,
        color_bar_max=15.0,
    )

### AD by ntot

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="log_ntot",
    x_label="$log_{10}(N_{tot})$",
    y_var="ad",
    y_label="$A_d$",
    fig_name="ad_by_log_ntot_overall_rmult",
    color_by_var="rmult",
    color_by_label="$b$",
    color_bar_min=1.1,
    color_bar_max=1.9
)

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="log_ntot",
    x_label="$log_{10}(N_{tot})$",
    y_var="ad",
    y_label="$A_d$",
    fig_name="ad_by_log_ntot_overall_mrp",
    color_by_var="mrp",
    color_by_label="$M_r$",
    color_bar_min=0.25,
    color_bar_max=0.75
)

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    x_var="log_ntot",
    x_label="$log_{10}(N_{tot})$",
    y_var="ad",
    y_label="$A_d$",
    fig_name="ad_by_log_ntot_overall_mrp",
    color_by_var="mrp",
    color_by_label="$M_r$",
    color_bar_min=0.25,
    color_bar_max=0.75
)

## Radius mean by ntot

In [None]:
variable = "radius_mean"

_ = plot_overall(
    data=data,
    n_target=50000,
    y_var=variable,
    y_label="$\\overline{r}$",
    x_var="log_ntot",
    x_label="$N_{tot}$",
    fig_name=f"{variable}_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

## Radius stdev by ntot

In [None]:
variable = "radius_stdev"

_ = plot_overall(
    data=data,
    n_target=50000,
    y_var=variable,
    y_label="$\\sigma_r$",
    x_var="log_ntot",
    x_label="$N_{tot}$",
    fig_name=f"{variable}_by_ntot_overall",
    color_by_var="b",
    color_by_label="$b$",
    color_bar_min=min_b,
    color_bar_max=max_b
)

In [None]:
_ = plot_overall(
    data=data.where(F.col("b") < -2.01).where(F.col("ntot") > 0),
    n_target=5000,
    y_var="radius_stdev",
    y_label="$\\sigma_r$",
    x_var="b",
    x_label="$b$",
    fig_name=f"delete_me",
    # color_by_var="log_ntot",
    # color_by_label="$N_{tot}$",
    # color_bar_min=data.select(F.min("log_ntot")).collect()[0][0],
    # color_bar_max=data.select(F.max("log_ntot")).collect()[0][0]
)

In [None]:
data.sample(0.001).count()

In [None]:
_ = plot_overall(
    data=data.withColumn("ir", F.col("nobs") / F.col("ntot")).where(F.col("b") < -2.01).where(F.col("ntot") > 5000).where(F.col("ir") > 0.5),
    n_target=5000,
    y_var="radius_mean",
    y_label="$\\overline{r}$",
    x_var="b",
    x_label="$b$",
    fig_name=f"delete_me",
    color_by_var="ir",
    color_by_label="IR",
    color_bar_min=0.5,
    color_bar_max=1.0
)

In [None]:
_ = plot_overall(
    data=data.withColumn("ir", F.col("nobs") / F.col("ntot")).where(F.col("b") < -2.01).where(F.col("ntot") > 5000).where(F.col("ir") > 0.5),
    n_target=50000,
    y_var="radius_mean",
    y_label="$\\overline{r}$",
    x_var="b",
    x_label="$b$",
    fig_name=f"delete_me",
    color_by_var="ir",
    color_by_label="IR",
    color_bar_min=0.5,
    color_bar_max=1.0
)

In [None]:
_ = plot_overall(
    data=data,
    n_target=50000,
    y_var="radius_stdev",
    y_label="$\\sigma_r$",
    x_var="b",
    x_label="$b$",
    fig_name=f"delete_me",
    color_by_var="log_ntot",
    color_by_label="$N_{tot}$",
    color_bar_min=data.select(F.min("log_ntot")).collect()[0][0],
    color_bar_max=data.select(F.max("log_ntot")).collect()[0][0]
)

In [None]:
_ = plot_overall(
    data=data[(data.b.between(-3.0, -2.8)) & (data.ntot < 1000)],
    n_target=50000,
    y_var="radius_stdev",
    y_label="$\\overline{r}$",
    x_var="ntot",
    x_label="$b$",
    fig_name=f"delete_me",
    color_by_var="log_ntot",
    color_by_label="$N_{tot}$",
    color_bar_min=data.select(F.min("log_ntot")).collect()[0][0],
    color_bar_max=data.select(F.max("log_ntot")).collect()[0][0]
)