In [2]:
import torch
from collections import defaultdict

def batched_multivariate_hypergeometric(a_batch):
    """
    Draws samples from a multivariate hypergeometric distribution for a batch of input tensors.
    
    Parameters:
    a_batch (torch.Tensor): A 2-dimensional tensor of shape (batch_size, k) where each row
                            contains the number of objects in each category for a batch element.

    Returns:
    torch.Tensor: A tensor of shape (batch_size, k) containing the number of objects drawn from each category.
    """
    batch_size, k = a_batch.shape
    result_batch = torch.zeros_like(a_batch, dtype=torch.int)

    for batch_idx in range(batch_size):
        a = a_batch[batch_idx]
        total_items = a.sum().item()
        n = torch.randint(0, total_items + 1, (1,)).item()

        if n == 0:
            continue

        remaining_items = a.clone().float()
        remaining_draws = n

        for i in range(k):
            if remaining_draws <= 0:
                break

            if remaining_items[i] == 0:
                continue

            p = remaining_items[i] / remaining_items.sum()
            draw = torch.distributions.Binomial(remaining_draws, p).sample().item()
            draw = min(draw, remaining_items[i].item())

            result_batch[batch_idx, i] = draw
            remaining_draws -= draw
            remaining_items[i] -= draw

    return result_batch

# Function to check if n is uniform
def check_uniform_n(a_batch, num_samples=10000):
    n_counts = defaultdict(int)
    for _ in range(num_samples):
        batch_size, k = a_batch.shape
        for batch_idx in range(batch_size):
            a = a_batch[batch_idx]
            total_items = a.sum().item()
            n = torch.randint(0, total_items + 1, (1,)).item()
            n_counts[n] += 1

    n_values, counts = zip(*n_counts.items())
    total_counts = sum(counts)
    probabilities = [count / total_counts for count in counts]

    return n_values, probabilities

# Function to check if the distribution of the output is uniform
def check_uniform_output(a, num_samples=10000):
    output_counts = defaultdict(int)
    total_items = a.sum().item()
    k = len(a)

    for _ in range(num_samples):
        n = torch.randint(0, total_items + 1, (1,)).item()
        remaining_items = a.clone().float()
        remaining_draws = n

        result = torch.zeros_like(a, dtype=torch.int)
        for i in range(k):
            if remaining_draws <= 0:
                break

            if remaining_items[i] == 0:
                continue

            p = remaining_items[i] / remaining_items.sum()
            draw = torch.distributions.Binomial(remaining_draws, p).sample().item()
            draw = min(draw, remaining_items[i].item())

            result[i] = draw
            remaining_draws -= draw
            remaining_items[i] -= draw

        output_counts[tuple(result.tolist())] += 1

    outputs, counts = zip(*output_counts.items())
    total_counts = sum(counts)
    probabilities = [count / total_counts for count in counts]

    return outputs, probabilities

# Example usage
a_batch = torch.tensor([
    [8, 6, 6],
    [10, 5, 5],
    [7, 3, 10]
])

# Check if n is uniform
n_values, n_probabilities = check_uniform_n(a_batch)
print("n values:", n_values)
print("n probabilities:", n_probabilities)

# Check if the distribution of the output is uniform for the first element in the batch
outputs, output_probabilities = check_uniform_output(a_batch[0])
print("Possible outputs:", outputs)
print("Output probabilities:", output_probabilities)


n values: (12, 7, 5, 2, 14, 17, 19, 20, 13, 4, 16, 0, 9, 6, 8, 3, 10, 15, 11, 1, 18)
n probabilities: [0.047233333333333336, 0.0474, 0.045733333333333334, 0.04756666666666667, 0.04806666666666667, 0.049466666666666666, 0.04853333333333333, 0.0466, 0.044566666666666664, 0.0486, 0.04736666666666667, 0.0481, 0.0461, 0.04643333333333333, 0.0481, 0.0462, 0.048133333333333334, 0.04856666666666667, 0.05056666666666667, 0.047, 0.049666666666666665]
Possible outputs: ((5, 2, 2), (8, 6, 5), (7, 4, 3), (8, 4, 6), (0, 0, 0), (8, 6, 6), (3, 0, 0), (7, 5, 2), (3, 2, 2), (3, 6, 4), (6, 2, 3), (4, 4, 5), (2, 2, 0), (8, 2, 6), (0, 0, 2), (5, 0, 4), (5, 4, 4), (8, 3, 4), (2, 1, 2), (4, 2, 0), (5, 6, 5), (6, 6, 2), (3, 4, 6), (4, 5, 1), (6, 3, 3), (4, 3, 3), (6, 5, 3), (1, 1, 0), (4, 2, 2), (0, 1, 2), (0, 1, 1), (4, 1, 1), (4, 6, 6), (4, 5, 5), (3, 1, 0), (2, 1, 1), (3, 1, 1), (4, 1, 4), (8, 5, 2), (7, 1, 1), (4, 1, 3), (4, 0, 1), (7, 4, 1), (6, 5, 6), (0, 1, 0), (1, 1, 1), (8, 5, 6), (6, 5, 4), (1, 2, 1

In [146]:
torch.randint(0, 2, (1,)).item()

0

In [147]:
import scipy

In [157]:
scipy.stats.beta.sf(.5, *torch.tensor([[10,2,1],[2,4,4]]))

array([0.99414062, 0.1875    , 0.0625    ])

In [5]:
(torch.tensor([-1,2])<=0).any()

tensor(True)