# Statistical Analysis

## Validation

In [2]:
"""
Complete Experimental Validation Framework for A/B Testing

Implements ALL required validation checks:
1. Sample Ratio Mismatch (SRM) Detection 
2. Covariate Balance Verification 
3. Temporal Stability Checks 
4. Multiple Testing Correction

When run directly, validates all 5 A/B tests with comprehensive reporting.
"""

import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.multitest import multipletests
from typing import Dict, List, Tuple, Optional, Union
import warnings
from datetime import datetime
import os


class ExperimentValidator:
    """
    Complete validation framework for A/B tests.
    """
    
    def __init__(self, 
                 srm_threshold: float = 0.001,
                 balance_threshold: float = 0.2,
                 temporal_threshold: float = 0.2):
        self.srm_threshold = srm_threshold
        self.balance_threshold = balance_threshold
        self.temporal_threshold = temporal_threshold
    
    def sample_ratio_mismatch_test(self,
                                   df: pd.DataFrame,
                                   variant_col: str,
                                   expected_ratio: Optional[Dict[str, float]] = None) -> Dict:
        """Sample Ratio Mismatch detection."""
        
        observed = df[variant_col].value_counts().sort_index()
        total = len(df)
        n_variants = len(observed)
        
        if expected_ratio is None:
            expected = pd.Series([total / n_variants] * n_variants, index=observed.index)
        else:
            expected = pd.Series({k: v * total for k, v in expected_ratio.items()})
        
        chi2_stat = np.sum((observed - expected)**2 / expected)
        df_chi = n_variants - 1
        pvalue = 1 - stats.chi2.cdf(chi2_stat, df_chi)
        
        has_srm = pvalue < self.srm_threshold
        
        result = {
            'test': 'sample_ratio_mismatch',
            'chi2_statistic': chi2_stat,
            'degrees_of_freedom': df_chi,
            'pvalue': pvalue,
            'threshold': self.srm_threshold,
            'has_srm': has_srm,
            'observed_counts': observed.to_dict(),
            'expected_counts': expected.to_dict(),
            'observed_ratio': (observed / total).to_dict(),
            'expected_ratio': (expected / total).to_dict()
        }
        
        if has_srm:
            result['warning'] = f"CRITICAL: SRM detected (p={pvalue:.6f} < {self.srm_threshold}). Experiment is INVALID."
        else:
            result['message'] = f"No SRM detected (p={pvalue:.4f}). Allocation is as expected."
        
        return result
    
    def covariate_balance_check(self,
                                df: pd.DataFrame,
                                variant_col: str,
                                covariates: List[str],
                                threshold: Optional[float] = None) -> Dict:
        """Covariate balance verification using SMD."""
        
        if threshold is None:
            threshold = self.balance_threshold
        
        variants = df[variant_col].unique()
        
        if len(variants) < 2:
            return {'error': 'Need at least 2 variants for balance check'}
        
        balance_results = []
        imbalanced_covariates = []
        
        for covariate in covariates:
            if covariate not in df.columns:
                warnings.warn(f"Covariate '{covariate}' not found in dataframe")
                continue
            
            is_categorical = (
                df[covariate].dtype == 'object' or 
                df[covariate].dtype.name == 'category' or
                df[covariate].nunique() < 10
            )
            
            if is_categorical:
                for category in df[covariate].unique():
                    proportions = {}
                    for variant in variants:
                        variant_data = df[df[variant_col] == variant][covariate]
                        proportions[variant] = (variant_data == category).mean()
                    
                    variant_list = list(variants)
                    p1 = proportions[variant_list[0]]
                    p2 = proportions[variant_list[1]]
                    p_pooled = (p1 + p2) / 2
                    
                    if p_pooled > 0 and p_pooled < 1:
                        smd = abs(p1 - p2) / np.sqrt(p_pooled * (1 - p_pooled))
                    else:
                        smd = 0.0
                    
                    is_imbalanced = smd > threshold
                    
                    balance_results.append({
                        'covariate': f"{covariate}={category}",
                        'type': 'categorical',
                        'variant_1': variant_list[0],
                        'variant_2': variant_list[1],
                        'proportion_1': p1,
                        'proportion_2': p2,
                        'smd': smd,
                        'imbalanced': is_imbalanced
                    })
                    
                    if is_imbalanced:
                        imbalanced_covariates.append(f"{covariate}={category}")
            else:
                variant_stats = {}
                for variant in variants:
                    variant_data = df[df[variant_col] == variant][covariate]
                    variant_stats[variant] = {
                        'mean': variant_data.mean(),
                        'std': variant_data.std(),
                        'var': variant_data.var(),
                        'n': len(variant_data)
                    }
                
                variant_list = list(variants)
                v1, v2 = variant_list[0], variant_list[1]
                
                mean_diff = abs(variant_stats[v1]['mean'] - variant_stats[v2]['mean'])
                pooled_std = np.sqrt((variant_stats[v1]['var'] + variant_stats[v2]['var']) / 2)
                
                if pooled_std > 0:
                    smd = mean_diff / pooled_std
                else:
                    smd = 0.0
                
                is_imbalanced = smd > threshold
                
                balance_results.append({
                    'covariate': covariate,
                    'type': 'continuous',
                    'variant_1': variant_list[0],
                    'variant_2': variant_list[1],
                    'mean_1': variant_stats[v1]['mean'],
                    'mean_2': variant_stats[v2]['mean'],
                    'std_1': variant_stats[v1]['std'],
                    'std_2': variant_stats[v2]['std'],
                    'smd': smd,
                    'imbalanced': is_imbalanced
                })
                
                if is_imbalanced:
                    imbalanced_covariates.append(covariate)
        
        balance_df = pd.DataFrame(balance_results)
        max_smd = balance_df['smd'].max() if len(balance_df) > 0 else 0
        n_imbalanced = len(imbalanced_covariates)
        
        if max_smd < 0.1:
            message = f"Excellent balance (max SMD={max_smd:.3f} < 0.1)"
        elif max_smd < threshold:
            message = f"Good balance (max SMD={max_smd:.3f} < {threshold})"
        else:
            message = f"{n_imbalanced} covariate(s) imbalanced (max SMD={max_smd:.3f} ≥ {threshold})"
        
        return {
            'test': 'covariate_balance',
            'variants_compared': list(variants)[:2],
            'balance_results': balance_df,
            'imbalanced_covariates': imbalanced_covariates,
            'n_imbalanced': n_imbalanced,
            'max_smd': max_smd,
            'threshold': threshold,
            'message': message
        }
    
    def temporal_stability_check(self,
                                df: pd.DataFrame,
                                variant_col: str,
                                date_col: str,
                                threshold: Optional[float] = None) -> Dict:
        """Temporal stability verification."""
        
        if threshold is None:
            threshold = self.temporal_threshold
        
        df = df.copy()
        if not pd.api.types.is_datetime64_any_dtype(df[date_col]):
            df[date_col] = pd.to_datetime(df[date_col])
        
        df['date'] = df[date_col].dt.date
        daily_counts = df.groupby(['date', variant_col]).size().unstack(fill_value=0)
        
        cv_results = {}
        for variant in daily_counts.columns:
            counts = daily_counts[variant]
            mean_count = counts.mean()
            std_count = counts.std()
            cv = std_count / mean_count if mean_count > 0 else 0.0
            cv_results[variant] = cv
        
        max_cv = max(cv_results.values())
        is_stable = max_cv < threshold
        
        message = (
            f"Stable allocation over time (max CV={max_cv:.3f} < {threshold})" if is_stable
            else f"Unstable allocation (max CV={max_cv:.3f} ≥ {threshold})"
        )
        
        return {
            'test': 'temporal_stability',
            'cv_by_variant': cv_results,
            'max_cv': max_cv,
            'threshold': threshold,
            'is_stable': is_stable,
            'daily_counts': daily_counts,
            'n_days': len(daily_counts),
            'message': message
        }
    
    def multiple_testing_correction(self,
                                    pvalues: List[float],
                                    method: str = 'holm',
                                    alpha: float = 0.05) -> Dict:
        """
        Multiple testing correction.
        
        Methods:
        - 'bonferroni': Most conservative (alpha/k)
        - 'holm': Holm-Bonferroni (recommended for 5-10 tests)
        - 'fdr_bh': Benjamini-Hochberg FDR (for >10 tests)
        
        References:
        - Bonferroni (1936)
        - Holm (1979)
        - Benjamini & Hochberg (1995)
        """
        
        pvalues_array = np.array(pvalues)
        n_tests = len(pvalues_array)
        
        # Apply correction
        reject, pvals_corrected, alphacSidak, alphacBonf = multipletests(
            pvalues_array,
            alpha=alpha,
            method=method
        )
        
        method_names = {
            'bonferroni': 'Bonferroni',
            'holm': 'Holm-Bonferroni',
            'fdr_bh': 'Benjamini-Hochberg FDR'
        }
        
        return {
            'test': 'multiple_testing_correction',
            'method': method_names.get(method, method),
            'n_tests': n_tests,
            'alpha': alpha,
            'original_pvalues': pvalues_array.tolist(),
            'corrected_pvalues': pvals_corrected.tolist(),
            'reject': reject.tolist(),
            'n_significant_original': sum(pvalues_array < alpha),
            'n_significant_corrected': sum(reject),
            'message': (
                f"✓ Multiple testing correction applied: {method_names.get(method, method)}\n"
                f"  Original significant: {sum(pvalues_array < alpha)}/{n_tests}\n"
                f"  Corrected significant: {sum(reject)}/{n_tests}"
            )
        }
    
    def run_all_validations(self,
                           df: pd.DataFrame,
                           variant_col: str,
                           covariates: Optional[List[str]] = None,
                           date_col: Optional[str] = None,
                           metric_pvalues: Optional[List[float]] = None,
                           correction_method: str = 'holm') -> Dict:
        """
        Run complete validation suite including multiple testing correction.
        """
        
        results = {}
        
        print("=" * 80)
        print("EXPERIMENTAL VALIDATION SUITE")
        print("=" * 80)
        
        # 1. SRM Test
        print("\n1. Sample Ratio Mismatch Test")
        print("-" * 80)
        srm_result = self.sample_ratio_mismatch_test(df, variant_col)
        results['srm'] = srm_result
        print(srm_result.get('message', srm_result.get('warning', '')))
        
        if srm_result['has_srm']:
            print("\n" + "=" * 80)
            print("VALIDATION FAILED: SRM detected.")
            print("=" * 80)
            return results
        
        # 2. Covariate Balance
        if covariates:
            print("\n2. Covariate Balance Check")
            print("-" * 80)
            balance_result = self.covariate_balance_check(df, variant_col, covariates)
            results['balance'] = balance_result
            print(balance_result.get('message', ''))
        
        # 3. Temporal Stability
        if date_col:
            print("\n3. Temporal Stability Check")
            print("-" * 80)
            temporal_result = self.temporal_stability_check(df, variant_col, date_col)
            results['temporal'] = temporal_result
            print(temporal_result.get('message', ''))
        
        # 4. Multiple Testing Correction
        if metric_pvalues:
            print("\n4. Multiple Testing Correction")
            print("-" * 80)
            correction_result = self.multiple_testing_correction(
                metric_pvalues,
                method=correction_method,
                alpha=0.05
            )
            results['multiple_testing'] = correction_result
            print(correction_result.get('message', ''))
        
        # Summary
        print("\n" + "=" * 80)
        all_clear = (
            not srm_result['has_srm'] and
            (not covariates or balance_result.get('max_smd', 0) < self.balance_threshold) and
            (not date_col or temporal_result.get('is_stable', True))
        )
        
        if all_clear:
            print("ALL VALIDATION CHECKS PASSED")
        else:
            print("VALIDATION WARNINGS DETECTED")
        print("=" * 80)
        
        return results


