In [1]:
import numpy as np
import scipy.stats as stats

In [3]:

def check_srm(n_AA, n_A, n_B, ratio_AA, ratio_A, ratio_B):
    """
    Check for Sample Ratio Mismatch (SRM) using a chi-square test.

    Parameters:
    - n_AA: Observed user count in AA group
    - n_A: Observed user count in A group
    - n_B: Observed user count in B group

    Returns:
    - Dictionary containing chi-square statistic, p-value, expected counts, and SRM status
    """
    # Total users
    N = n_AA + n_A + n_B

    # Expected count assuming equal distribution
    expected = [N * ratio_AA, N * ratio_A, N * ratio_B]
    observed = [n_AA, n_A, n_B]

    # Perform chi-square test
    chi2_stat, p_value = stats.chisquare(f_obs=observed, f_exp=expected)

    # Determine if there is a significant SRM
    srm_detected = p_value < 0.05

    return {
        'Observed Counts': {'AA': n_AA, 'A': n_A, 'B': n_B},
        'Expected Counts': {'AA': expected[0], 'A': expected[1], 'B': expected[2]},
        'Chi-Square Statistic': chi2_stat,
        'p-value': p_value,
        'SRM Detected': srm_detected
    }