In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from collections import defaultdict
from collections.abc import Callable

import numpy as np
from tqdm import tqdm

In [None]:
rseed = 99
rng = np.random.default_rng(rseed)

In [None]:
output_dir = "../../outputs"
output_path = os.path.join(output_dir, "filtered_params_dict.json")

In [None]:
reloaded_params_dicts = json.load(open(output_path))

In [None]:
len(list(reloaded_params_dicts.keys()))

In [None]:
tot_systems = sum([len(v) for v in reloaded_params_dicts.values()])
print(f"tot_systems_reloaded: {tot_systems}")

In [None]:
n_scalinglaw_splits = 7
split_sizes = [tot_systems]
scalinglaw_syssample_indices = [np.arange(tot_systems)]
for i in range(n_scalinglaw_splits):
    curr_split_size = int(tot_systems // (2 ** (i + 1)))
    split_sizes.append(curr_split_size)
    curr_syssample_indices = rng.choice(
        scalinglaw_syssample_indices[i], size=curr_split_size, replace=False
    )
    scalinglaw_syssample_indices.append(curr_syssample_indices)
print(split_sizes)

In [None]:
for i, syssample_indices in enumerate(scalinglaw_syssample_indices):
    print(f"number of syssample_indices for split {i}: {syssample_indices.shape[0]}")
    if i > 0:
        assert np.all(
            np.isin(syssample_indices, scalinglaw_syssample_indices[i - 1])
        ), "smaller splits must be a subset of the previous split"

In [None]:
subdir_sample_counts_dict = {}
for system_name, system_params in reloaded_params_dicts.items():
    n_samples = len(system_params)
    subdir_sample_counts_dict[system_name] = n_samples

In [None]:
assert sum(list(subdir_sample_counts_dict.values())) == tot_systems

In [None]:
def get_system_name_for_sample_idx(sample_idx, subdir_sample_counts_dict):
    if sample_idx < 0 or sample_idx >= tot_systems:
        raise ValueError(f"sample_idx must be between 0 and {tot_systems - 1}")

    cumulative_count = 0
    for system_name, count in subdir_sample_counts_dict.items():
        if sample_idx < cumulative_count + count:
            return system_name
        cumulative_count += count

    return None  # Should never reach here if sample_idx is valid


# Example usage
sample_idx = 100
system_name = get_system_name_for_sample_idx(sample_idx, subdir_sample_counts_dict)
print(f"Sample index {sample_idx} belongs to system: {system_name}")

In [None]:
def create_sample_idx_mapping(
    subdir_sample_counts_dict: dict[str, int],
) -> Callable[[np.ndarray | list[int]], tuple[np.ndarray, np.ndarray]]:
    # Create arrays for fast lookup
    system_names = []
    boundaries = [0]  # Start with 0

    # Build the boundaries and system names arrays
    for system_name, count in subdir_sample_counts_dict.items():
        system_names.append(system_name)
        boundaries.append(boundaries[-1] + count)

    # Convert to numpy arrays for faster operations
    boundaries = np.array(boundaries)
    system_names = np.array(system_names)

    def get_system_names_and_positions(
        sample_idxs: np.ndarray | list[int],
    ) -> tuple[np.ndarray, np.ndarray]:
        # Validate input
        sample_idxs = np.asarray(sample_idxs)
        if np.any((sample_idxs < 0) | (sample_idxs >= tot_systems)):
            raise ValueError(f"All sample_idxs must be between 0 and {tot_systems - 1}")

        # Find the index where each sample_idx would be inserted in boundaries
        # Subtract 1 to get the correct system index
        system_indices = np.searchsorted(boundaries, sample_idxs, side="right") - 1

        # Calculate relative positions within each system
        relative_positions = sample_idxs - boundaries[system_indices]

        # Return both the system names and relative positions
        return system_names[system_indices], relative_positions

    return get_system_names_and_positions

In [None]:
# Create the mapping function
get_system_names_and_positions = create_sample_idx_mapping(subdir_sample_counts_dict)

In [None]:
sample_idx_lst = [0, 1, 43, 44, 200, 300]
system_names, positions = get_system_names_and_positions(sample_idx_lst)
for idx, name, pos in zip(sample_idx_lst, system_names, positions):
    print(f"Sample index {idx} belongs to system: {name} at position {pos}")

In [None]:
params_dicts_all_splits = []
for i, curr_syssample_indices in tqdm(
    enumerate(scalinglaw_syssample_indices),
    desc="Splitting params dicts for scalinglaw splits",
):
    if i == 0:
        continue
    params_dict_split = defaultdict(list)
    # Sort the system sample indices to ensure consistent ordering
    curr_syssample_indices = np.sort(curr_syssample_indices)
    print(
        f"number of syssample_indices for split {i}: {curr_syssample_indices.shape[0]}"
    )
    # validate that the current split is a subset of the previous split
    if i > 0:
        assert np.all(
            np.isin(curr_syssample_indices, scalinglaw_syssample_indices[i - 1])
        ), "smaller splits must be a subset of the previous split"

    system_names, positions = get_system_names_and_positions(curr_syssample_indices)

    for system_name, pos in zip(system_names, positions):
        params_dict_split[system_name].append(reloaded_params_dicts[system_name][pos])
    params_dicts_all_splits.append(params_dict_split)

In [None]:
sum([len(v) for v in params_dicts_all_splits[-1].values()])

In [None]:
params_dicts_all_splits[-1].keys()

In [None]:
len(params_dicts_all_splits[-6]["AtmosphericRegime_Hadley"])

In [None]:
len(params_dicts_all_splits)

In [None]:
# Convert numpy arrays to lists for JSON serialization
def convert_numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_list(item) for item in obj]
    else:
        return obj

