In [29]:
import numpy as np

def generate_custom_distribution(weights, total_samples=10000):
    """
    根據指定的 bin 權重 weights，產生自定義分布的隨機數據。
    
    參數:
        weights (list of int/float): 每個 bin 的相對權重，長度表示 bin 數。
        total_samples (int): 總共要產生的樣本數，預設為 10000。
        
    回傳:
        data_custom (np.ndarray): 根據權重產生的數據樣本。
    """
    weights = np.array(weights)
    prob = weights / weights.sum()  # 正規化為機率
    bins = np.linspace(0, 1, len(weights) + 1)
    samples_per_bin = (prob * total_samples).astype(int)

    data_custom = []
    for i, n in enumerate(samples_per_bin):
        bin_start = bins[i]
        bin_end = bins[i + 1]
        data_custom.append(np.random.uniform(bin_start, bin_end, n))

    return np.concatenate(data_custom)


In [38]:
import matplotlib.pyplot as plt
import numpy as np

# 設定隨機種子以重現結果
np.random.seed(0)

# 建立四組不同分布的數據
data_a = np.random.beta(8, 0.01, 10000)   # 雙邊尖銳
data_b = np.random.beta(0.01, 8, 10000)     # 集中在左側
weights = [100, 10, 5, 10, 0, 0, 0, 0, 0, 0]
data_c =  generate_custom_distribution(weights)
# data_d = np.random.beta(0.1, 5, 10000)       # 偏左平緩


weights = [100, 0, 1, 2, 3, 4, 5, 7, 5, 25]
data_d = generate_custom_distribution(weights)


datasets = [data_a, data_b, data_c, data_d]
titles = ['filter_bad_1', 'filter_bad_2', 'filter_bad_3', 'filter_good_1']

# 設定 bins 與 colormap
bins = np.arange(0, 1.1, 0.1)
bin_width = bins[1] - bins[0]
colormap = plt.cm.viridis  # 使用 viridis colormap

# 產生與儲存直方圖
for data, title in zip(datasets, titles):
    fig, ax = plt.subplots(figsize=(3, 2))
    counts, bin_edges = np.histogram(data, bins=bins)

    # 繪製每個 bin 的長條，向右平移半個 bin 寬度
    for i in range(len(counts)):
        bin_center = bin_edges[i] + bin_width / 2
        color = colormap(bin_center)
        ax.bar(bin_center, counts[i], width=bin_width, edgecolor='black', color=color, align='center')

    # 設定軸線與標題
    ax.set_xlim(0, 1)
    ax.set_xticks([0, 0.5, 1])
    # ax.set_title(title)
    ax.set_facecolor('white')
    ax.grid(False)

    # 儲存圖表
    plt.tight_layout()
    filename = f"{title.strip('()')}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"Saved: {filename}")


Saved: filter_bad_1.png
Saved: filter_bad_2.png
Saved: filter_bad_3.png
Saved: filter_good_1.png