# ============================================================================
# COMPREHENSIVE VALIDATION FOR ALL 5 A/B TESTS
# ============================================================================

def validate_test(test_name, csv_file, validator):
    """Validate a single test"""
    
    print(f"\n{'='*80}")
    print(f"TEST: {test_name}")
    print('='*80)
    
    try:
        # Coba 1: Langsung di folder raw_dataset (Posisi standar)
        if os.path.exists(f'raw_dataset/{csv_file}'):
            df = pd.read_csv(f'raw_dataset/{csv_file}')
            
        # Coba 2: Mundur satu folder (Siapa tau kamu pindah folder)
        elif os.path.exists(f'../raw_dataset/{csv_file}'):
            df = pd.read_csv(f'../raw_dataset/{csv_file}')
            
        # Coba 3: Path Absolute (Jaga-jaga)
        elif os.path.exists(f'/raw_dataset/{csv_file}'):
            df = pd.read_csv(f'/raw_dataset/{csv_file}')
            
        else:
            raise FileNotFoundError(f"Cannot find {csv_file}")
            
    except FileNotFoundError:
        print(f"File not found: {csv_file}")
        print("   Please run data_generation.py first!")
        return None
    
    print(f"Loaded: {len(df):,} rows")
    
    # Variant split
    variant_counts = df['variant'].value_counts()
    print(f"Variants ({len(variant_counts)}):")
    for variant, count in variant_counts.items():
        pct = count / len(df) * 100
        print(f"   - {variant}: {count:,} ({pct:.1f}%)")
    
    # Quick validation
    srm = validator.sample_ratio_mismatch_test(df, 'variant')
    balance = validator.covariate_balance_check(
        df, 'variant', ['device_type', 'browser', 'region']
    )
    temporal = validator.temporal_stability_check(df, 'variant', 'timestamp')
    
    # Status
    srm_status = "PASS" if not srm['has_srm'] else "FAIL"
    balance_status = "OK" if balance['max_smd'] < 0.1 else "WARNING" if balance['max_smd'] < 0.2 else "FAIL"
    temporal_status = "OK" if temporal['is_stable'] else "WARNING"
    
    print(f"\nValidation Results:")
    print(f"  SRM Test:        {srm_status} (p={srm['pvalue']:.4f})")
    print(f"  Balance:         {balance_status} (SMD={balance['max_smd']:.3f})")
    print(f"  Temporal:        {temporal_status} (CV={temporal['max_cv']:.3f})")
    
    return {
        'test': test_name,
        'n': len(df),
        'n_variants': len(variant_counts),
        'srm_pvalue': srm['pvalue'],
        'srm_passed': not srm['has_srm'],
        'balance_smd': balance['max_smd'],
        'balance_ok': balance['max_smd'] < 0.2,
        'temporal_cv': temporal['max_cv'],
        'temporal_stable': temporal['is_stable'],
        'overall_valid': not srm['has_srm'] and balance['max_smd'] < 0.2
    }


