In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import statsmodels.formula.api as smf

from saturation.utils import *

In [20]:
n_cores = 26

spark = (
    SparkSession.builder
    .master(f"local[{n_cores}]")
    .appName("Saturation")
    .config("spark.sql.shuffle.partitions", "500")
    .config("spark.driver.memory", "40g")
    .config("spark.driver.maxResultSize", "16g")
    .getOrCreate()
)

# Configuration variables

In [27]:
BASE_PATH = "/data/saturation/thesis_run_20250223/"

# Maximum nstat at which to retrieve statistics
MAX_NSTAT_FOR_STATISTICS = int(1e6)

# States dataset variables
N_NSTATS = 1
MIN_NSTAT = int(1.5e6)
MAX_NSTAT = int(2.5e6)

# Load configurations

In [28]:
configs_pdf, configs_df, configs_dict = get_configs(
    base_path=BASE_PATH,
    spark=spark
)

25/03/03 17:52:31 WARN CacheManager: Asked to cache already cached data.


In [29]:
simulation_ids = list(configs_pdf.simulation_id.drop_duplicates())

In [30]:
first_sim_id = simulation_ids[0]
study_region_size = configs_dict[first_sim_id]["study_region_size"]
study_region_padding = configs_dict[first_sim_id]["study_region_padding"]

# Write out states for all simulations to disk

In [31]:
if nstats == 1:
    nstats = [MAX_NSTAT]
else:
    step = int((MAX_NSTAT - MIN_NSTAT) / N_NSTATS)
    nstats = [MIN_NSTAT + x * step for x in range(N_NSTATS)]

In [32]:
for simulation_id in simulation_ids:
    stats_df = spark.read.parquet(f"{BASE_PATH}/{simulation_id}/statistics_*.parquet")
    craters_df = spark.read.parquet( f"{BASE_PATH}/{simulation_id}/craters_*.parquet")
    removals_df = spark.read.parquet(f"{BASE_PATH}/{simulation_id}/crater_removals_*.parquet")
    
    states = get_states(
        stats_df=stats_df,
        craters_df=craters_df,
        removals_df=removals_df,
        nstats=nstats,
        study_region_size=study_region_size,
        study_region_padding=study_region_padding,
        spark=spark,
        result_columns=["crater_id", "radius", "nstat"],
    )
    states["simulation_id"] = simulation_id
    states = states.set_index("simulation_id").sort_index()
    states.to_parquet(f"data/states_{simulation_id}_{N_NSTATS}.parquet")

                                                                                

# Write out statistics

In [10]:
result_columns = [
    "radius",
    "lifespan",
    "simulation_id"
]
for simulation_id in simulation_ids:
    stats_df = spark.read.parquet(f"{BASE_PATH}/{simulation_id}/statistics_*.parquet")
    craters_df = spark.read.parquet( f"{BASE_PATH}/{simulation_id}/craters_*.parquet")
    removals_df = spark.read.parquet(f"{BASE_PATH}/{simulation_id}/crater_removals_*.parquet")
        
    statistics = get_statistics_with_lifespans_for_simulations(
        simulation_ids=[simulation_id],
        base_path=BASE_PATH,
        configs_df=configs_df,
        spark=spark,
        result_columns=result_columns,
        max_nstat=MAX_NSTAT_FOR_STATISTICS,
    )
    statistics = statistics.set_index("simulation_id").sort_index()
    statistics.to_parquet(f"data/statistics_{simulation_id}.parquet")

                                                                                