# Requirements

In [None]:
!pip3 install gradio



# Data Manager

## Config Loader

In [2]:
from yaml import safe_load
from typing import Union, Optional, List
import pandas as pd
from pathlib import Path
from json import load as json_load

class ConfigLoader:
    """
    Handles loading configuration files for data analysis.
    Supports YAML and JSON formats.

    Reads configuration data from a specified file path and provides a dictionary-like interface to access configuration values.
    """

    def __init__(self):
        self.config: Optional[dict] = None
        self.file_path: Optional[Union[str, Path]] = None
        self.file_type: Optional[str] = None

    def load(self, file_path: Union[str, Path]) -> None:
        """
        Load configuration from a file.

        Parameters
        ----------
        file_path : str or Path
            Path to the configuration file (YAML or JSON).

        Raises
        ------
        ValueError
            If the file format is unsupported or if the file cannot be read.
        """
        self.file_path = Path(file_path)
        ext = self.file_path.suffix.lower()

        if ext in ['.yaml', '.yml']:
            self.file_type = 'yaml'
            with open(self.file_path, 'r') as f:
                self.config = safe_load(f)
        elif ext in ['.json']:
            self.file_type = 'json'
            with open(self.file_path, 'r') as f:
                self.config = json_load(f)
        else:
            raise ValueError(f"Unsupported file format: {ext}. Supported formats are YAML and JSON.")
        if not isinstance(self.config, dict):
            raise ValueError("Configuration file must contain a valid dictionary.")
        if not self.config:
            raise ValueError("Configuration file is empty.")

    def get_configs(self, file_path: Optional[Union[str, Path]] = None) -> dict:
        """
        Get the loaded configuration data.

        Parameters
        ----------
        file_path : str or Path, optional
            If provided, will load the configuration from this file instead of the previously loaded one.

        Returns
        -------
        dict
            The loaded configuration data.

        Raises
        ------
        ValueError
            If no configuration has been loaded.
        """
        if file_path:
            self.load(file_path)

        if self.config is None:
            raise ValueError("No configuration has been loaded.")

        return self.config

## Data Loader

In [3]:
import pandas as pd
import os
import logging
from typing import List, Tuple, Union, Optional
from pathlib import Path

# Configure the root logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

class DataLoader:
    """
    Loads and manages reference and comparison datasets for analysis.
    """

    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.df_ref: Optional[pd.DataFrame] = None
        self.df_cmp: Optional[pd.DataFrame] = None
        self.common_columns: List[str] = []
        self.logger.info("Initialized DataLoader instance.")

    def load(
        self,
        ref_path: Union[str, Path],
        cmp_path: Union[str, Path],
        ref_sheet_name: Optional[str] = None,
        cmp_sheet_name: Optional[str] = None
    ) -> None:
        """
        Load datasets from specified file paths.

        Parameters
        ----------
        ref_path : str or Path
            Path to the reference dataset file (CSV, Parquet or Excel).
        cmp_path : str or Path
            Path to the comparison dataset file (CSV, Parquet or Excel).
        ref_sheet_name : str, optional
            Name of the sheet to read (for Excel files).
        cmp_sheet_name : str, optional
            Name of the sheet to read (for Excel files).

        Raises
        ------
        ValueError
            If file extensions are unsupported or if there are no common columns.
        """
        self.logger.info(f"Starting load: ref_path={ref_path}, cmp_path={cmp_path}")

        # Read reference file
        self.logger.info(f"Reading reference dataset from: {ref_path}")
        self.df_ref = self._read_file(ref_path, sheet_name=ref_sheet_name)
        if self.df_ref is None:
            self.logger.error(f"Failed to load reference dataset from {ref_path}")
            raise ValueError(f"Failed to load reference dataset from {ref_path}")
        if not isinstance(self.df_ref, pd.DataFrame):
            self.logger.error(f"Reference dataset at {ref_path} is not a valid DataFrame.")
            raise ValueError(f"Reference dataset at {ref_path} is not a valid DataFrame.")
        self.logger.info(f"Successfully loaded reference dataset ({len(self.df_ref)} rows, {len(self.df_ref.columns)} columns).")

        # Read comparison file
        self.logger.info(f"Reading comparison dataset from: {cmp_path}")
        self.df_cmp = self._read_file(cmp_path, sheet_name=cmp_sheet_name)
        if self.df_cmp is None:
            self.logger.error(f"Failed to load comparison dataset from {cmp_path}")
            raise ValueError(f"Failed to load comparison dataset from {cmp_path}")
        if not isinstance(self.df_cmp, pd.DataFrame):
            self.logger.error(f"Comparison dataset at {cmp_path} is not a valid DataFrame.")
            raise ValueError(f"Comparison dataset at {cmp_path} is not a valid DataFrame.")
        self.logger.info(f"Successfully loaded comparison dataset ({len(self.df_cmp)} rows, {len(self.df_cmp.columns)} columns).")

        # Determine common columns
        self.common_columns = list(set(self.df_ref.columns).intersection(self.df_cmp.columns))
        if not self.common_columns:
            self.logger.error("No common columns found between reference and comparison datasets.")
            raise ValueError("No common columns between reference and comparison datasets.")

        self.logger.info(f"Common columns ({len(self.common_columns)}): {self.common_columns}")

    def _read_file(
        self,
        path: Union[str, Path],
        sheet_name: Optional[str] = None
    ) -> Optional[pd.DataFrame]:
        """
        Read a DataFrame from a file path, inferring format from the extension.

        Parameters
        ----------
        path : str or Path
            Path to the dataset file.
        sheet_name : str, optional
            Name of the sheet to read (for Excel files).

        Returns
        -------
        pd.DataFrame or None
            Loaded DataFrame, or None if reading failed.
        """
        ext = Path(path).suffix.lower()
        self.logger.info(f"Attempting to read file '{path}' with extension '{ext}'")
        try:
            if ext == '.csv':
                df = pd.read_csv(path)
                self.logger.info(f"Read CSV file: {path}")
                return df
            elif ext in ('.parquet', '.parq', '.pq'):
                df = pd.read_parquet(path)
                self.logger.info(f"Read Parquet file: {path}")
                return df
            elif ext in ('.xlsx', '.xls', '.xlsm'):
                if sheet_name is None:
                    df = pd.read_excel(path)
                    self.logger.info(f"Read Excel file (default sheet): {path}")
                else:
                    df = pd.read_excel(path, sheet_name=sheet_name)
                    self.logger.info(f"Read Excel file (sheet='{sheet_name}'): {path}")
                return df
            else:
                error_msg = f"Unsupported file format: {ext}"
                self.logger.error(error_msg)
                raise ValueError(error_msg)
        except Exception as e:
            self.logger.error(f"Error reading file '{path}': {e}")
            return None

    def get_data(
        self,
        ref_path: Optional[Union[str, Path]] = None,
        cmp_path: Optional[Union[str, Path]] = None,
        ref_sheet_name: Optional[str] = None,
        cmp_sheet_name: Optional[str] = None
    ) -> Tuple[pd.DataFrame, pd.DataFrame, List[str]]:
        """
        Return the loaded reference and comparison DataFrames (and common columns).

        Parameters
        ----------
        ref_path : str or Path, optional
            Path to the reference dataset file (if not already loaded).
        cmp_path : str or Path, optional
            Path to the comparison dataset file (if not already loaded).

        Returns
        -------
        tuple(pd.DataFrame, pd.DataFrame, List[str])
            (reference DataFrame, comparison DataFrame, list of common columns)
        """
        if ref_path and cmp_path:
            self.logger.info("Paths provided to get_data; invoking load()")
            self.load(
                ref_path=ref_path,
                cmp_path=cmp_path,
                ref_sheet_name=ref_sheet_name,
                cmp_sheet_name=cmp_sheet_name
            )
        if self.df_ref is None:
            self.logger.error("Reference dataset is not loaded.")
            raise ValueError("Reference dataset is not loaded.")
        if self.df_cmp is None:
            self.logger.error("Comparison dataset is not loaded.")
            raise ValueError("Comparison dataset is not loaded.")
        if not self.common_columns:
            self.logger.error("No common columns identified; cannot proceed.")
            raise ValueError("No common columns between reference and comparison datasets.")

        self.logger.info("Returning loaded data and common columns.")
        return self.df_ref, self.df_cmp, self.common_columns


## DataFrame Schema Enforcer

In [None]:
import pandas as pd
import numpy as np
import re
import datetime
import logging
from decimal import Decimal
from dateutil.parser import parse as parse_date

# Configure the root logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