def validate_all_tests():
    """Run comprehensive validation on all 5 A/B tests"""
    
    print("="*80)
    print("COMPREHENSIVE VALIDATION SUITE")
    print("Validating All 5 A/B Tests")
    print("="*80)
    
    validator = ExperimentValidator(
        srm_threshold=0.001,
        balance_threshold=0.2,
        temporal_threshold=0.2
    )
    
    tests = [
        ('Test 1: Menu Design', 'test1_menu.csv'),
        ('Test 2: Novelty Slider', 'test2_novelty_slider.csv'),
        ('Test 3: Product Sliders', 'test3_product_sliders.csv'),
        ('Test 4: Customer Reviews', 'test4_reviews.csv'),
        ('Test 5: Search Engine', 'test5_search_engine.csv')
    ]
    
    results = []
    for test_name, csv_file in tests:
        result = validate_test(test_name, csv_file, validator)
        if result:
            results.append(result)
    
    # Summary table
    print(f"\n\n{'='*80}")
    print("SUMMARY TABLE")
    print('='*80)
    
    if results:
        summary_df = pd.DataFrame(results)
        
        print(f"\n{'Test':<30} {'N':>8} {'SRM':>8} {'Balance':>10} {'Temporal':>10} {'Valid':>8}")
        print('-'*80)
        
        for _, row in summary_df.iterrows():
            test = row['test'][:28]
            n = f"{int(row['n']):,}"
            srm = "PASS" if row['srm_passed'] else "FAIL"
            balance = "Good" if row['balance_ok'] else "Warning"
            temporal = "Stable" if row['temporal_stable'] else "Unstable"
            valid = "YES" if row['overall_valid'] else "CHECK"
            
            print(f"{test:<30} {n:>8} {srm:>8} {balance:>10} {temporal:>10} {valid:>8}")
        
        # Overall stats
        print('\n' + '='*80)
        print("OVERALL STATISTICS")
        print('='*80)
        
        n_total = summary_df['n'].sum()
        n_passed_srm = summary_df['srm_passed'].sum()
        n_valid = summary_df['overall_valid'].sum()
        
        print(f"\nTotal samples across all tests: {n_total:,}")
        print(f"Tests passed SRM check: {n_passed_srm}/{len(results)}")
        print(f"Tests fully valid: {n_valid}/{len(results)}")
        
        if n_passed_srm == len(results) and n_valid == len(results):
            print("\n ALL TESTS ARE VALID")
            print("\nAll experiments passed validation checks!")
            print("You can proceed with statistical analysis with full confidence.")
        elif n_passed_srm < len(results):
            print("\n CRITICAL ISSUES DETECTED")
            print("\nSome tests failed SRM check - DO NOT analyze those tests!")
        else:
            print("\nMINOR WARNINGS DETECTED")
            print("\nTests passed critical checks but have minor balance/temporal issues.")
            print("Proceed with caution and consider causal adjustment methods.")
        
        # Detailed recommendations
        print("\n" + "="*80)
        print("RECOMMENDATIONS BY TEST")
        print("="*80)
        
        for _, row in summary_df.iterrows():
            print(f"\n{row['test']}:")
            if row['overall_valid']:
                print("All checks passed - proceed with analysis")
            else:
                if not row['srm_passed']:
                    print("    SRM FAILED - DO NOT ANALYZE")
                    print("     → Investigate randomization bug")
                    print("     → Restart experiment after fix")
                elif not row['balance_ok']:
                    print("    Balance issue detected")
                    print("     → Use regression adjustment or CUPED")
                    print("     → Check for selection bias")
                if not row['temporal_stable']:
                    print("    Temporal instability detected")
                    print("     → Check for system changes during test")
                    print("     → Consider excluding unstable periods")
    
    print("\n" + "="*80)
    print("VALIDATION COMPLETE")
    print("="*80 + "\n")



