In [None]:
# 5-1 sample generation
import numpy as np
import polars as pl

num_of_generating_samples = 10000  # 生成するサンプル数
desired_components_sum = 1  # 合計を指定する特徴量がある場合の、合計の値。例えば、この値を 100 にすれば、合計を 100 にできます

generation_settings = pl.read_csv("../test_data/setting_of_generation.csv")
setting_items = generation_settings.get_column("")
generation_settings = generation_settings.drop(setting_items.name)
x_var_names = generation_settings.columns
# 上、下限値を抽出
x_upper = tuple(map(float, generation_settings.row(0)))
x_lower = tuple(map(float, generation_settings.row(1)))
group_setting = tuple(map(int, generation_settings.row(2)))
rounding = tuple(map(int, generation_settings.row(3)))

# 0 から 1 の間の一様乱数でサンプル生成
rng = np.random.default_rng(99)
x_generated = pl.DataFrame(
    {
        col: rng.uniform(
            lower,
            upper,
            num_of_generating_samples,
        )
        for col, upper, lower in zip(x_var_names, x_upper, x_lower)
    }
)

# 合計を desired_sum_of_components にする特徴量がある場合
if sum(group_setting) != 0:
    var_groups = (
        pl.DataFrame({"vars": x_var_names, "group_no": group_setting})
        .group_by("group_no", maintain_order=True)
        .agg(pl.col("vars").implode())
    ).filter(pl.col("group_no") != 0)

    for grouped_vars in var_groups.get_column("vars"):
        x_generated = x_generated.with_columns(
            pl.col(grouped_vars)
            .truediv(pl.sum_horizontal(pl.col(grouped_vars)))
            .mul(desired_components_sum)
        )
    # 全変数が下限以上かつ上限未満である行のみ残す (各列ごとに [x_lower, x_upper) を満たすこと)
    mask = x_generated.select(
        [
            (pl.col(col) >= lower) & (pl.col(col) < upper)
            for col, lower, upper in zip(x_var_names, x_lower, x_upper)
        ]
    ).min_horizontal()
    x_generated = x_generated.filter(mask)

# 数値の丸め込みをする場合
if "rounding" in setting_items:
    x_generated = x_generated.with_columns(
        [
            pl.col(col)
            .mul(10**round_digit)
            .round(0, "half_away_from_zero")
            .truediv(10**round_digit)
            for col, round_digit in zip(x_var_names, rounding)
        ]
    )

# 保存
x_generated.write_csv("../output/generated_samples.csv")