In [None]:
# for i, param_dict_split in enumerate(params_dicts_all_splits):
#     curr_tot_systems = sum([len(v) for v in param_dict_split.values()])
#     print(f"number of systems in split: {curr_tot_systems}")

#     curr_output_path = os.path.join(
#         output_dir, f"params_dict_split_{curr_tot_systems}.json"
#     )
#     # Create a serializable version of the dictionary
#     serializable_params_dict = {}
#     for system_name, system_param_dicts in param_dict_split.items():
#         serializable_params_dict[system_name] = [
#             convert_numpy_to_list(param_dict) for param_dict in system_param_dicts
#         ]

#     with open(curr_output_path, "w") as f:
#         json.dump(serializable_params_dict, f, indent=2)

#     print(f"Saved filtered parameters to {curr_output_path}")

In [None]:
scalinglaw_syssample_indices[0].shape

In [None]:
len(list(reversed(scalinglaw_syssample_indices)))

In [None]:
params_dicts_all_splits_filtered = []
reversed_scalinglaw_syssample_indices = list(reversed(scalinglaw_syssample_indices))

for i, curr_syssample_indices in tqdm(
    enumerate(reversed_scalinglaw_syssample_indices),
    desc="Splitting params dicts for scalinglaw splits",
):
    if i == 0:
        continue
    params_dict_split = defaultdict(list)
    # Sort the system sample indices to ensure consistent ordering
    curr_syssample_indices = np.sort(curr_syssample_indices)
    print(
        f"number of syssample_indices for split {i}: {curr_syssample_indices.shape[0]}"
    )
    prev_syssample_indices = reversed_scalinglaw_syssample_indices[i - 1]

    # validate that the current split is a subset of the previous split
    if i > 0:
        assert np.all(np.isin(prev_syssample_indices, curr_syssample_indices)), (
            "smaller splits must be a subset of the previous split"
        )

    curr_syssample_indices = np.setdiff1d(
        curr_syssample_indices, prev_syssample_indices
    )
    print(
        f"number of syssample_indices for split {i} after filtering out subset in previous split: {curr_syssample_indices.shape[0]}"
    )

    system_names, positions = get_system_names_and_positions(curr_syssample_indices)

    for system_name, pos in zip(system_names, positions):
        params_dict_split[system_name].append(reloaded_params_dicts[system_name][pos])
    params_dicts_all_splits_filtered.append(params_dict_split)

In [None]:
for i, param_dict_split in enumerate(params_dicts_all_splits_filtered):
    curr_tot_systems = sum([len(v) for v in param_dict_split.values()])
    prev_tot_systems = len(reversed_scalinglaw_syssample_indices[i])
    print(f"number of systems in split: {curr_tot_systems}")
    print(f"number of systems in previous split: {prev_tot_systems}")
    start_idx = prev_tot_systems
    end_idx = start_idx + curr_tot_systems
    print(f"start_idx: {start_idx}, end_idx: {end_idx}")
    curr_output_path = os.path.join(
        output_dir, f"params_dict_split_{start_idx}-{end_idx}.json"
    )
    # Create a serializable version of the dictionary
    serializable_params_dict = {}
    for system_name, system_param_dicts in param_dict_split.items():
        serializable_params_dict[system_name] = [
            convert_numpy_to_list(param_dict) for param_dict in system_param_dicts
        ]

    print(f"check: {sum([len(v) for v in serializable_params_dict.values()])}")
    with open(curr_output_path, "w") as f:
        json.dump(serializable_params_dict, f, indent=2)

    print(f"Saved filtered parameters to {curr_output_path}")