# MAIN EXECUTION
if __name__ == "__main__":
    validate_all_tests()


COMPREHENSIVE VALIDATION SUITE
Validating All 5 A/B Tests

TEST: Test 1: Menu Design
Loaded: 7,000 rows
Variants (2):
   - A_horizontal_menu: 3,500 (50.0%)
   - B_dropdown_menu: 3,500 (50.0%)

Validation Results:
  SRM Test:        PASS (p=1.0000)
  Balance:         OK (SMD=0.026)
  Temporal:        OK (CV=0.057)

TEST: Test 2: Novelty Slider
Loaded: 16,000 rows
Variants (2):
   - A_manual_novelties: 8,000 (50.0%)
   - B_personalized_novelties: 8,000 (50.0%)

Validation Results:
  SRM Test:        PASS (p=1.0000)
  Balance:         OK (SMD=0.028)
  Temporal:        OK (CV=0.038)

TEST: Test 3: Product Sliders
Loaded: 18,000 rows
Variants (3):
   - A_selected_by_others_only: 6,000 (33.3%)
   - B_similar_products_top: 6,000 (33.3%)
   - C_selected_by_others_top: 6,000 (33.3%)

Validation Results:
  SRM Test:        PASS (p=1.0000)
  Balance:         OK (SMD=0.039)
  Temporal:        OK (CV=0.049)

