In [4]:
import numpy as np
from ete3 import Tree
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import random

# ----------------------------
# Beta-splitting 트리 생성 (index 기준 split)
# ----------------------------
def simulate_beta_splitting_tree_uniform(n, beta):
    leaves = np.random.uniform(0, 1, n)

    def split(leaves_sub):
        if len(leaves_sub) == 1:
            return Tree(name=str(leaves_sub[0]))
        size = len(leaves_sub)

        a = max(beta + 1, 1e-6)
        b = max(beta + 1, 1e-6)

        x = np.random.beta(a, b)
        left_leaves = [l for l in leaves_sub if l <= x]
        right_leaves = [l for l in leaves_sub if l > x]

        if len(left_leaves) == 0:
            left_leaves.append(right_leaves.pop(0))
        elif len(right_leaves) == 0:
            right_leaves.append(left_leaves.pop(-1))

        left_tree = split(left_leaves)
        right_tree = split(right_leaves)
        root = Tree()
        root.add_child(left_tree)
        root.add_child(right_tree)
        return root

    return split(list(leaves))

# ----------------------------
# 루트 split ratio 계산 (약분 없음)
# ----------------------------
def root_subtree_ratio(tree):
    children = tree.get_children()
    if len(children) == 0:
        return None
    elif len(children) == 1:
        l = len(children[0].get_leaves())
        r = 0
    else:
        l = len(children[0].get_leaves())
        r = len(children[1].get_leaves())
    return f"{l}:{r}"

# ----------------------------
# n개 prune
# ----------------------------
def prune_to_n_leaves(tree, n):
    all_leaves = tree.get_leaves()
    selected_leaves = random.sample(all_leaves, n)
    selected_names = [leaf.name for leaf in selected_leaves]
    tree.prune(selected_names)  # 원본 트리 수정
    return tree

# ----------------------------
# GIF 생성
# ----------------------------
def make_beta_prune_gif(n, beta_values=None, large_n=1000, num_samples=50, filename="beta_prune.gif"):
    if beta_values is None:
        beta_values = np.linspace(-0.9, 2, 100)

    ratios = [f"{i}:{n-i}" for i in range(1, n//2+1)]
    images = []

    for beta in beta_values:
        ratio_counts = {r:0 for r in ratios}
        for _ in range(num_samples):
            large_tree = simulate_beta_splitting_tree_uniform(large_n, beta)
            pruned_tree = prune_to_n_leaves(large_tree, n)
            r = root_subtree_ratio(pruned_tree)
            if r in ratio_counts:
                ratio_counts[r] += 1

        ratio_values = [ratio_counts[r]/num_samples for r in ratios]

        plt.figure(figsize=(6,4))
        plt.bar(range(len(ratios)), ratio_values, color='skyblue')
        plt.xticks(range(len(ratios)), ratios, rotation=45)
        plt.ylim(0, max(ratio_values)*1.1)  # 최대값보다 10% 여유
        plt.ylabel("Proportion")
        plt.title(f"Beta-splitting prune n={n}, beta={beta:.2f}")
        plt.tight_layout()

        buf = BytesIO()
        plt.savefig(buf, format='png')
        plt.close()
        buf.seek(0)
        images.append(Image.open(buf))

    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=100,
        loop=0
    )

# ----------------------------
# 실행 예시
beta_values = np.concatenate([
    np.linspace(-1+1e-2, 1, 300),
    np.linspace(1, 100, 100)
])

make_beta_prune_gif(8, beta_values, large_n=256, num_samples=100, filename="beta_prune_n8.gif")
make_beta_prune_gif(16, beta_values, large_n=256, num_samples=100, filename="beta_prune_n16.gif")