class DataFrameSchemaEnforcer:
    """
    Enforces column types on a DataFrame based on a user-provided schema.

    schema: dict mapping column names to target types, where target types can be:
      - Python types: int, float, str, bool, Decimal, datetime.datetime
      - Strings: 'int', 'float', 'string', 'boolean', 'datetime', 'decimal', 'category'
    """
    _NUM_RE = re.compile(r"^\s*([+-]?\d+(?:\.\d+)?)([kmbtKMBT])?\s*$")
    _SUFFIX_FACTORS = {
        'k': 1e3,
        'm': 1e6,
        'b': 1e9,
        't': 1e12,
    }

    def __init__(self, schema: dict):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.schema = schema
        self.handlers = {
            int: self._handle_int,
            'int': self._handle_int,
            float: self._handle_float,
            'float': self._handle_float,
            str: self._handle_str,
            'string': self._handle_str,
            'text': self._handle_str,
            bool: self._handle_bool,
            'boolean': self._handle_bool,
            datetime.datetime: self._handle_datetime,
            'datetime': self._handle_datetime,
            Decimal: self._handle_decimal,
            'decimal': self._handle_decimal,
            'category': self._handle_category,
        }
        self.logger.info(f"Initialized DataFrameSchemaEnforcer with schema: {self.schema}")

    def enforce(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Return a new DataFrame with columns cast according to the schema.
        """
        self.logger.info("Starting schema enforcement on DataFrame.")
        result = df.copy()
        for col, tgt in self.schema.items():
            if col not in result.columns:
                self.logger.warning(f"Column '{col}' not found in DataFrame; skipping.")
                continue

            handler = self.handlers.get(tgt)
            if handler is None:
                self.logger.warning(f"No handler for target type '{tgt}' on column '{col}'; skipping.")
                continue

            self.logger.info(f"Enforcing type '{tgt}' on column '{col}'.")
            try:
                result[col] = handler(result[col])
                self.logger.info(f"Column '{col}' successfully cast to '{tgt}'.")
            except Exception as e:
                self.logger.error(f"Failed to cast column '{col}' to '{tgt}': {e}")
                raise

        self.logger.info("Schema enforcement complete.")
        return result

    def _handle_numeric(self, series: pd.Series) -> pd.Series:
        """Parse numeric strings with commas and suffixes to floats."""
        self.logger.info("Parsing numeric values.")
        def parse(val):
            if pd.isna(val):
                return np.nan
            s = str(val).replace(',', '').strip()
            m = self._NUM_RE.match(s)
            if m:
                num = float(m.group(1))
                suf = m.group(2)
                if suf:
                    num *= self._SUFFIX_FACTORS[suf.lower()]
                return num
            try:
                return float(s)
            except Exception:
                return np.nan

        parsed_series = series.map(parse)
        self.logger.info("Numeric parsing complete.")
        return parsed_series

    def _handle_int(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to integer ('Int64').")
        floats = self._handle_numeric(series)
        try:
            int_series = floats.round(0).astype('Int64')
            self.logger.info("Integer casting complete.")
            return int_series
        except Exception as e:
            self.logger.error(f"Error casting to integer: {e}")
            raise

    def _handle_float(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to float.")
        try:
            float_series = self._handle_numeric(series).astype(float)
            self.logger.info("Float casting complete.")
            return float_series
        except Exception as e:
            self.logger.error(f"Error casting to float: {e}")
            raise

    def _handle_str(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to string.")
        try:
            str_series = series.astype(str)
            self.logger.info("String casting complete.")
            return str_series
        except Exception as e:
            self.logger.error(f"Error casting to string: {e}")
            raise

    def _handle_bool(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to boolean.")
        def parse(val):
            if pd.isna(val):
                return pd.NA
            s = str(val).strip().lower()
            if s in ('true','1','yes','y','t'):
                return True
            if s in ('false','0','no','n','f'):
                return False
            return pd.NA

        try:
            bool_series = series.map(parse).astype('boolean')
            self.logger.info("Boolean casting complete.")
            return bool_series
        except Exception as e:
            self.logger.error(f"Error casting to boolean: {e}")
            raise

    def _handle_datetime(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to datetime.")
        try:
            dt_series = pd.to_datetime(series, errors='coerce', infer_datetime_format=True)
            self.logger.info("Datetime casting complete.")
            return dt_series
        except Exception as e:
            self.logger.error(f"Error casting to datetime: {e}")
            raise

    def _handle_decimal(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to Decimal.")
        def parse(val):
            if pd.isna(val):
                return None
            s = str(val).replace(',', '').strip()
            try:
                return Decimal(s)
            except Exception:
                return None

        try:
            dec_series = series.map(parse)
            self.logger.info("Decimal casting complete.")
            return dec_series
        except Exception as e:
            self.logger.error(f"Error casting to Decimal: {e}")
            raise

    def _handle_category(self, series: pd.Series) -> pd.Series:
        self.logger.info("Casting series to category.")
        try:
            cat_series = series.astype('category')
            self.logger.info("Category casting complete.")
            return cat_series
        except Exception as e:
            self.logger.error(f"Error casting to category: {e}")
            raise

## Data Manager

In [4]:
import pandas as pd
import logging
# from src.data_management.config_loader import ConfigLoader
# from src.data_management.data_loader import DataLoader
# from src.data_management.schema_enforcer import DataFrameSchemaEnforcer

# Configure the root logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

class DataManager:
    """
    Manages data loading and configuration for data analysis.

    This class integrates the functionality of loading configuration files and datasets,
    providing a unified interface for data management tasks.
    """

    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.info("Initializing DataManager.")
        self.config_loader = ConfigLoader()
        self.data_loader = DataLoader()

    def load_config(self, file_path: str) -> dict:
        """
        Load configuration from a specified file path.

        Parameters
        ----------
        file_path : str
            Path to the configuration file (YAML or JSON).

        Returns
        -------
        dict
            Loaded configuration data.
        """
        self.logger.info(f"Loading configuration from: {file_path}")
        try:
            self.config_loader.get_configs(file_path)
            config = self.config_loader.config
            self.logger.info(f"Configuration loaded successfully from: {file_path}")
            return config
        except Exception as e:
            self.logger.error(f"Failed to load configuration from {file_path}: {e}")
            raise

    def load_data(
        self,
        ref_path: str,
        cmp_path: str,
        ref_sheet_name: str = None,
        cmp_sheet_name: str = None
    ) -> tuple[pd.DataFrame, pd.DataFrame, list[str]]:
        """
        Load datasets from specified file paths.

        Parameters
        ----------
        ref_path : str
            Path to the reference dataset file (CSV, Parquet or Excel).
        cmp_path : str
            Path to the comparison dataset file (CSV, Parquet or Excel).
        ref_sheet_name : str, optional
            Name of the sheet to read for the reference dataset (for Excel files).
        cmp_sheet_name : str, optional
            Name of the sheet to read for the comparison dataset (for Excel files).

        Returns
        -------
        tuple[pd.DataFrame, pd.DataFrame, list[str]]
            Reference dataset, comparison dataset, and list of common columns.

        Raises
        ------
        ValueError
            If file extensions are unsupported or if there are no common columns.
        """
        self.logger.info(
            f"Loading data: ref_path={ref_path}, cmp_path={cmp_path}, "
            f"ref_sheet_name={ref_sheet_name}, cmp_sheet_name={cmp_sheet_name}"
        )
        try:
            df_ref, df_cmp, common_columns = self.data_loader.get_data(
                ref_path, cmp_path, ref_sheet_name, cmp_sheet_name
            )
            self.logger.info(
                f"Data loaded successfully: "
                f"ref_rows={len(df_ref)}, ref_cols={len(df_ref.columns)}; "
                f"cmp_rows={len(df_cmp)}, cmp_cols={len(df_cmp.columns)}; "
                f"common_columns={common_columns}"
            )
            return df_ref, df_cmp, common_columns
        except Exception as e:
            self.logger.error(f"Failed to load data: {e}")
            raise

    def enforce_schema(self, df: pd.DataFrame, schema: dict) -> pd.DataFrame:
        """
        Enforce a schema on a DataFrame.

        Parameters
        ----------
        df : pd.DataFrame
            DataFrame to enforce the schema on.
        schema : dict
            Schema definition to enforce.

        Returns
        -------
        pd.DataFrame
            DataFrame with enforced schema.

        Raises
        ------
        ValueError
            If the schema enforcement fails.
        """
        self.logger.info("Enforcing schema on DataFrame.")
        try:
            enforcer = DataFrameSchemaEnforcer(schema)
            df_enforced = enforcer.enforce(df)
            self.logger.info("Schema enforcement successful.")
            return df_enforced
        except Exception as e:
            self.logger.error(f"Schema enforcement failed: {e}")
            raise

    def pre_process_data(
        self,
        df: pd.DataFrame,
        schema: dict
    ) -> pd.DataFrame:
        """
        Pre-process a DataFrame by enforcing a schema.

        Parameters
        ----------
        df : pd.DataFrame
            DataFrame to pre-process.
        schema : dict
            Schema definition to enforce.

        Returns
        -------
        pd.DataFrame
            Pre-processed DataFrame with enforced schema.

        Raises
        ------
        ValueError
            If the pre-processing fails.
        """
        self.logger.info("Pre-processing DataFrame.")
        try:
            df_preprocessed = self.enforce_schema(df, schema)
            self.logger.info("Pre-processing completed successfully.")
            return df_preprocessed
        except Exception as e:
            self.logger.error(f"Pre-processing failed: {e}")
            raise

    def get_ref_and_cmp_data(
        self,
        ref_path: str,
        cmp_path: str,
        ref_data_schema_path: str,
        cmp_data_schema_path: str,
        ref_sheet_name: str = None,
        cmp_sheet_name: str = None,
    ) -> tuple[pd.DataFrame, pd.DataFrame, list[str], dict, dict]:
        """
        Load reference and comparison datasets, enforce schemas, and return them with common columns.

        Parameters
        ----------
        ref_path : str
            Path to the reference dataset file.
        cmp_path : str
            Path to the comparison dataset file.
        ref_data_schema_path : str
            Path to the reference data schema file (JSON/YAML).
        cmp_data_schema_path : str
            Path to the comparison data schema file (JSON/YAML).
        ref_sheet_name : str, optional
            Name of the sheet for the reference dataset (if applicable).
        cmp_sheet_name : str, optional
            Name of the sheet for the comparison dataset (if applicable).

        Returns
        -------
        tuple[pd.DataFrame, pd.DataFrame, list[str], dict, dict]
            Tuple containing:
            - Reference DataFrame (schema enforced)
            - Comparison DataFrame (schema enforced)
            - List of common columns
            - Reference schema dict
            - Comparison schema dict
        """
        self.logger.info("Starting full data retrieval and schema enforcement process.")
        try:
            # Load raw data
            df_ref, df_cmp, common_columns = self.load_data(
                ref_path, cmp_path, ref_sheet_name, cmp_sheet_name
            )

            # Load schemas
            self.logger.info(f"Loading reference schema from: {ref_data_schema_path}")
            ref_schema = self.load_config(ref_data_schema_path)
            self.logger.info(f"Loading comparison schema from: {cmp_data_schema_path}")
            cmp_schema = self.load_config(cmp_data_schema_path)

            # Enforce schemas
            self.logger.info("Enforcing schema on reference DataFrame.")
            df_ref_enforced = self.pre_process_data(df_ref, ref_schema)
            self.logger.info("Enforcing schema on comparison DataFrame.")
            df_cmp_enforced = self.pre_process_data(df_cmp, cmp_schema)

            self.logger.info(
                f"Finished processing reference and comparison data. "
                f"Returning data with {len(common_columns)} common columns."
            )
            return df_ref_enforced, df_cmp_enforced, common_columns, ref_schema, cmp_schema

        except Exception as e:
            self.logger.error(f"get_ref_and_cmp_data failed: {e}")
            raise

## Version 2

## Imports

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from collections import Counter
from itertools import combinations
from scipy.stats import gaussian_kde, entropy
from scipy.stats import ks_2samp, wasserstein_distance
from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA, NMF
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from abc import ABC, abstractmethod
import gradio as gr
import logging

# ─── Logging Configuration ────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

# Make sure DataManager is importable. If running in Colab, you can either just put the DataManager code in a cell above or do:
# from google.colab import drive
# drive.mount('/content/drive')
# import sys
# sys.path.append('/content/drive/MyDrive/your_project_folder')
# from data_manager import DataManager

from src.data_management.data_manager import DataManager

## Report Class

In [None]:
class Report:
    def __init__(
        self,
        ref_df: pd.DataFrame,
        cmp_df: pd.DataFrame,
        common_columns: list,
        ref_schema: dict,
        cmp_schema: dict
    ):
        logger.info("Initializing Report object.")
        self.ref_df = ref_df.loc[:, common_columns].reset_index(drop=True)
        self.cmp_df = cmp_df.loc[:, common_columns].reset_index(drop=True)
        self.common_columns = common_columns
        self.ref_schema = ref_schema
        self.cmp_schema = cmp_schema

        logger.info(f"Common columns: {common_columns}")
        # Determine effective column types using schema + heuristics
        self.column_types = {}
        for col in common_columns:
            declared = ref_schema.get(col)
            logger.info(f"Determining type for column '{col}' (declared: {declared}).")
            if declared == 'category':
                self.column_types[col] = 'categorical'
            elif declared in (bool, 'boolean'):
                self.column_types[col] = 'categorical'
            elif declared in (int, 'int', float, 'float', 'decimal'):
                self.column_types[col] = 'numeric'
            elif declared in (pd.Timestamp, 'datetime', 'datetime64[ns]'):
                self.column_types[col] = 'datetime'
            elif declared in (str, 'string', 'text'):
                non_na = self.ref_df[col].dropna().astype(str)
                if len(non_na) == 0:
                    self.column_types[col] = 'categorical'
                else:
                    unique_ratio = non_na.nunique() / len(non_na)
                    mean_tokens = non_na.str.split().apply(len).mean()
                    if unique_ratio <= 0.20 and mean_tokens <= 5:
                        self.column_types[col] = 'categorical'
                    else:
                        self.column_types[col] = 'text'
            else:
                dtype = self.ref_df[col].dtype
                if pd.api.types.is_numeric_dtype(dtype):
                    self.column_types[col] = 'numeric'
                elif pd.api.types.is_datetime64_any_dtype(dtype):
                    self.column_types[col] = 'datetime'
                else:
                    self.column_types[col] = 'categorical'
            logger.info(f"  → Column '{col}' determined to be '{self.column_types[col]}'.")

        # Placeholder for per‐column metrics
        self.per_column_metrics = {col: {} for col in common_columns}
        logger.info("Report initialization complete.")


    @staticmethod
    def _js_divergence(p: np.ndarray, q: np.ndarray) -> float:
        logger.info("Computing Jensen–Shannon divergence.")
        p = p + 1e-12
        q = q + 1e-12
        p = p / p.sum()
        q = q / q.sum()
        m = 0.5 * (p + q)
        result = 0.5 * entropy(p, m, base=2) + 0.5 * entropy(q, m, base=2)
        logger.info(f"  → JS divergence = {result:.6f}")
        return result


    def numeric_ks_similarity(self, col: str) -> float:
        logger.info(f"Computing KS similarity for numeric column '{col}'.")
        real_vals = self.ref_df[col].dropna().astype(float).values
        syn_vals  = self.cmp_df[col].dropna().astype(float).values
        logger.info(f"  → Samples: real={real_vals.size}, syn={syn_vals.size}")
        if real_vals.size < 2 or syn_vals.size < 2:
            logger.info("  → Not enough data points for KS: returning NaN.")
            return np.nan
        D_stat, _ = ks_2samp(real_vals, syn_vals)
        sim = 1.0 - D_stat
        logger.info(f"  → KS D-statistic = {D_stat:.6f}, similarity = {sim:.6f}")
        return sim


    def numeric_wasserstein(self, col: str) -> float:
        logger.info(f"Computing Wasserstein distance for numeric column '{col}'.")
        real_vals = self.ref_df[col].dropna().astype(float).values
        syn_vals  = self.cmp_df[col].dropna().astype(float).values
        logger.info(f"  → Samples: real={real_vals.size}, syn={syn_vals.size}")
        if real_vals.size < 2 or syn_vals.size < 2:
            logger.info("  → Not enough data points for Wasserstein: returning NaN.")
            return np.nan, np.nan

        Wraw = wasserstein_distance(real_vals, syn_vals)
        logger.info(f"  → Raw Wasserstein = {Wraw:.6f}")
        rmin, rmax = real_vals.min(), real_vals.max()
        rng = rmax - rmin
        if rng == 0:
            logger.info("  → Real range is zero; normalized Wasserstein is NaN.")
            return Wraw, np.nan

        Wnorm = Wraw / rng
        logger.info(f"  → Normalized Wasserstein = {Wnorm:.6f}")
        return Wraw, Wnorm


    def numeric_summary_stats(self, col: str) -> dict:
        logger.info(f"Computing detailed summary‐stats for numeric column '{col}'.")
        real_vals = self.ref_df[col].dropna().astype(float)
        syn_vals  = self.cmp_df[col].dropna().astype(float)
        if real_vals.empty or syn_vals.empty:
            logger.info("  → One side is empty: returning all NaN stats.")
            return {
                'mean_real': np.nan, 'mean_syn': np.nan, 'mean_diff': np.nan, 'mean_pct_of_range': np.nan,
                'median_real': np.nan, 'median_syn': np.nan, 'median_diff': np.nan, 'median_pct_of_range': np.nan,
                'std_real': np.nan, 'std_syn': np.nan, 'std_diff': np.nan, 'std_pct_of_std': np.nan
            }

        rmin, rmax = real_vals.min(), real_vals.max()
        data_range = rmax - rmin
        logger.info(f"  → Real min={rmin:.4f}, max={rmax:.4f}, range={data_range:.4f}")

        mean_r = real_vals.mean()
        mean_s = syn_vals.mean()
        diff_mean = mean_s - mean_r
        pct_mean = diff_mean / data_range if data_range != 0 else np.nan
        logger.info(f"  → Mean: real={mean_r:.4f}, syn={mean_s:.4f}, diff={diff_mean:.4f}, pct_range={pct_mean:.6f}")

        med_r = real_vals.median()
        med_s = syn_vals.median()
        diff_med = med_s - med_r
        pct_med = diff_med / data_range if data_range != 0 else np.nan
        logger.info(f"  → Median: real={med_r:.4f}, syn={med_s:.4f}, diff={diff_med:.4f}, pct_range={pct_med:.6f}")

        std_r = real_vals.std(ddof=1)
        std_s = syn_vals.std(ddof=1)
        diff_std = std_s - std_r
        pct_std = diff_std / std_r if std_r != 0 else np.nan
        logger.info(f"  → StdDev: real={std_r:.4f}, syn={std_s:.4f}, diff={diff_std:.4f}, pct_std={pct_std:.6f}")

        return {
            'mean_real': mean_r,
            'mean_syn': mean_s,
            'mean_diff': diff_mean,
            'mean_pct_of_range': pct_mean,
            'median_real': med_r,
            'median_syn': med_s,
            'median_diff': diff_med,
            'median_pct_of_range': pct_med,
            'std_real': std_r,
            'std_syn': std_s,
            'std_diff': diff_std,
            'std_pct_of_std': pct_std
        }


    def numeric_range_coverage(self, col: str) -> float:
        logger.info(f"Computing range coverage for numeric column '{col}'.")
        real_vals = self.ref_df[col].dropna().astype(float)
        syn_vals  = self.cmp_df[col].dropna().astype(float)
        if real_vals.empty or syn_vals.empty:
            logger.info("  → One side is empty: returning NaN.")
            return np.nan
        rmin, rmax = real_vals.min(), real_vals.max()
        rng = rmax - rmin
        if rng == 0:
            logger.info("  → Real range is zero: returning NaN.")
            return np.nan
        smin, smax = syn_vals.min(), syn_vals.max()
        srng = smax - smin
        cov = min(srng / rng, 1.0)
        logger.info(f"  → Synthetic range={srng:.4f}, coverage={cov:.6f}")
        return cov


    # ───── Univariate Methods (unchanged from earlier) ──────────────────────────

    def _compute_numeric_histogram_js(self, col: str) -> float:
        logger.info(f"Computing histogram‐based JS divergence for '{col}'.")
        ref_vals = self.ref_df[col].dropna().astype(float).values
        cmp_vals = self.cmp_df[col].dropna().astype(float).values
        if len(ref_vals) < 2 or len(cmp_vals) < 2:
            logger.info("  → Not enough data for histogram JS: returning NaN.")
            return np.nan

        n_ref = len(ref_vals)
        s_ref = np.std(ref_vals, ddof=1)
        h = 3.49 * s_ref * (n_ref ** (-1 / 3))
        if h <= 0:
            k = int(np.ceil(np.log2(n_ref) + 1))
        else:
            data_min = min(ref_vals.min(), cmp_vals.min())
            data_max = max(ref_vals.max(), cmp_vals.max())
            k = int(np.ceil((data_max - data_min) / h))
            if k < 2:
                k = int(np.ceil(np.log2(n_ref) + 1))
        k = max(min(k, 100), 2)
        logger.info(f"  → Number of bins for histogram: k={k}")

        data_min = min(ref_vals.min(), cmp_vals.min())
        data_max = max(ref_vals.max(), cmp_vals.max())
        bins = np.linspace(data_min, data_max, k + 1)

        p_ref, _ = np.histogram(ref_vals, bins=bins, density=True)
        p_cmp, _ = np.histogram(cmp_vals, bins=bins, density=True)
        bin_width = bins[1] - bins[0]
        p_ref_probs = p_ref * bin_width
        p_cmp_probs = p_cmp * bin_width

        js = Report._js_divergence(p_ref_probs, p_cmp_probs)
        return js


    def _compute_numeric_kde_metrics(self, col: str) -> (float, float):
        logger.info(f"Computing KDE‐based L2 & JS metrics for '{col}'.")
        ref_vals = self.ref_df[col].dropna().astype(float).values
        cmp_vals = self.cmp_df[col].dropna().astype(float).values
        if len(ref_vals) < 2 or len(cmp_vals) < 2:
            logger.info("  → Not enough data for KDE: returning (NaN, NaN).")
            return np.nan, np.nan

        pooled = np.concatenate([ref_vals, cmp_vals])
        bw = np.std(pooled, ddof=1) * (len(pooled) ** (-1/5))
        if bw <= 0:
            bw = 1.0
        logger.info(f"  → Bandwidth for KDE (Silverman estimate) = {bw:.6f}")

        try:
            kde_ref = gaussian_kde(ref_vals, bw_method=bw / np.std(ref_vals, ddof=1))
            kde_cmp = gaussian_kde(cmp_vals, bw_method=bw / np.std(cmp_vals, ddof=1))
        except np.linalg.LinAlgError:
            logger.info("  → KDE failed (singular covariance): returning (NaN, NaN).")
            return np.nan, np.nan

        data_min = pooled.min()
        data_max = pooled.max()
        grid = np.linspace(data_min, data_max, 512)
        pdf_ref = kde_ref(grid)
        pdf_cmp = kde_cmp(grid)

        dx = grid[1] - grid[0]
        pdf_ref_norm = pdf_ref / (np.trapz(pdf_ref, grid))
        pdf_cmp_norm = pdf_cmp / (np.trapz(pdf_cmp, grid))

        l2 = np.sqrt(np.trapz((pdf_ref_norm - pdf_cmp_norm) ** 2, grid))
        logger.info(f"  → KDE L2 distance = {l2:.6f}")

        p = pdf_ref_norm * dx
        q = pdf_cmp_norm * dx
        js = Report._js_divergence(p, q)
        return l2, js


    def categorical_tvd_similarity(self, col: str) -> float:
        logger.info(f"Computing TVD similarity for categorical column '{col}'.")
        real_counts = self.ref_df[col].dropna().value_counts(normalize=True)
        syn_counts  = self.cmp_df[col].dropna().value_counts(normalize=True)
        all_cats = set(real_counts.index).union(syn_counts.index)
        tvd = 0.0
        for cat in all_cats:
            p = real_counts.get(cat, 0.0)
            q = syn_counts.get(cat, 0.0)
            tvd += abs(p - q)
        tvd *= 0.5
        sim = 1.0 - tvd
        logger.info(f"  → TVD = {tvd:.6f}, similarity = {sim:.6f}")
        return sim


    def categorical_coverage(self, col: str) -> float:
        logger.info(f"Computing category coverage for '{col}'.")
        real_uniques = set(self.ref_df[col].dropna().unique())
        syn_uniques  = set(self.cmp_df[col].dropna().unique())
        if not real_uniques:
            logger.info("  → No real categories: returning NaN.")
            return np.nan
        covered = len(real_uniques.intersection(syn_uniques))
        coverage = covered / len(real_uniques)
        logger.info(f"  → Coverage = {coverage:.6f} ({covered}/{len(real_uniques)})")
        return coverage


    def _compute_categorical_js(self, col: str) -> float:
        logger.info(f"Computing JS divergence on categorical frequencies for '{col}'.")
        ref_counts = self.ref_df[col].dropna().value_counts(normalize=True)
        cmp_counts = self.cmp_df[col].dropna().value_counts(normalize=True)
        all_cats = sorted(set(ref_counts.index).union(cmp_counts.index))
        p = np.array([ref_counts.get(cat, 0.0) for cat in all_cats], dtype=float)
        q = np.array([cmp_counts.get(cat, 0.0) for cat in all_cats], dtype=float)
        js = Report._js_divergence(p, q)
        return js


    def _compute_text_token_js(self, col: str, top_k: int = 5000) -> float:
        logger.info(f"Computing token‐level JS divergence for text column '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side is empty: returning NaN.")
            return np.nan

        tokens_ref = [tok for t in texts_ref for tok in t.split()]
        tokens_cmp = [tok for t in texts_cmp for tok in t.split()]
        ctr_ref = Counter(tokens_ref)
        ctr_cmp = Counter(tokens_cmp)
        combined = ctr_ref + ctr_cmp
        most_common = [tok for tok, _ in combined.most_common(top_k)]

        p = np.array([ctr_ref.get(tok, 0) for tok in most_common], dtype=float)
        q = np.array([ctr_cmp.get(tok, 0) for tok in most_common], dtype=float)
        if p.sum() == 0 or q.sum() == 0:
            logger.info("  → Zero counts on one side: returning NaN.")
            return np.nan
        p = p / p.sum()
        q = q / q.sum()
        js = Report._js_divergence(p, q)
        return js


    def _compute_text_tfidf_cosine(self, col: str, max_features: int = 5000) -> float:
        logger.info(f"Computing TF–IDF cosine for text column '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side is empty: returning NaN.")
            return np.nan

        corpus = texts_ref + texts_cmp
        logger.info(f"  → Fitting TF–IDF on {len(corpus)} documents (max_features={max_features}).")
        vectorizer = TfidfVectorizer(max_features=max_features, stop_words='english', lowercase=True)
        tfidf_all = vectorizer.fit_transform(corpus)
        n_ref = len(texts_ref)
        A_ref = tfidf_all[:n_ref, :]
        A_cmp = tfidf_all[n_ref:, :]

        c_ref = np.asarray(A_ref.mean(axis=0)).ravel()
        c_cmp = np.asarray(A_cmp.mean(axis=0)).ravel()
        denom = np.linalg.norm(c_ref) * np.linalg.norm(c_cmp)
        cosine = float(np.dot(c_ref, c_cmp) / denom) if denom > 0 else np.nan
        logger.info(f"  → TF–IDF cosine = {cosine:.6f}")
        return cosine


    def _compute_text_vocab_jaccard(self, col: str) -> float:
        logger.info(f"Computing vocabulary Jaccard for text column '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side is empty: returning NaN.")
            return np.nan

        tokens_ref = set(tok for t in texts_ref for tok in t.split())
        tokens_cmp = set(tok for t in texts_cmp for tok in t.split())
        if not tokens_ref and not tokens_cmp:
            logger.info("  → Both sets empty: returning NaN.")
            return np.nan
        if not tokens_ref or not tokens_cmp:
            logger.info("  → One set empty: returning 0.0.")
            return 0.0
        inter = tokens_ref.intersection(tokens_cmp)
        union = tokens_ref.union(tokens_cmp)
        jaccard = len(inter) / len(union)
        logger.info(f"  → Vocabulary Jaccard = {jaccard:.6f}")
        return jaccard


    def _compute_document_length_diff(self, col: str) -> float:
        logger.info(f"Computing document‐length difference for text column '{col}'.")
        lens_ref = self.ref_df[col].dropna().astype(str).str.split().apply(len)
        lens_cmp = self.cmp_df[col].dropna().astype(str).str.split().apply(len)
        if lens_ref.empty or lens_cmp.empty:
            logger.info("  → One side is empty: returning NaN.")
            return np.nan
        diff = abs(lens_ref.mean() - lens_cmp.mean())
        logger.info(f"  → Avg length real = {lens_ref.mean():.2f}, syn = {lens_cmp.mean():.2f}, Δ = {diff:.2f}")
        return diff


    def _compute_text_embedding_metrics(
        self,
        col: str,
        max_samples: int = 200
    ) -> (float, float):
        """
        Subsample up to max_samples texts from each side,
        compute embedding‐cosine and MMD on those.
        """
        logger.info(f"Computing embedding metrics for text column '{col}' (max_samples={max_samples}).")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side empty: returning (NaN, NaN).")
            return np.nan, np.nan

        rng = np.random.default_rng(seed=42)
        if len(texts_ref) > max_samples:
            texts_ref = list(rng.choice(texts_ref, max_samples, replace=False))
        if len(texts_cmp) > max_samples:
            texts_cmp = list(rng.choice(texts_cmp, max_samples, replace=False))

        logger.info(f"  → Sampling: real={len(texts_ref)}, syn={len(texts_cmp)}.")
        try:
            model = SentenceTransformer("all-MiniLM-L6-v2")
            emb_ref = model.encode(texts_ref, convert_to_numpy=True)
            emb_cmp = model.encode(texts_cmp, convert_to_numpy=True)
        except Exception as e:
            logger.info(f"  → Error loading/encoding embeddings: {e}. Returning (NaN, NaN).")
            return np.nan, np.nan

        cent_ref = emb_ref.mean(axis=0)
        cent_cmp = emb_cmp.mean(axis=0)
        denom = np.linalg.norm(cent_ref) * np.linalg.norm(cent_cmp)
        emb_cosine = float(np.dot(cent_ref, cent_cmp) / denom) if denom > 0 else np.nan
        logger.info(f"  → Embedding cosine = {emb_cosine:.6f}")

        combined = np.vstack([emb_ref, emb_cmp])
        pdists = np.sqrt(np.sum((combined[:, None, :] - combined[None, :, :]) ** 2, axis=2))
        median_dist = np.median(pdists)
        gamma = 1.0 / (2 * (median_dist ** 2 + 1e-12))
        logger.info(f"  → MMD gamma parameter = {gamma:.6e}")

        def rbf_matrix(X, Y, gamma):
            dists = np.sum((X[:, None, :] - Y[None, :, :]) ** 2, axis=2)
            return np.exp(-gamma * dists)

        K_rr = rbf_matrix(emb_ref, emb_ref, gamma)
        K_cc = rbf_matrix(emb_cmp, emb_cmp, gamma)
        K_rc = rbf_matrix(emb_ref, emb_cmp, gamma)

        m = emb_ref.shape[0]
        n = emb_cmp.shape[0]
        mmd_sq = (np.sum(K_rr) / (m * m)) - (2 * np.sum(K_rc) / (m * n)) + (np.sum(K_cc) / (n * n))
        emb_mmd = float(np.sqrt(max(mmd_sq, 0.0)))
        logger.info(f"  → Embedding MMD = {emb_mmd:.6f}")

        return emb_cosine, emb_mmd


    def compute_column(self, col: str) -> dict:
        """
        Compute per‐column metrics using the new univariate methods.
        """
        logger.info(f"compute_column called for '{col}'.")
        results = {}
        # Missing-rate difference
        mr_diff = abs(self.ref_df[col].isna().mean() - self.cmp_df[col].isna().mean())
        results['missing_rate_diff'] = mr_diff
        logger.info(f"  → Missing‐rate difference = {mr_diff:.6f}")

        ctype = self.column_types.get(col)
        logger.info(f"  → Column type = {ctype}")

        if ctype == "numeric":
            # 1) KS similarity
            results['ks_similarity'] = self.numeric_ks_similarity(col)

            # 2) Wasserstein (raw + normalized)
            Wraw, Wnorm = self.numeric_wasserstein(col)
            results['wasserstein_raw']  = Wraw
            results['wasserstein_norm'] = Wnorm

            # 3) Detailed summary‐stats
            stats = self.numeric_summary_stats(col)
            results.update(stats)

            # 4) Range coverage
            results['range_coverage'] = self.numeric_range_coverage(col)

        elif ctype == "categorical":
            # 1) TVD similarity
            results['tvd_similarity'] = self.categorical_tvd_similarity(col)
            # 2) Category coverage
            results['category_coverage'] = self.categorical_coverage(col)

        elif ctype == "text":
            # 1) Token‐JS Divergence
            tok_js   = self._compute_text_token_js(col)

            # 2) Bigram Jaccard & Bigram‐JS
            bigram_j = self._compute_text_bigram_jaccard(col)
            bigram_js = self._compute_text_bigram_js(col)

            # 3) TF–IDF Cosine
            tfidf_cos = self._compute_text_tfidf_cosine(col)

            # 4) Vocabulary Jaccard (unigrams)
            vocab_j = self._compute_text_vocab_jaccard(col)

            # 5) OOV Rate
            oov_real = self._compute_text_oov_rate(col)

            # 6) Document Length Stats
            length_stats = self._compute_text_length_stats(col)
            len_real = length_stats['len_real']
            len_syn  = length_stats['len_syn']
            len_diff = length_stats['len_diff']

            # 7) Topic‐Distribution Cosine
            topic_cos = self._compute_text_topic_cosine(col)

            # 8) Embedding Cosine & MMD
            emb_cos, emb_mmd = self._compute_text_embedding_metrics(col)

            # Populate results
            results['text_tok_js']         = tok_js
            results['text_bigram_jaccard'] = bigram_j
            results['text_bigram_js']      = bigram_js
            results['text_tfidf_cosine']   = tfidf_cos
            results['text_vocab_jaccard']  = vocab_j
            results['text_oov_rate']       = oov_real
            results['len_real']            = len_real
            results['len_syn']             = len_syn
            results['len_diff']            = len_diff
            results['text_topic_cosine']   = topic_cos
            results['text_emb_cosine']     = emb_cos
            results['text_emb_mmd']        = emb_mmd

        else:
            # datetime or other types: only missing‐rate
            logger.info(f"  → No specialized univariate for type '{ctype}', returning only missing‐rate.")

        logger.info(f"  → Completed compute_column for '{col}'.")
        return {'dtype': ctype, **results}


    def plot_column_histogram(self, col: str):
        logger.info(f"Generating histogram plot for '{col}'.")
        ref_vals = self.ref_df[col].dropna().astype(float)
        cmp_vals = self.cmp_df[col].dropna().astype(float)

        fig = go.Figure()
        fig.add_trace(go.Histogram(
            x=ref_vals,
            histnorm='probability density',
            name='Reference',
            opacity=0.6,
            nbinsx=30
        ))
        fig.add_trace(go.Histogram(
            x=cmp_vals,
            histnorm='probability density',
            name='Comparison',
            opacity=0.6,
            nbinsx=30
        ))
        fig.update_layout(
            barmode='overlay',
            title=f"Histogram (Density) for '{col}'",
            xaxis_title=col,
            yaxis_title='Density'
        )
        logger.info(f"  → Histogram figure ready for '{col}'.")
        return fig


    def plot_column_kde(self, col: str):
        logger.info(f"Generating KDE plot for '{col}'.")
        ref_vals = self.ref_df[col].dropna().astype(float).values
        cmp_vals = self.cmp_df[col].dropna().astype(float).values

        if len(np.unique(ref_vals)) < 2 or len(np.unique(cmp_vals)) < 2:
            logger.info("  → Not enough distinct values for KDE: returning empty figure.")
            return go.Figure()

        try:
            kde_ref = gaussian_kde(ref_vals)
            kde_cmp = gaussian_kde(cmp_vals)
        except np.linalg.LinAlgError:
            logger.info("  → KDE failed (singular covariance): returning empty figure.")
            return go.Figure()

        pooled = np.concatenate([ref_vals, cmp_vals])
        xmin, xmax = pooled.min(), pooled.max()
        grid = np.linspace(xmin, xmax, 200)

        pdf_ref = kde_ref(grid)
        pdf_cmp = kde_cmp(grid)

        fig = go.Figure()
        fig.add_trace(go.Scatter(x=grid, y=pdf_ref, mode='lines', name='Ref KDE'))
        fig.add_trace(go.Scatter(x=grid, y=pdf_cmp, mode='lines', name='Cmp KDE'))
        fig.update_layout(
            title=f"KDE Plot for '{col}'",
            xaxis_title=col,
            yaxis_title='Density'
        )
        logger.info(f"  → KDE figure ready for '{col}'.")
        return fig


    def plot_category_bar(self, col: str):
        logger.info(f"Generating category bar chart for '{col}'.")
        ref_counts = self.ref_df[col].dropna().value_counts(normalize=True).reset_index()
        ref_counts.columns = [col, 'proportion']
        ref_counts['dataset'] = 'Reference'

        cmp_counts = self.cmp_df[col].dropna().value_counts(normalize=True).reset_index()
        cmp_counts.columns = [col, 'proportion']
        cmp_counts['dataset'] = 'Comparison'

        df_union = pd.concat([ref_counts, cmp_counts], ignore_index=True)

        fig = px.bar(
            df_union,
            x=col,
            y='proportion',
            color='dataset',
            barmode='group',
            title=f"Category Proportions for '{col}'"
        )
        logger.info(f"  → Category bar chart ready for '{col}'.")
        return fig


    def plot_text_top_tokens(self, col: str, top_n: int = 20):
        logger.info(f"Generating top‐token bar chart for '{col}' (top_n={top_n}).")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()

        tokens_ref = [tok for t in texts_ref for tok in t.split()]
        tokens_cmp = [tok for t in texts_cmp for tok in t.split()]

        ctr_ref = Counter(tokens_ref)
        ctr_cmp = Counter(tokens_cmp)
        combined = ctr_ref + ctr_cmp
        most_common = [tok for tok, _ in combined.most_common(top_n)]

        ref_freqs = [ctr_ref.get(tok, 0) for tok in most_common]
        cmp_freqs = [ctr_cmp.get(tok, 0) for tok in most_common]

        n = len(most_common)
        df_plot = pd.DataFrame({
            'token': most_common * 2,
            'count': ref_freqs + cmp_freqs,
            'dataset': ['Reference'] * n + ['Comparison'] * n
        })

        fig = px.bar(
            df_plot,
            x='token',
            y='count',
            color='dataset',
            barmode='group',
            title=f"Top {n} Token Frequencies for '{col}'"
        )
        logger.info(f"  → Top‐token bar chart ready for '{col}'.")
        return fig


    def plot_text_length_hist(self, col: str):
        logger.info(f"Generating document‐length histogram for '{col}'.")
        lens_ref = self.ref_df[col].dropna().astype(str).str.split().apply(len)
        lens_cmp = self.cmp_df[col].dropna().astype(str).str.split().apply(len)

        df_plot = pd.DataFrame({
            'length': pd.concat([lens_ref, lens_cmp], ignore_index=True),
            'dataset': ['Reference'] * len(lens_ref) + ['Comparison'] * len(lens_cmp)
        })

        fig = px.histogram(
            df_plot,
            x='length',
            color='dataset',
            histnorm='probability density',
            barmode='overlay',
            nbins=30,
            title=f"Document Length Distribution for '{col}'"
        )
        logger.info(f"  → Document‐length histogram ready for '{col}'.")
        return fig


    def plot_text_embedding_scatter(self, col: str):
        logger.info(f"Generating (empty) embedding scatter for '{col}'.")
        # Placeholder (no implementation)
        return go.Figure()


    # ───── PCA‐Based Methods ────────────────────────────────────────────────────

    def _get_numeric_matrix(self) -> (np.ndarray, np.ndarray, list):
        logger.info("Building numeric matrices for PCA.")
        numeric_cols = [col for col in self.common_columns if self.column_types[col] == 'numeric']
        logger.info(f"  → Numeric columns: {numeric_cols}")
        if not numeric_cols:
            return np.array([]), np.array([]), []

        ref_num = self.ref_df[numeric_cols].dropna(axis=0, how='any').astype(float)
        cmp_num = self.cmp_df[numeric_cols].dropna(axis=0, how='any').astype(float)
        logger.info(f"  → After dropping NaNs: ref_num shape={ref_num.shape}, cmp_num shape={cmp_num.shape}")

        scaler_ref = StandardScaler().fit(ref_num.values)
        X_ref = scaler_ref.transform(ref_num.values)

        scaler_cmp = StandardScaler().fit(cmp_num.values)
        X_cmp = scaler_cmp.transform(cmp_num.values)

        return X_ref, X_cmp, numeric_cols


    def compute_pca_eigen_cosines(self, n_components: int = 5) -> dict:
        logger.info(f"Computing PCA eigenvector cosines (n_components={n_components}).")
        X_ref, X_cmp, numeric_cols = self._get_numeric_matrix()
        if X_ref.size == 0 or X_cmp.size == 0:
            logger.info("  → Not enough numeric data: returning empty PCA result.")
            return {
                'cosine_similarities': [],
                'explained_variance_ref': [],
                'explained_variance_cmp': [],
                'weighted_cosine': np.nan
            }

        pca_ref = PCA(n_components=n_components)
        pca_cmp = PCA(n_components=n_components)
        pca_ref.fit(X_ref)
        pca_cmp.fit(X_cmp)

        eigvecs_ref = pca_ref.components_
        eigvecs_cmp = pca_cmp.components_

        cosines = []
        for i in range(len(eigvecs_ref)):
            v1 = eigvecs_ref[i]
            v2 = eigvecs_cmp[i]
            dot = np.dot(v1, v2)
            n1 = np.linalg.norm(v1)
            n2 = np.linalg.norm(v2)
            cos = dot / (n1 * n2) if (n1 > 0 and n2 > 0) else 0.0
            cosines.append(cos)
            logger.info(f"  → PC{i+1} cosine = {cos:.6f}")

        weights = pca_ref.explained_variance_ratio_
        weighted_cos = float(np.dot(cosines, weights[:len(cosines)]))
        logger.info(f"  → Weighted cosine = {weighted_cos:.6f}")

        return {
            'cosine_similarities': cosines,
            'explained_variance_ref': pca_ref.explained_variance_ratio_.tolist(),
            'explained_variance_cmp': pca_cmp.explained_variance_ratio_.tolist(),
            'weighted_cosine': weighted_cos
        }


    def get_pca_projection_dataframe(self, n_components: int = 2) -> pd.DataFrame:
        logger.info(f"Building PCA projection dataframe (n_components={n_components}).")
        X_ref, X_cmp, numeric_cols = self._get_numeric_matrix()
        if X_ref.size == 0 or X_cmp.size == 0:
            logger.info("  → Not enough numeric data: returning empty DataFrame.")
            return pd.DataFrame()

        X_all = np.vstack([X_ref, X_cmp])
        pca = PCA(n_components=n_components)
        coords_all = pca.fit_transform(X_all)

        n_ref = X_ref.shape[0]
        df_coords = pd.DataFrame(coords_all, columns=[f"PC{i+1}" for i in range(n_components)])
        df_coords['dataset'] = ['Reference'] * n_ref + ['Comparison'] * (coords_all.shape[0] - n_ref)
        logger.info("  → PCA projection dataframe ready.")
        return df_coords


    def plot_pca_scatter(self, n_components: int = 2):
        logger.info(f"Generating PCA scatter plot (components={n_components}).")
        df_coords = self.get_pca_projection_dataframe(n_components=n_components)
        if df_coords.empty or 'PC1' not in df_coords or 'PC2' not in df_coords:
            logger.info("  → Not enough PCA components: returning empty figure.")
            return go.Figure()

        fig = px.scatter(
            df_coords,
            x="PC1",
            y="PC2",
            color="dataset",
            symbol="dataset",
            title=f"PCA Projection (PC1 vs PC2)"
        )
        fig.update_traces(marker=dict(size=5, opacity=0.75))
        logger.info("  → PCA scatter plot ready.")
        return fig


    def compute_correlation_similarity(self) -> float:
        """
        Computes the average pairwise Pearson‐correlation similarity between
        numeric columns in the reference vs. comparison data. If there are fewer
        than 2 numeric columns, returns np.nan.
        """
        logger.info("Computing global correlation similarity (numeric × numeric).")
        numeric_cols = [c for c in self.common_columns if self.column_types.get(c) == "numeric"]
        logger.info(f"  → Numeric columns: {numeric_cols}")
        if len(numeric_cols) < 2:
            logger.info("  → Fewer than 2 numeric columns: returning NaN.")
            return np.nan

        df_real = self.ref_df[numeric_cols].dropna()
        df_syn  = self.cmp_df[numeric_cols].dropna()
        if df_real.shape[0] < 2 or df_syn.shape[0] < 2:
            logger.info("  → Not enough rows after dropping NaNs: returning NaN.")
            return np.nan

        corr_real = df_real.corr(method="pearson")
        corr_syn  = df_syn.corr(method="pearson")
        sims = []
        for (i, j) in combinations(numeric_cols, 2):
            r_r = corr_real.at[i, j]
            r_s = corr_syn.at[i, j]
            if pd.isna(r_r) or pd.isna(r_s):
                continue
            sim_ij = 1.0 - abs(r_r - r_s)
            sims.append(sim_ij)
            logger.info(f"  → Pair ({i},{j}): corr_real={r_r:.4f}, corr_syn={r_s:.4f}, sim={sim_ij:.6f}")

        if not sims:
            logger.info("  → No valid correlation pairs: returning NaN.")
            return np.nan
        mean_sim = float(np.mean(sims))
        logger.info(f"  → Mean correlation similarity = {mean_sim:.6f}")
        return mean_sim


    def compute_contingency_similarity(self) -> float:
        """
        Computes the average Total Variation Distance (TVD) similarity over all
        pairs of categorical columns. If fewer than 2 categorical columns, returns np.nan.

        For each pair of categorical columns (A,B):
          • Build the joint frequency table in real and synthetic (normalized).
          • TVD(A,B) = 0.5 * sum |p_real( a,b ) - p_syn( a,b )| over all (a,b).
          • sim(A,B) = 1 - TVD(A,B).
        Finally, return the mean(sim(A,B)) for all A < B.
        """
        logger.info("Computing global contingency similarity (categorical × categorical).")
        categorical_cols = [c for c in self.common_columns if self.column_types.get(c) == "categorical"]
        logger.info(f"  → Categorical columns: {categorical_cols}")
        if len(categorical_cols) < 2:
            logger.info("  → Fewer than 2 categorical columns: returning NaN.")
            return np.nan

        sims = []
        for (c1, c2) in combinations(categorical_cols, 2):
            logger.info(f"  → Processing pair ({c1},{c2}).")
            real_pair = self.ref_df[[c1, c2]].dropna().astype(str)
            syn_pair  = self.cmp_df[[c1, c2]].dropna().astype(str)

            if real_pair.shape[0] < 1 or syn_pair.shape[0] < 1:
                logger.info(f"    → One side empty after dropna: skipping pair.")
                continue

            real_counts = real_pair.groupby([c1, c2]).size().rename("count").reset_index()
            syn_counts  = syn_pair.groupby([c1, c2]).size().rename("count").reset_index()

            real_pivot = real_counts.pivot(index=c1, columns=c2, values="count").fillna(0)
            syn_pivot  = syn_counts.pivot(index=c1, columns=c2, values="count").fillna(0)

            all_index = real_pivot.index.union(syn_pivot.index)
            all_columns = real_pivot.columns.union(syn_pivot.columns)
            real_norm = real_pivot.reindex(index=all_index, columns=all_columns, fill_value=0)
            syn_norm  = syn_pivot.reindex(index=all_index, columns=all_columns, fill_value=0)

            real_prob = real_norm.values / real_norm.values.sum()
            syn_prob  = syn_norm.values / syn_norm.values.sum()

            tvd = 0.5 * np.abs(real_prob - syn_prob).sum()
            sim = 1.0 - tvd
            sims.append(sim)
            logger.info(f"    → TVD for pair = {tvd:.6f}, similarity = {sim:.6f}")

        if not sims:
            logger.info("  → No valid categorical pairs: returning NaN.")
            return np.nan
        mean_sim = float(np.mean(sims))
        logger.info(f"  → Mean contingency similarity = {mean_sim:.6f}")
        return mean_sim


    def compute_distinguishability_auc(
        self,
        test_size: float = 0.3,
        random_state: int = 42,
        max_rows_per_side: int = 1000
    ) -> float:
        """
        Trains a LogisticRegression on a balanced subsample
        of up to max_rows_per_side real vs. synthetic rows.
        Returns ROC AUC. If not enough variance, returns NaN.
        """
        logger.info(f"Computing distinguishability AUC (max_rows_per_side={max_rows_per_side}).")
        df_real = self.ref_df.copy()
        df_real["_is_synthetic"] = 0
        df_syn  = self.cmp_df.copy()
        df_syn["_is_synthetic"] = 1

        if len(df_real) > max_rows_per_side:
            df_real = df_real.sample(n=max_rows_per_side, random_state=random_state)
            logger.info(f"  → Subsampled real to {max_rows_per_side} rows.")
        if len(df_syn) > max_rows_per_side:
            df_syn = df_syn.sample(n=max_rows_per_side, random_state=random_state)
            logger.info(f"  → Subsampled synthetic to {max_rows_per_side} rows.")

        df_all = pd.concat([df_real, df_syn], ignore_index=True)
        logger.info(f"  → Concatenated dataset shape = {df_all.shape}")

        numeric_cols = [c for c in self.common_columns if self.column_types.get(c) == "numeric"]
        categorical_cols = [c for c in self.common_columns if self.column_types.get(c) == "categorical"]
        logger.info(f"  → Numeric columns: {numeric_cols}")
        logger.info(f"  → Categorical columns: {categorical_cols}")

        if numeric_cols:
            num_mat = df_all[numeric_cols].astype(float).fillna(0.0)
            scaler = StandardScaler()
            X_num = scaler.fit_transform(num_mat.values)
            logger.info(f"  → Numeric matrix shape = {X_num.shape}")
        else:
            X_num = np.empty((len(df_all), 0))
            logger.info("  → No numeric columns: X_num is empty array.")

        if categorical_cols:
            df_cat = df_all[categorical_cols].astype(str).fillna("NA").copy()
            for col in categorical_cols:
                top_cats = df_cat[col].value_counts().nlargest(50).index
                df_cat[col] = df_cat[col].where(df_cat[col].isin(top_cats), other="OTHER")
                logger.info(f"  → For '{col}', collapsed rare categories to 'OTHER'.")
            encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
            X_cat = encoder.fit_transform(df_cat.values)
            logger.info(f"  → One-hot encoded shape = {X_cat.shape}")
        else:
            X_cat = np.empty((len(df_all), 0))
            logger.info("  → No categorical columns: X_cat is empty array.")

        X = np.hstack([X_num, X_cat])
        y = df_all["_is_synthetic"].values
        logger.info(f"  → Combined feature matrix shape = {X.shape}, labels length = {len(y)}")

        if len(np.unique(y)) < 2:
            logger.info("  → Only one class present: returning NaN.")
            return np.nan

        try:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=test_size, stratify=y, random_state=random_state
            )
            logger.info(f"  → Train/test split: X_train={X_train.shape}, X_test={X_test.shape}")
        except ValueError as e:
            logger.info(f"  → Train/test split error: {e}. Returning NaN.")
            return np.nan

        clf = LogisticRegression(
            solver="liblinear", max_iter=100, random_state=random_state
        )
        clf.fit(X_train, y_train)
        logger.info("  → LogisticRegression fitted on training data.")

        y_prob = clf.predict_proba(X_test)[:, 1]
        try:
            auc = roc_auc_score(y_test, y_prob)
            logger.info(f"  → Computed ROC AUC = {auc:.6f}")
        except ValueError as e:
            logger.info(f"  → ROC AUC computation error: {e}. Returning NaN.")
            return np.nan

        return float(auc)


    # ─── Text‐Bigram & OOV Helpers ────────────────────────────────────────────

    def _compute_text_bigram_jaccard(self, col: str) -> float:
        logger.info(f"Computing bigram‐level Jaccard for '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side empty: returning NaN.")
            return np.nan

        def all_bigrams(texts):
            bigrams = set()
            for t in texts:
                tokens = t.split()
                for i in range(len(tokens) - 1):
                    bigrams.add((tokens[i], tokens[i + 1]))
            return bigrams

        big_ref = all_bigrams(texts_ref)
        big_cmp = all_bigrams(texts_cmp)
        if not big_ref and not big_cmp:
            logger.info("  → Both bigram sets empty: returning NaN.")
            return np.nan
        if not big_ref or not big_cmp:
            logger.info("  → One bigram set empty: returning 0.0.")
            return 0.0
        intersect = big_ref.intersection(big_cmp)
        union = big_ref.union(big_cmp)
        jaccard = len(intersect) / len(union)
        logger.info(f"  → Bigram Jaccard = {jaccard:.6f}")
        return jaccard


    def _compute_text_bigram_js(self, col: str, top_k: int = 5000) -> float:
        logger.info(f"Computing bigram‐level JS divergence for '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side empty: returning NaN.")
            return np.nan

        def bigram_counts(texts):
            ctr = Counter()
            for t in texts:
                tokens = t.split()
                for i in range(len(tokens) - 1):
                    ctr[(tokens[i], tokens[i + 1])] += 1
            return ctr

        ctr_ref = bigram_counts(texts_ref)
        ctr_cmp = bigram_counts(texts_cmp)
        combined = ctr_ref + ctr_cmp
        most_common = [bg for bg, _ in combined.most_common(top_k)]
        logger.info(f"  → Top {len(most_common)} bigrams selected.")

        p = np.array([ctr_ref.get(bg, 0) for bg in most_common], dtype=float)
        q = np.array([ctr_cmp.get(bg, 0) for bg in most_common], dtype=float)
        if p.sum() == 0 or q.sum() == 0:
            logger.info("  → Zero counts on one side: returning NaN.")
            return np.nan
        p = p / p.sum()
        q = q / q.sum()
        js = Report._js_divergence(p, q)
        return js


    def _compute_text_oov_rate(self, col: str) -> float:
        logger.info(f"Computing OOV rate for '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side empty: returning NaN.")
            return np.nan

        tokens_ref = set(tok for t in texts_ref for tok in t.split())
        tokens_cmp = set(tok for t in texts_cmp for tok in t.split())
        if not tokens_ref:
            logger.info("  → No tokens in real: returning NaN.")
            return np.nan
        oov_real = len(tokens_ref - tokens_cmp) / len(tokens_ref)
        logger.info(f"  → OOV rate (real→syn) = {oov_real:.6f}")
        return oov_real


    def _compute_text_length_stats(self, col: str) -> dict:
        logger.info(f"Computing text length stats for '{col}'.")
        lens_ref = self.ref_df[col].dropna().astype(str).str.split().apply(len)
        lens_cmp = self.cmp_df[col].dropna().astype(str).str.split().apply(len)
        if lens_ref.empty or lens_cmp.empty:
            logger.info("  → One side empty: returning NaN stats.")
            return {'len_real': np.nan, 'len_syn': np.nan, 'len_diff': np.nan}
        avg_r = lens_ref.mean()
        avg_s = lens_cmp.mean()
        diff = avg_s - avg_r
        logger.info(f"  → Avg length real={avg_r:.2f}, syn={avg_s:.2f}, diff={diff:.2f}")
        return {'len_real': avg_r, 'len_syn': avg_s, 'len_diff': diff}


    def _compute_text_topic_cosine(
        self,
        col: str,
        n_topics: int = 5,
        max_features: int = 1000,
        max_docs_per_side: int = 200
    ) -> float:
        logger.info(f"Computing topic cosine for '{col}'.")
        texts_ref = self.ref_df[col].dropna().astype(str).tolist()
        texts_cmp = self.cmp_df[col].dropna().astype(str).tolist()
        if not texts_ref or not texts_cmp:
            logger.info("  → One side empty: returning NaN.")
            return np.nan

        if len(texts_ref) > max_docs_per_side:
            rng = np.random.default_rng(seed=0)
            texts_ref = list(rng.choice(texts_ref, max_docs_per_side, replace=False))
            logger.info(f"  → Subsampled {len(texts_ref)} real docs.")
        if len(texts_cmp) > max_docs_per_side:
            rng = np.random.default_rng(seed=1)
            texts_cmp = list(rng.choice(texts_cmp, max_docs_per_side, replace=False))
            logger.info(f"  → Subsampled {len(texts_cmp)} synthetic docs.")

        corpus = texts_ref + texts_cmp
        logger.info(f"  → TF–IDF on {len(corpus)} docs (max_features={max_features}).")
        vectorizer = TfidfVectorizer(
            max_features=max_features,
            stop_words='english'
        )
        try:
            tfidf_all = vectorizer.fit_transform(corpus)
        except Exception as e:
            logger.info(f"  → TF–IDF fit error: {e}. Returning NaN.")
            return np.nan

        nmf = NMF(
            n_components=n_topics,
            random_state=0,
            init='nndsvda',
            max_iter=100
        )
        logger.info(f"  → Fitting NMF (n_topics={n_topics}) on TF–IDF matrix.")
        try:
            W_all = nmf.fit_transform(tfidf_all)
        except Exception as e:
            logger.info(f"  → NMF fit error: {e}. Returning NaN.")
            return np.nan

        n_ref = len(texts_ref)
        W_ref = W_all[:n_ref, :]
        W_cmp = W_all[n_ref:, :]

        mean_ref = W_ref.mean(axis=0)
        mean_cmp = W_cmp.mean(axis=0)
        denom = np.linalg.norm(mean_ref) * np.linalg.norm(mean_cmp)
        cosine = float(np.dot(mean_ref, mean_cmp) / denom) if denom > 0 else np.nan
        logger.info(f"  → Topic cosine = {cosine:.6f}")
        return cosine


## Tab (Abstract)

In [None]:
class Tab(ABC):
    """
    Abstract base for a UI “tab.” Each subclass must implement build_ui() and register_callbacks().
    """
    def __init__(self, report_state, common_cols_state):
        self.report_state = report_state
        self.common_cols_state = common_cols_state

    @abstractmethod
    def build_ui(self):
        pass

    @abstractmethod
    def register_callbacks(self):
        pass

In [None]:
class GeneralStatsTab(Tab):
    """
    “General Stats” tab, reorganised into accordions for:
      • Missingness
      • Numeric Distributions
      • Categorical Frequencies
      • Text Metadata
    """
    def __init__(self, report_state, common_cols_state):
        super().__init__(report_state, common_cols_state)

    def build_ui(self):
        with gr.TabItem("General Stats"):
            # ---- Data‐Load Controls ----
            with gr.Row():
                self.ref_path_input = gr.Textbox(
                    label="Reference Dataset Path", placeholder="e.g. data/reference.csv"
                )
                self.cmp_path_input = gr.Textbox(
                    label="Comparison Dataset Path", placeholder="e.g. data/comparison.csv"
                )
            with gr.Row():
                self.ref_schema_input = gr.Textbox(
                    label="Reference Schema Path (JSON/YAML)", placeholder="e.g. data/ref_schema.json"
                )
                self.cmp_schema_input = gr.Textbox(
                    label="Comparison Schema Path (JSON/YAML)", placeholder="e.g. data/cmp_schema.json"
                )

            self.load_button = gr.Button("Load Data")
            self.load_status = gr.Markdown("Awaiting data load...")

            # ---- Side‐by‐Side Dataframes of Common Columns ----
            with gr.Row():
                self.ref_table = gr.Dataframe(
                    value=None,
                    label="Reference Dataset (Common Columns Only)",
                    interactive=False
                )
                self.cmp_table = gr.Dataframe(
                    value=None,
                    label="Comparison Dataset (Common Columns Only)",
                    interactive=False
                )

            # ---- Common‐Columns Dropdown ----
            self.column_dropdown = gr.Dropdown(
                label="Select Common Column for Univariate Comparison",
                choices=[],
                interactive=False
            )

            # ---- Accordions & Placeholders ----

            # 1) Missingness Accordion
            self.missing_accordion = gr.Accordion("Missingness", visible=False, open=True)
            with self.missing_accordion:
                self.missing_md = gr.Markdown("", visible=False)

            # 2) Numeric Distributions Accordion
            self.numeric_accordion = gr.Accordion("Numeric Distributions", visible=False, open=False)
            with self.numeric_accordion:
                self.numeric_md        = gr.Markdown("", visible=False)
                self.numeric_cdf_plot  = gr.Plot(visible=False)
                self.numeric_hist_plot = gr.Plot(visible=False)
                self.numeric_kde_plot  = gr.Plot(visible=False)

            # 3) Categorical Frequencies Accordion
            self.categorical_accordion = gr.Accordion("Categorical Frequencies", visible=False, open=False)
            with self.categorical_accordion:
                self.categorical_md       = gr.Markdown("", visible=False)
                self.categorical_bar_plot = gr.Plot(visible=False)

            # 4) Text Metadata Accordion
            self.text_accordion = gr.Accordion("Text Metadata", visible=False, open=False)
            with self.text_accordion:
                self.text_md            = gr.Markdown("", visible=False)
                self.text_token_plot    = gr.Plot(visible=False)
                self.text_len_plot      = gr.Plot(visible=False)
                self.text_samples_real  = gr.Textbox(label="Sample Real Text", visible=False, interactive=False)
                self.text_samples_synth = gr.Textbox(label="Sample Synthetic Text", visible=False, interactive=False)

    def register_callbacks(self):
        # 1. Load Data callback
        def load_data(ref_path, cmp_path, ref_schema_path, cmp_schema_path):
            try:
                ref_df, cmp_df, common_columns, ref_schema, cmp_schema = (
                    DataManager()
                    .get_ref_and_cmp_data(
                        ref_path, cmp_path, ref_schema_path, cmp_schema_path
                    )
                )
                report = Report(ref_df, cmp_df, common_columns, ref_schema, cmp_schema)

                ref_sub = ref_df.loc[:, common_columns]
                cmp_sub = cmp_df.loc[:, common_columns]

                return (
                    report,
                    common_columns,
                    ref_sub,
                    cmp_sub,
                    f"✅ Data loaded successfully. Found {len(common_columns)} common columns."
                )
            except Exception as e:
                err = f"❌ Error loading data: {e}"
                return (None, None, None, None, err)

        self.load_button.click(
            load_data,
            inputs=[
                self.ref_path_input,
                self.cmp_path_input,
                self.ref_schema_input,
                self.cmp_schema_input,
            ],
            outputs=[
                self.report_state,
                self.common_cols_state,
                self.ref_table,
                self.cmp_table,
                self.load_status,
            ]
        )

        # 2. Enable Dataframes & Dropdown after data loads
        def enable_components(report, common_columns):
            if report is not None and common_columns:
                return (
                    gr.update(interactive=True),
                    gr.update(interactive=True),
                    gr.update(choices=common_columns, interactive=True),
                )
            else:
                return (
                    gr.update(interactive=False),
                    gr.update(interactive=False),
                    gr.update(choices=[], interactive=False),
                )

        self.report_state.change(
            enable_components,
            inputs=[self.report_state, self.common_cols_state],
            outputs=[self.ref_table, self.cmp_table, self.column_dropdown]
        )

        # 3. Show per‐column stats & plots when dropdown changes
        def show_column_stats(report, col_name):
            # If no report or no column chosen: hide everything
            if report is None or not col_name:
                return (
                    # Missingness Accordion + content
                    gr.update(visible=False),
                    "", gr.update(visible=False),

                    # Numeric Accordion + content + plots
                    gr.update(visible=False),
                    "", gr.update(visible=False),
                    None, gr.update(visible=False),
                    None, gr.update(visible=False),
                    None, gr.update(visible=False),

                    # Categorical Accordion + content + plot
                    gr.update(visible=False),
                    "", gr.update(visible=False),
                    None, gr.update(visible=False),

                    # Text Accordion + content + plots + samples
                    gr.update(visible=False),
                    "", gr.update(visible=False),
                    None, gr.update(visible=False),
                    None, gr.update(visible=False),
                    "", gr.update(visible=False),
                    "", gr.update(visible=False),
                )

            # Compute univariate metrics
            ctype       = report.column_types.get(col_name)
            col_metrics = report.compute_column(col_name)

            #######  (A) MISSINGNESS #######
            mr_real = report.ref_df[col_name].isna().mean()
            mr_synth = report.cmp_df[col_name].isna().mean()
            mr_diff = abs(mr_real - mr_synth)
            missing_md = ""
            missing_md += "### Missingness (Real vs Synthetic)\n\n"
            missing_md += (
                f"- Real % missing = {mr_real*100:.1f}%  |  "
                f"Synth % missing = {mr_synth*100:.1f}%  |  "
                f"Δ = {mr_diff*100:+.1f} pts\n"
            )

            #######  (B) NUMERIC #######
            numeric_md   = ""
            cdf_fig      = None
            hist_fig     = None
            kde_fig      = None
            show_numeric = (ctype == "numeric")

            if show_numeric:
                # 1) KS similarity
                ks_sim = col_metrics.get("ks_similarity", np.nan)
                # 2) Wasserstein (raw + normalised)
                w_raw  = col_metrics.get("wasserstein_raw",   np.nan)
                w_norm = col_metrics.get("wasserstein_norm",  np.nan)
                # 3) Mean / Median / Std differences
                mean_r   = col_metrics.get("mean_real",    np.nan)
                mean_s   = col_metrics.get("mean_syn",     np.nan)
                mean_d   = col_metrics.get("mean_diff",    np.nan)
                mean_pct = col_metrics.get("mean_pct_of_range", np.nan)

                med_r    = col_metrics.get("median_real",  np.nan)
                med_s    = col_metrics.get("median_syn",   np.nan)
                med_d    = col_metrics.get("median_diff",  np.nan)
                med_pct  = col_metrics.get("median_pct_of_range", np.nan)

                std_r    = col_metrics.get("std_real",     np.nan)
                std_s    = col_metrics.get("std_syn",      np.nan)
                std_d    = col_metrics.get("std_diff",     np.nan)
                std_pct  = col_metrics.get("std_pct_of_std", np.nan)

                # 4) Range coverage
                range_cov = col_metrics.get("range_coverage", np.nan)

                numeric_md += "### Numeric Distributions (Real vs Synthetic)\n\n"
                if not np.isnan(ks_sim):
                    numeric_md += f"- KS Similarity (1 – D): **{ks_sim:.4f}**\n"
                if not np.isnan(w_raw):
                    if not np.isnan(w_norm):
                        numeric_md += (
                            f"- Wasserstein Distance: **{w_raw:.4f} units**  "
                            f"(normalised = {w_norm:.4f})\n"
                        )
                    else:
                        numeric_md += f"- Wasserstein Distance: **{w_raw:.4f} units**  (normalised = N/A)\n"

                if not np.isnan(mean_r) and not np.isnan(mean_s):
                    pct_txt = f" ({mean_pct*100:+.2f}% of real‐range)" if not np.isnan(mean_pct) else ""
                    numeric_md += (
                        f"- Mean → Real: **{mean_r:.4f}**, Syn: **{mean_s:.4f}**, "
                        f"Δ = **{mean_d:+.4f}**{pct_txt}\n"
                    )
                if not np.isnan(med_r) and not np.isnan(med_s):
                    pct_txt = f" ({med_pct*100:+.2f}% of real‐range)" if not np.isnan(med_pct) else ""
                    numeric_md += (
                        f"- Median → Real: **{med_r:.4f}**, Syn: **{med_s:.4f}**, "
                        f"Δ = **{med_d:+.4f}**{pct_txt}\n"
                    )
                if not np.isnan(std_r) and not np.isnan(std_s):
                    pct_txt = f" ({std_pct*100:+.1f}% of real‐std)" if not np.isnan(std_pct) else ""
                    numeric_md += (
                        f"- Std Dev → Real: **{std_r:.4f}**, Syn: **{std_s:.4f}**, "
                        f"Δ = **{std_d:+.4f}**{pct_txt}\n"
                    )

                if not np.isnan(range_cov):
                    numeric_md += f"- Range Coverage: **{range_cov:.4f}** (1.0 = perfect)\n"

                # Generate plots
                cdf_fig  = report.plot_column_cdf(col_name)
                hist_fig = report.plot_column_histogram(col_name)
                kde_fig  = report.plot_column_kde(col_name)

            #######  (C) CATEGORICAL #######
            cat_md   = ""
            cat_fig  = None
            show_cat = (ctype == "categorical")

            if show_cat:
                tvd_sim  = col_metrics.get("tvd_similarity",    np.nan)
                coverage = col_metrics.get("category_coverage", np.nan)

                cat_md += "### Categorical Frequencies (Real vs Synthetic)\n\n"
                if not np.isnan(tvd_sim):
                    cat_md += f"- TVD Similarity: **{tvd_sim:.4f}**\n"
                if not np.isnan(coverage):
                    cat_md += f"- Category Coverage: **{coverage:.4f}**\n"

                cat_fig = report.plot_category_bar(col_name)

            #######  (D) TEXT #######
            text_md    = ""
            tok_fig    = None
            len_fig    = None
            samples_r  = ""
            samples_s  = ""
            show_text  = (ctype == "text")

            if show_text:
                tok_js     = col_metrics.get("text_tok_js",        np.nan)
                bigram_j   = col_metrics.get("text_bigram_jaccard", np.nan)
                bigram_js  = col_metrics.get("text_bigram_js",      np.nan)
                tfidf_cos  = col_metrics.get("text_tfidf_cosine",   np.nan)
                vocab_j    = col_metrics.get("text_vocab_jaccard",  np.nan)
                oov_rate   = col_metrics.get("text_oov_rate",       np.nan)
                topic_cos  = col_metrics.get("text_topic_cosine",   np.nan)
                len_r      = col_metrics.get("len_real",            np.nan)
                len_s      = col_metrics.get("len_syn",             np.nan)
                len_d      = col_metrics.get("len_diff",            np.nan)
                emb_cos    = col_metrics.get("text_emb_cosine",     np.nan)
                emb_mmd    = col_metrics.get("text_emb_mmd",        np.nan)

                text_md += "### Text Metadata (Real vs Synthetic)\n\n"
                if not np.isnan(tok_js):
                    text_md += f"- Token‐JS Divergence (unigrams): **{tok_js:.4f}**\n"
                if not np.isnan(bigram_j):
                    text_md += f"- Bigram Jaccard: **{bigram_j:.4f}**\n"
                if not np.isnan(bigram_js):
                    text_md += f"- Bigram‐JS Divergence: **{bigram_js:.4f}**\n"
                if not np.isnan(tfidf_cos):
                    text_md += f"- TF–IDF Cosine Similarity: **{tfidf_cos:.4f}**\n"
                if not np.isnan(vocab_j):
                    text_md += f"- Vocabulary Jaccard (unigrams): **{vocab_j:.4f}**\n"
                if not np.isnan(oov_rate):
                    text_md += f"- OOV Rate (real→synth): **{oov_rate*100:.1f}%**\n"
                if not np.isnan(topic_cos):
                    text_md += f"- Topic‐Distribution Cosine: **{topic_cos:.4f}**\n"
                if not np.isnan(len_r) and not np.isnan(len_s):
                    text_md += (
                        f"- Average Length → Real: **{len_r:.1f}** tokens, "
                        f"Synthetic: **{len_s:.1f}**, Δ = **{len_d:+.1f} tokens**\n"
                    )
                if not np.isnan(emb_cos):
                    text_md += f"- Embedding Cosine: **{emb_cos:.4f}**\n"
                if not np.isnan(emb_mmd):
                    text_md += f"- Embedding MMD: **{emb_mmd:.4f}**\n"

                # Plots for text
                tok_fig = report.plot_text_top_tokens(col_name)
                len_fig = report.plot_text_length_hist(col_name)

                # Sample up to 3 real vs. synthetic texts
                real_texts = report.ref_df[col_name].dropna().astype(str)
                synth_texts = report.cmp_df[col_name].dropna().astype(str)
                if len(real_texts) > 0 and len(synth_texts) > 0:
                    samples_r = "\n\n".join(real_texts.sample(min(3, len(real_texts))).tolist())
                    samples_s = "\n\n".join(synth_texts.sample(min(3, len(synth_texts))).tolist())

            # Return all 22 outputs in the exact order declared in build_ui()
            return (
                # 1) Missingness Accordion visibility
                gr.update(visible=True),
                # 2) Missingness Markdown content
                missing_md,
                # 3) Missingness Markdown visibility
                gr.update(visible=True),

                # 4) Numeric Accordion visibility
                gr.update(visible=show_numeric),
                # 5) Numeric Markdown content
                numeric_md,
                # 6) Numeric Markdown visibility
                gr.update(visible=show_numeric),
                # 7) CDF plot (numeric)
                cdf_fig,
                gr.update(visible=show_numeric),
                # 8) Histogram plot (numeric)
                hist_fig,
                gr.update(visible=show_numeric),
                # 9) KDE plot (numeric)
                kde_fig,
                gr.update(visible=show_numeric),

                # 10) Categorical Accordion visibility
                gr.update(visible=show_cat),
                # 11) Categorical Markdown content
                cat_md,
                # 12) Categorical Markdown visibility
                gr.update(visible=show_cat),
                # 13) Categorical bar‐plot
                cat_fig,
                gr.update(visible=show_cat),

                # 14) Text Accordion visibility
                gr.update(visible=show_text),
                # 15) Text Markdown content
                text_md,
                # 16) Text Markdown visibility
                gr.update(visible=show_text),
                # 17) Top‐token bar (text)
                tok_fig,
                gr.update(visible=show_text),
                # 18) Document‐length histogram (text)
                len_fig,
                gr.update(visible=show_text),
                # 19) Sample Real Text
                samples_r,
                gr.update(visible=show_text and bool(samples_r)),
                # 20) Sample Synthetic Text
                samples_s,
                gr.update(visible=show_text and bool(samples_s)),
            )

        # Bind the dropdown to all 22 outputs (each paired with a visibility toggle)
        self.column_dropdown.change(
            show_column_stats,
            inputs=[self.report_state, self.column_dropdown],
            outputs=[
                # Missingness
                self.missing_accordion,
                self.missing_md,
                self.missing_md,

                # Numeric
                self.numeric_accordion,
                self.numeric_md,
                self.numeric_md,
                self.numeric_cdf_plot,
                self.numeric_cdf_plot,
                self.numeric_hist_plot,
                self.numeric_hist_plot,
                self.numeric_kde_plot,
                self.numeric_kde_plot,

                # Categorical
                self.categorical_accordion,
                self.categorical_md,
                self.categorical_md,
                self.categorical_bar_plot,
                self.categorical_bar_plot,

                # Text
                self.text_accordion,
                self.text_md,
                self.text_md,
                self.text_token_plot,
                self.text_token_plot,
                self.text_len_plot,
                self.text_len_plot,
                self.text_samples_real,
                self.text_samples_real,
                self.text_samples_synth,
                self.text_samples_synth,
            ]
        )


In [None]:
class MultivariateTab(Tab):
    """
    “PCA Comparison” tab, now extended to include:
      1. PCA eigenvector‐cosine metrics
      2. Correlation Similarity (numeric × numeric)
      3. Contingency Similarity (categorical × categorical)
      4. Distinguishability AUC (classifier test)
    """
    def __init__(self, report_state, common_cols_state):
        super().__init__(report_state, common_cols_state)

    def build_ui(self):
        with gr.TabItem("Multivariate Analysis"):
            # Slider & button for PCA
            self.ncomp_slider = gr.Slider(
                minimum=1,
                maximum=10,
                step=1,
                value=2,
                label="Number of PCA Components to Compare"
            )
            self.compute_button = gr.Button("Compute PCA & Multivariate Metrics")

            # PCA outputs
            self.pca_metrics_output = gr.Markdown("", visible=False)
            self.pca_scatter_plot  = gr.Plot(visible=False)

            # NEW: Multivariate relationship outputs
            self.corr_sim_output        = gr.Markdown("", visible=False, label="Correlation Similarity")
            self.contingency_sim_output = gr.Markdown("", visible=False, label="Contingency Similarity")
            self.distinguish_auc_output = gr.Markdown("", visible=False, label="Distinguishability (ROC AUC)")

    def register_callbacks(self):
        # 1) Enable controls once report is ready
        def enable_controls(report):
            logger.info("MultivariateTab: enable_controls called.")
            if report is not None:
                logger.info("  → Enabling PCA & multivariate controls.")
                return gr.update(interactive=True), gr.update(interactive=True)
            else:
                logger.info("  → Disabling PCA & multivariate controls.")
                return gr.update(interactive=False), gr.update(interactive=False)

        self.report_state.change(
            enable_controls,
            inputs=[self.report_state],
            outputs=[self.ncomp_slider, self.compute_button]
        )

        # 2) When the button is clicked:
        def compute_multivariate(report, n_components):
            logger.info(f"MultivariateTab: compute_multivariate called (n_components={n_components}).")
            if report is None:
                logger.info("  → No report available: clearing all outputs.")
                return (
                    "",       # pca_metrics_output
                    None,     # pca_scatter_plot
                    gr.update(visible=False),
                    "",
                    gr.update(visible=False),
                    "",
                    gr.update(visible=False),
                    "",
                    gr.update(visible=False),
                )

            # --- (A) PCA Eigen‐Cosines (as before) ---
            pca_res = report.compute_pca_eigen_cosines(n_components=n_components)
            cosines = pca_res['cosine_similarities']
            ev_ref  = pca_res['explained_variance_ref']
            ev_cmp  = pca_res['explained_variance_cmp']
            weighted_cos = pca_res['weighted_cosine']

            md_pca = f"## PCA Eigenvector Cosines (first {len(cosines)} components)\n\n"
            for i, cos in enumerate(cosines, start=1):
                md_pca += (
                    f"- PC{i} cosine: **{cos:.4f}**  "
                    f"(exp_var_ref: {ev_ref[i-1]:.4f}, exp_var_cmp: {ev_cmp[i-1]:.4f})\n"
                )
            md_pca += f"\n**Weighted Cosine** (weights = ref exp_var_ratio): **{weighted_cos:.4f}**\n"
            logger.info("  → PCA Markdown prepared.")

            scatter_fig = report.plot_pca_scatter(n_components=max(n_components, 2))
            show_scatter = len(cosines) >= 2
            if show_scatter:
                logger.info("  → PCA scatter figure prepared.")
            else:
                logger.info("  → Not enough PCA components for scatter.")

            # --- (B) Correlation Similarity ---
            corr_sim = report.compute_correlation_similarity()
            if np.isnan(corr_sim):
                md_corr = "Correlation Similarity: **N/A** (need ≥ 2 numeric columns)"
                logger.info("  → Correlation similarity: N/A.")
            else:
                md_corr = f"**Correlation Similarity (numeric × numeric)**: **{corr_sim:.4f}**\n"
                logger.info(f"  → Correlation similarity = {corr_sim:.6f}")
            show_corr = not np.isnan(corr_sim)

            # --- (C) Contingency Similarity ---
            cont_sim = report.compute_contingency_similarity()
            if np.isnan(cont_sim):
                md_cont = "Contingency Similarity: **N/A** (need ≥ 2 categorical columns)"
                logger.info("  → Contingency similarity: N/A.")
            else:
                md_cont = f"**Contingency Similarity (cat × cat)**: **{cont_sim:.4f}**\n"
                logger.info(f"  → Contingency similarity = {cont_sim:.6f}")
            show_cont = not np.isnan(cont_sim)

            # --- (D) Distinguishability AUC ---
            auc = report.compute_distinguishability_auc()
            if np.isnan(auc):
                md_auc = "Distinguishability AUC: **N/A** (insufficient mix of numeric/categorical data)"
                logger.info("  → Distinguishability AUC: N/A.")
            else:
                md_auc = f"**Distinguishability (ROC AUC)**: **{auc:.4f}**  \n"
                md_auc += (
                    "- AUC near 0.5 ⇒ synthetic is hard to distinguish from real  \n"
                    "- AUC near 1.0 ⇒ classifier easily separates real vs. synthetic\n"
                )
                logger.info(f"  → Distinguishability AUC = {auc:.6f}")
            show_auc = not np.isnan(auc)

            return (
                # PCA outputs
                md_pca,                # 1: PCA Markdown
                scatter_fig,           # 2: PCA scatter
                gr.update(visible=True),
                gr.update(visible=show_scatter),
                # Correlation similarity
                md_corr,               # 5: correlation Markdown
                gr.update(visible=show_corr),
                # Contingency similarity
                md_cont,               # 7: contingency Markdown
                gr.update(visible=show_cont),
                # Distinguishability AUC
                md_auc,                # 9: AUC Markdown
                gr.update(visible=show_auc),
            )

        # Bind the button to 10 outputs (5 pairs of (value, visibility))
        self.compute_button.click(
            compute_multivariate,
            inputs=[self.report_state, self.ncomp_slider],
            outputs=[
                self.pca_metrics_output,      # 1
                self.pca_scatter_plot,        # 2
                self.pca_metrics_output,      # 3 (visibility)
                self.pca_scatter_plot,        # 4
                self.corr_sim_output,         # 5
                self.corr_sim_output,         # 6 (visibility)
                self.contingency_sim_output,  # 7
                self.contingency_sim_output,  # 8 (visibility)
                self.distinguish_auc_output,  # 9
                self.distinguish_auc_output   # 10 (visibility)
            ]
        )


In [None]:
class GlobalSummaryTab(Tab):
    """
    “Global Summary” tab (Step 5):
      - Button to “Compute Global Summary”
      - Dataframe: one row per common column, listing all metrics
      - Bar‐chart: top-5 most divergent columns by primary_divergence
    """
    def __init__(self, report_state, common_cols_state):
        super().__init__(report_state, common_cols_state)

    def build_ui(self):
        with gr.TabItem("Global Summary"):
            self.compute_summary_btn = gr.Button("Compute Global Summary")

            # Expanded set of headers to include all metrics + primary_divergence
            self.summary_table = gr.Dataframe(
                headers=[
                    "column",
                    "dtype",

                    # Missing-rate (real, syn, Δ)
                    "missing_rate_real",
                    "missing_rate_syn",
                    "missing_rate_diff",

                    # Numeric: KS & Wasserstein
                    "ks_similarity",
                    "wasserstein_raw",
                    "wasserstein_norm",

                    # Numeric: mean/median/std
                    "mean_real",
                    "mean_syn",
                    "mean_diff",
                    "mean_pct_of_range",
                    "median_real",
                    "median_syn",
                    "median_diff",
                    "median_pct_of_range",
                    "std_real",
                    "std_syn",
                    "std_diff",
                    "std_pct_of_std",

                    # Numeric: range coverage
                    "range_coverage",

                    # Categorical: TVD & coverage
                    "tvd_similarity",
                    "category_coverage",

                    # Text: token-JS, bigram Jaccard/JS
                    "text_tok_js",
                    "text_bigram_jaccard",
                    "text_bigram_js",

                    # Text: TF–IDF cosine, vocab Jaccard, OOV rate
                    "text_tfidf_cosine",
                    "text_vocab_jaccard",
                    "text_oov_rate",

                    # Text: topic cosine
                    "text_topic_cosine",

                    # Text: length (real, syn, Δ)
                    "len_real",
                    "len_syn",
                    "len_diff",

                    # Text: embedding metrics
                    "text_emb_cosine",
                    "text_emb_mmd",

                    # Global multivariate (same value for every row)
                    "correlation_similarity",
                    "contingency_similarity",
                    "distinguishability_auc",

                    # The chosen “primary divergence” for ranking
                    "primary_divergence"
                ],
                row_count="dynamic",
                interactive=False,
                label="Per-Column Metrics & Divergence"
            )

            # Bar chart for top-5 columns by primary_divergence
            self.divergence_bar = gr.Plot(visible=False)

    def register_callbacks(self):
        # Enable the “Compute” button only once report is loaded
        def enable_button(report):
            logger.info("GlobalSummaryTab: enable_button called.")
            return gr.update(interactive=report is not None)

        self.report_state.change(
            enable_button,
            inputs=[self.report_state],
            outputs=[self.compute_summary_btn]
        )

        # When “Compute Global Summary” is clicked:
        def compute_global_summary(report):
            logger.info("GlobalSummaryTab: compute_global_summary called.")
            if report is None:
                logger.info("  → No report: clearing table & hiding bar chart.")
                return ([], None, gr.update(visible=False))

            rows = []
            # 1) Gather univariate metrics per column
            for col in report.common_columns:
                logger.info(f"  → Computing univariate for column '{col}'.")
                col_info = report.compute_column(col)
                dtype = col_info.pop("dtype")

                # Initialize all possible keys to NaN
                base_row = {
                    "column": col,
                    "dtype": dtype,

                    # Missing-rate
                    "missing_rate_real":  np.nan,
                    "missing_rate_syn":   np.nan,
                    "missing_rate_diff":  np.nan,

                    # Numeric
                    "ks_similarity":      np.nan,
                    "wasserstein_raw":    np.nan,
                    "wasserstein_norm":   np.nan,
                    "mean_real":          np.nan,
                    "mean_syn":           np.nan,
                    "mean_diff":          np.nan,
                    "mean_pct_of_range":  np.nan,
                    "median_real":        np.nan,
                    "median_syn":         np.nan,
                    "median_diff":        np.nan,
                    "median_pct_of_range": np.nan,
                    "std_real":           np.nan,
                    "std_syn":            np.nan,
                    "std_diff":           np.nan,
                    "std_pct_of_std":     np.nan,
                    "range_coverage":     np.nan,

                    # Categorical
                    "tvd_similarity":     np.nan,
                    "category_coverage":  np.nan,

                    # Text
                    "text_tok_js":           np.nan,
                    "text_bigram_jaccard":   np.nan,
                    "text_bigram_js":        np.nan,
                    "text_tfidf_cosine":     np.nan,
                    "text_vocab_jaccard":    np.nan,
                    "text_oov_rate":         np.nan,
                    "text_topic_cosine":     np.nan,
                    "len_real":              np.nan,
                    "len_syn":               np.nan,
                    "len_diff":              np.nan,
                    "text_emb_cosine":       np.nan,
                    "text_emb_mmd":          np.nan,

                    # Global multivariate (filled below)
                    "correlation_similarity": np.nan,
                    "contingency_similarity": np.nan,
                    "distinguishability_auc": np.nan,

                    # To be computed per row
                    "primary_divergence":    np.nan
                }

                for key, val in col_info.items():
                    if key in base_row:
                        base_row[key] = val
                rows.append(base_row)

            df_metrics = pd.DataFrame(rows)
            logger.info("  → DataFrame of univariate metrics constructed.")

            # 2) Compute global multivariate metrics once
            corr_sim = report.compute_correlation_similarity()
            cont_sim = report.compute_contingency_similarity()
            auc_sim  = report.compute_distinguishability_auc()
            logger.info(f"  → Multivariate: corr={corr_sim}, cont={cont_sim}, auc={auc_sim}")

            # 3) Broadcast those values to every row
            df_metrics["correlation_similarity"]  = corr_sim
            df_metrics["contingency_similarity"]  = cont_sim
            df_metrics["distinguishability_auc"]  = auc_sim

            # 4) Define “primary_divergence” based on dtype
            def get_primary_divergence(r):
                dtype = r["dtype"]
                if dtype == "numeric":
                    wnorm = r["wasserstein_norm"]
                    if not pd.isna(wnorm):
                        return wnorm
                    ks = r["ks_similarity"]
                    return (1.0 - ks) if not pd.isna(ks) else np.nan

                elif dtype == "categorical":
                    tvd = r["tvd_similarity"]
                    return (1.0 - tvd) if not pd.isna(tvd) else np.nan

                elif dtype == "text":
                    tfidf = r["text_tfidf_cosine"]
                    if not pd.isna(tfidf):
                        return (1.0 - tfidf)
                    embc = r["text_emb_cosine"]
                    return (1.0 - embc) if not pd.isna(embc) else np.nan

                else:
                    return np.nan

            df_metrics["primary_divergence"] = df_metrics.apply(get_primary_divergence, axis=1)
            logger.info("  → Computed primary_divergence for each column.")

            # 5) Sort by primary_divergence (descending), then take top 5
            df_sorted = df_metrics.sort_values(
                by="primary_divergence", ascending=False, na_position="last"
            )
            top_n = min(5, len(df_sorted))
            df_top = df_sorted.iloc[:top_n, :]
            logger.info(f"  → Top {top_n} columns by primary_divergence: {df_top['column'].tolist()}")

            # 6) Build bar chart for top_n columns
            fig = go.Figure()
            fig.add_trace(
                go.Bar(
                    x=df_top["column"],
                    y=df_top["primary_divergence"],
                    text=df_top["primary_divergence"].round(4),
                    textposition="auto",
                    name="Primary Divergence"
                )
            )
            fig.update_layout(
                title=f"Top {top_n} Divergent Columns (by Primary Divergence)",
                xaxis_title="Column",
                yaxis_title="Divergence",
                margin=dict(l=40, r=40, t=50, b=40)
            )
            logger.info("  → Bar chart for top divergences ready.")

            # 7) Return the full table and show the bar chart
            return (
                df_metrics.values.tolist(),
                fig,
                gr.update(visible=True)
            )

        self.compute_summary_btn.click(
            compute_global_summary,
            inputs=[self.report_state],
            outputs=[
                self.summary_table,   # table data (list of lists)
                self.divergence_bar,  # bar chart figure
                self.divergence_bar   # bar chart visibility toggle
            ]
        )

In [None]:


# ─── “Build & Launch” with Three Tabs ─────────────────────────────────

with gr.Blocks() as demo:
    logger.info("Launching Gradio Blocks app with three tabs.")
    # (1) Create the two State objects inside the Blocks context
    report_state      = gr.State(None)
    common_cols_state = gr.State(None)

    # (2) Instantiate all three tabs, passing the shared states
    general_tab  = GeneralStatsTab(report_state, common_cols_state)
    multivar_tab = MultivariateTab(report_state, common_cols_state)
    global_tab   = GlobalSummaryTab(report_state, common_cols_state)

    # (3) Build the UI for all three tabs under *one* Tabs container
    with gr.Tabs():
        general_tab.build_ui()
        multivar_tab.build_ui()
        global_tab.build_ui()

    # (4) Register callbacks for each tab (after the Tabs block)
    general_tab.register_callbacks()
    multivar_tab.register_callbacks()
    global_tab.register_callbacks()

# (5) Launch the app (in Colab use share=True)
demo.launch(share=True, debug=True)  # or debug=True if you want logs

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://6ed079be3f3bef6124.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://d32197b955f5131ecf.gradio.live
Killing tunnel 127.0.0.1:7861 <> https://6ffb363646dfcc34a0.gradio.live
Killing tunnel 127.0.0.1:7862 <> https://6d806f1c19404b72ec.gradio.live
Killing tunnel 127.0.0.1:7862 <> https://6ed079be3f3bef6124.gradio.live