TEST: Test 4: Customer Reviews
Loaded: 42,000 rows
Variants (2):
   - A_no_featured_reviews

## Define Step

In [3]:
import numpy as np
import pandas as pd
from scipy import stats
from scipy.stats import chi2_contingency
from statsmodels.stats.multitest import multipletests
from typing import Dict, List, Tuple, Optional
import warnings

# Try to import validation module
try:
    from validation import ExperimentValidator
    VALIDATION_AVAILABLE = True
except ImportError:
    VALIDATION_AVAILABLE = False
    warnings.warn("Validation module not available. Skipping validation checks.")


class ABTestAnalyzer:
    
    def __init__(self, alpha: float = 0.05):
        self.alpha = alpha
        if VALIDATION_AVAILABLE:
            self.validator = ExperimentValidator(srm_threshold=0.001)  # Stricter for SRM
        else:
            self.validator = None
    
    def calculate_sample_size(self,
                            baseline_rate: float,
                            mde: float,
                            alpha: float = 0.05,
                            power: float = 0.80,
                            two_tailed: bool = True) -> int:
        
        if two_tailed:
            z_alpha = stats.norm.ppf(1 - alpha/2)
        else:
            z_alpha = stats.norm.ppf(1 - alpha)
        
        z_beta = stats.norm.ppf(power)
    
        p1 = baseline_rate
        p2 = baseline_rate * (1 + mde)
        
        p2 = min(p2, 0.999)
        
        numerator = (z_alpha + z_beta) ** 2 * (p1 * (1 - p1) + p2 * (1 - p2))
        denominator = (p2 - p1) ** 2
        
        n = numerator / denominator
        
        return int(np.ceil(n))
    
    def two_sample_ttest(self,
                        control: np.ndarray,
                        treatment: np.ndarray,
                        metric_name: str,
                        equal_var: bool = False) -> Dict:
        
        control = control[~np.isnan(control)]
        treatment = treatment[~np.isnan(treatment)]
        
        control_mean = control.mean()
        treatment_mean = treatment.mean()
        control_std = control.std(ddof=1)
        treatment_std = treatment.std(ddof=1)
        n_control = len(control)
        n_treatment = len(treatment)
        
        statistic, pvalue = stats.ttest_ind(treatment, control, equal_var=equal_var)
        
        pooled_std = np.sqrt((control_std**2 + treatment_std**2) / 2)
        cohens_d = (treatment_mean - control_mean) / pooled_std if pooled_std > 0 else 0
        
        se_diff = np.sqrt(control_std**2/n_control + treatment_std**2/n_treatment)
        
        if not equal_var:
            num = (control_std**2/n_control + treatment_std**2/n_treatment)**2
            denom = ((control_std**2/n_control)**2/(n_control-1) + 
                    (treatment_std**2/n_treatment)**2/(n_treatment-1))
            df = num / denom if denom > 0 else n_control + n_treatment - 2
        else:
            df = n_control + n_treatment - 2
        
        t_crit = stats.t.ppf(1 - self.alpha/2, df)
        diff = treatment_mean - control_mean
        ci_lower = diff - t_crit * se_diff
        ci_upper = diff + t_crit * se_diff
        
        relative_lift_pct = (diff / control_mean * 100) if control_mean != 0 else 0
        
        return {
            'metric': metric_name,
            'test_type': 't-test',
            'statistic': statistic,
            'pvalue': pvalue,
            'significant': pvalue < self.alpha,
            'control_mean': control_mean,
            'treatment_mean': treatment_mean,
            'control_std': control_std,
            'treatment_std': treatment_std,
            'absolute_diff': diff,
            'relative_lift_pct': relative_lift_pct,
            'cohens_d': cohens_d,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'n_control': n_control,
            'n_treatment': n_treatment,
            'degrees_of_freedom': df
        }
    
    def proportion_test(self,
                        control_successes: int,
                        control_total: int,
                        treatment_successes: int,
                        treatment_total: int,
                        metric_name: str) -> Dict:
        
        p_control = control_successes / control_total
        p_treatment = treatment_successes / treatment_total
        
        p_pooled = (control_successes + treatment_successes) / (control_total + treatment_total)
        
        se = np.sqrt(p_pooled * (1 - p_pooled) * (1/control_total + 1/treatment_total))
        
        z_stat = (p_treatment - p_control) / se if se > 0 else 0
        
        pvalue = 2 * (1 - stats.norm.cdf(abs(z_stat)))
        
        se_diff = np.sqrt(p_control*(1-p_control)/control_total + 
                        p_treatment*(1-p_treatment)/treatment_total)
        z_crit = stats.norm.ppf(1 - self.alpha/2)
        diff = p_treatment - p_control
        ci_lower = diff - z_crit * se_diff
        ci_upper = diff + z_crit * se_diff
        
        relative_lift_pct = (diff / p_control * 100) if p_control > 0 else 0
        
        return {
            'metric': metric_name,
            'test_type': 'proportion_test',
            'statistic': z_stat,
            'pvalue': pvalue,
            'significant': pvalue < self.alpha,
            'control_rate': p_control,
            'treatment_rate': p_treatment,
            'absolute_diff': diff,
            'relative_lift_pct': relative_lift_pct,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'n_control': control_total,
            'n_treatment': treatment_total
        }

    def chi_square_test(self,
                        control: np.ndarray,
                        treatment: np.ndarray,
                        metric_name: str) -> Dict:
        
        combined = np.concatenate([control, treatment])
        labels = np.concatenate([np.zeros(len(control)), np.ones(len(treatment))])
        
        contingency_table = pd.crosstab(combined, labels)
        
        chi2, pvalue, dof, expected = chi2_contingency(contingency_table)

        n = len(combined)
        min_dim = min(contingency_table.shape[0], contingency_table.shape[1]) - 1
        cramers_v = np.sqrt(chi2 / (n * min_dim)) if min_dim > 0 else 0
        
        return {
            'metric': metric_name,
            'test_type': 'chi_square',
            'statistic': chi2,
            'pvalue': pvalue,
            'significant': pvalue < self.alpha,
            'degrees_of_freedom': dof,
            'cramers_v': cramers_v,
            'n_control': len(control),
            'n_treatment': len(treatment)
        }
    
    def mann_whitney_u_test(self,
                            control: np.ndarray,
                            treatment: np.ndarray,
                            metric_name: str) -> Dict:

        control = control[~np.isnan(control)]
        treatment = treatment[~np.isnan(treatment)]
        
        statistic, pvalue = stats.mannwhitneyu(treatment, control, alternative='two-sided')
        
        n1 = len(control)
        n2 = len(treatment)
        rank_biserial = 1 - (2*statistic) / (n1 * n2)
        
        control_median = np.median(control)
        treatment_median = np.median(treatment)
        
        return {
            'metric': metric_name,
            'test_type': 'mann_whitney',
            'statistic': statistic,
            'pvalue': pvalue,
            'significant': pvalue < self.alpha,
            'control_median': control_median,
            'treatment_median': treatment_median,
            'rank_biserial': rank_biserial,
            'n_control': n1,
            'n_treatment': n2
        }
    
    def bootstrap_confidence_interval(self,
                                    control: np.ndarray,
                                    treatment: np.ndarray,
                                    metric_name: str,
                                    n_bootstrap: int = 10000,
                                    confidence_level: float = 0.95) -> Dict:
        
        np.random.seed(42)
        
        control = control[~np.isnan(control)]
        treatment = treatment[~np.isnan(treatment)]
        
        boot_diffs = []
        for _ in range(n_bootstrap):
            control_boot = np.random.choice(control, size=len(control), replace=True)
            treatment_boot = np.random.choice(treatment, size=len(treatment), replace=True)
            boot_diffs.append(treatment_boot.mean() - control_boot.mean())
        
        boot_diffs = np.array(boot_diffs)
        
        alpha_bootstrap = 1 - confidence_level
        ci_lower = np.percentile(boot_diffs, alpha_bootstrap/2 * 100)
        ci_upper = np.percentile(boot_diffs, (1 - alpha_bootstrap/2) * 100)
        
        observed_diff = treatment.mean() - control.mean()

        significant = not (ci_lower <= 0 <= ci_upper)
        
        return {
            'metric': metric_name,
            'test_type': 'bootstrap',
            'observed_diff': observed_diff,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'significant': significant,
            'confidence_level': confidence_level,
            'n_bootstrap': n_bootstrap
        }
    
    def multiple_testing_correction(self,
                                    p_values: List[float],
                                    method: str = 'holm') -> Dict:

        reject, pvals_corrected, alphacSidak, alphacBonf = multipletests(
            p_values, 
            alpha=self.alpha, 
            method=method
        )
        
        fwer_uncorrected = 1 - (1 - self.alpha) ** len(p_values)
        
        return {
            'method': method,
            'original_pvalues': p_values,
            'corrected_pvalues': pvals_corrected.tolist(),
            'reject': reject.tolist(),
            'fwer_uncorrected': fwer_uncorrected,
            'num_tests': len(p_values),
            'num_significant_uncorrected': sum(p < self.alpha for p in p_values),
            'num_significant_corrected': sum(reject)
        }



In [4]:
import pandas as pd
from scipy import stats

def drive_ab_analysis(df, analyzer):
    print("DRIVER STARTING...\n")
    
    # 1. Setup Data
    try:
        col_list = df.columns.tolist()
        var_idx = col_list.index('variant')
        metric_cols = col_list[var_idx+1:] 
        
        variants = sorted(df['variant'].unique())
        n_variants = len(variants)
        
        print(f"Control/Variants: {variants}")
        print(f"Metrics to analyze: {metric_cols}\n")
        
    except Exception as e:
        print(f"Error Setup: {e}")
        return

    # List Penampung Sementara
    temp_results = []
    raw_pvalues = []

    # 2. FASE 1: HITUNG STATISTIK MENTAH (Belum print tabel)
    for metric in metric_cols:
        
        # --- JALUR A: KASUS > 2 VARIAN ---
        if n_variants > 2:
            unique_vals = df[metric].dropna().unique()
            is_binary = set(unique_vals).issubset({0, 1, 0.0, 1.0})
            
            if is_binary:
                test_name = 'Chi-Square (K>2)'
                contingency = pd.crosstab(df['variant'], df[metric])
                chi2, p_val, dof, _ = stats.chi2_contingency(contingency)
                
                result = {
                    'metric': metric,
                    'test_type': 'chi_square_k_variants',
                    'statistic': chi2,
                    'pvalue': p_val,
                    'dof': dof
                }
            else:
                test_name = 'Kruskal-Wallis'
                groups = [df[df['variant'] == v][metric].dropna() for v in variants]
                stat, p_val = stats.kruskal(*groups)
                
                result = {
                    'metric': metric,
                    'test_type': 'kruskal_wallis',
                    'statistic': stat,
                    'pvalue': p_val
                }
            
            # Simpan hasil sementara
            temp_results.append({
                'metric': metric,
                'test_name': test_name,
                'p_raw': p_val,
                'main_result': result,
                'boot_result': None
            })
            raw_pvalues.append(p_val)

        # --- JALUR B: KASUS 2 VARIAN (A/B) ---
        else:
            control_name, treat_name = variants[0], variants[1]
            data_c = df[df['variant'] == control_name][metric]
            data_t = df[df['variant'] == treat_name][metric]
            
            unique_vals = df[metric].dropna().unique()
            is_binary = set(unique_vals).issubset({0, 1, 0.0, 1.0})

            if is_binary:
                test_name = 'Proportion Test'
                result = analyzer.proportion_test(
                    control_successes=data_c.sum(), 
                    control_total=len(data_c), 
                    treatment_successes=data_t.sum(), 
                    treatment_total=len(data_t), 
                    metric_name=metric
                )
            else:
                test_name = 'Mann-Whitney U'
                result = analyzer.mann_whitney_u_test(
                    control=data_c, 
                    treatment=data_t, 
                    metric_name=metric
                )

            # Bootstrap
            boot_res = analyzer.bootstrap_confidence_interval(data_c, data_t, metric)

            # Simpan hasil sementara
            temp_results.append({
                'metric': metric,
                'test_name': test_name,
                'p_raw': result['pvalue'],
                'main_result': result,
                'boot_result': boot_res
            })
            raw_pvalues.append(result['pvalue'])

    # 3. FASE 2: KOREKSI P-VALUE (Multiple Testing Correction)
    # Ini langkah penting yang kamu minta
    correction_res = analyzer.multiple_testing_correction(raw_pvalues, method='holm')
    corrected_pvals = correction_res['corrected_pvalues']
    reject_decisions = correction_res['reject'] # True = Signifikan, False = Tidak

    # 4. FASE 3: PRINT TABEL SUMMARY
    print("="*100)
    print(f"{'METRIC':<25} | {'TEST TYPE':<18} | {'RAW P-VAL':<10} | {'CORR P-VAL':<10} | {'SIG (CORR)?'}")
    print("-" * 100)
    
    for i, item in enumerate(temp_results):
        metric = item['metric']
        test_type = item['test_name']
        p_raw = item['p_raw']
        
        # Ambil data hasil koreksi
        p_corr = corrected_pvals[i]
        is_sig = "YES" if reject_decisions[i] else "NO"
        
        print(f"{metric:<25} | {test_type:<18} | {p_raw:.5f}    | {p_corr:.5f}     | {is_sig}")
    
    print("="*100 + "\n")

    # 5. FASE 4: PRINT DETAIL RETURN (Raw Dictionary)
    print("DETAILED RETURN VALUES (RAW DICTIONARY)")
    print("="*80)

    for item in temp_results:
        print(f"METRIC: {item['metric']}")
        print("-" * 40)
        
        print("Function Return (Main Test):")
        for key, value in item['main_result'].items():
            print(f"   {key:<20} : {value}")
            
        if item['boot_result'] is not None:
            print("\nFunction Return (Bootstrap):")
            for key, value in item['boot_result'].items():
                print(f"   {key:<20} : {value}")
        else:
             print("\nFunction Return (Bootstrap):")
             print("   Skipped (Not applicable for >2 variants)")
            
        print("\n" + "-"*80 + "\n")

## Execution

In [None]:
import pandas as pd
import os

analyzer = ABTestAnalyzer()

files_to_test = [
    'raw_dataset/test1_menu.csv',
    'raw_dataset/test2_novelty_slider.csv',
    'raw_dataset/test3_product_sliders.csv',
    'raw_dataset/test4_reviews.csv',
    'raw_dataset/test5_search_engine.csv'
]

print("BATCH ANALYSIS STARTED...\n")

# 3. Loop
for filepath in files_to_test:
    print("#" * 100)
    print(f"PROCESSING FILE: {filepath}")
    print("#" * 100)
    
    if os.path.exists(filepath):
        # Load Data
        df_loop = pd.read_csv(filepath)

        drive_ab_analysis(df_loop, analyzer)
        
    else:
        print(f"File not found: {filepath}. Skipping...")
    
    print("\n" + " "*40 + "--- END OF FILE ---\n\n")

print("ALL ANALYSES COMPLETED.")

BATCH ANALYSIS STARTED...

####################################################################################################
PROCESSING FILE: raw_dataset/test1_menu.csv
####################################################################################################
DRIVER STARTING...

Control/Variants: ['A_horizontal_menu', 'B_dropdown_menu']
Metrics to analyze: ['pages_viewed', 'added_to_cart', 'bounced', 'revenue']

METRIC                    | TEST TYPE          | RAW P-VAL  | CORR P-VAL | SIG (CORR)?
----------------------------------------------------------------------------------------------------
pages_viewed              | Mann-Whitney U     | 0.06748    | 0.13497     | NO
added_to_cart             | Proportion Test    | 0.00000    | 0.00000     | YES
bounced                   | Proportion Test    | 0.33544    | 0.33544     | NO
revenue                   | Mann-Whitney U     | 0.00000    | 0.00000     | YES

DETAILED RETURN VALUES (RAW DICTIONARY)
METRIC: pages_viewed
---