In [None]:
import os
import re
import json
import base64
import logging
import tempfile
import itertools
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import gradio as gr

# PDF/table tools (you already used these in your original script)
from pdf2image import convert_from_path

# OpenAI client - placeholder, keep your existing initialization
from openai import OpenAI

# --- CONFIG --- replace with your real values via environment variables
AZURE_DEPLOYMENT = os.getenv("DEPLOYMENT_NAME", "gpt-4o")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  # replace securely

client = OpenAI(api_key=OPENAI_API_KEY)

# ------------------------
# Utilities
# ------------------------
def encode_image_to_base64(path: str) -> str:
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode()

def make_pdf_iframe(path: str):
    b64 = base64.b64encode(open(path,"rb").read()).decode()
    return f'<iframe src="data:application/pdf;base64,{b64}" width="600" height="800" style="border:none;"></iframe>'

RE_MEAN_SD = re.compile(r'^\s*([+-]?\d+(?:\.\d+)?)\s*(?:±|\+/-|\u00B1)\s*([+-]?\d+(?:\.\d+)?)\s*$')
RE_PAREN = re.compile(r'^\s*([+-]?\d+(?:\.\d+)?)\s*\(\s*([+-]?\d+(?:\.\d+)?)\s*\)\s*$')

def parse_mean_sd_cell(cell):
    """Try to parse 'mean ± sd' or 'mean(sd)' patterns; else None."""
    if pd.isna(cell):
        return None
    s = str(cell).strip()
    m = RE_MEAN_SD.match(s)
    if m:
        return float(m.group(1)), float(m.group(2))
    m2 = RE_PAREN.match(s)
    if m2:
        return float(m2.group(1)), float(m2.group(2))
    return None

# ------------------------
# f_range and ANOVA-from-summaries (simplified & robust)
# (We implement a version suited to 1D (k groups) or 2D (a x b) cells.)
# ------------------------
def _anova_oneway_from_summaries(m, s, n):
    m = np.asarray(m, dtype=float)
    s = np.asarray(s, dtype=float)
    n = np.asarray(n, dtype=float)
    k = len(m)
    N = np.sum(n)
    if k < 2:
        return np.nan, np.nan, np.nan
    grand_mean = np.sum(n * m) / N
    ss_between = np.sum(n * (m - grand_mean) ** 2)
    ss_within = np.sum((n - 1.0) * (s ** 2))
    df_between = k - 1
    df_within = int(np.sum(n) - k)
    if df_between <= 0 or df_within <= 0:
        return np.nan, df_between, df_within
    ms_between = ss_between / df_between
    ms_within = ss_within / df_within
    F = ms_between / ms_within if ms_within != 0 else np.nan
    return F, df_between, df_within

def _anova_twoway_from_summaries(m, s, n):
    m = np.asarray(m, dtype=float)
    s = np.asarray(s, dtype=float)
    n = np.asarray(n, dtype=float)
    if m.ndim != 2:
        raise ValueError("m must be 2D for two-way ANOVA")
    a, b = m.shape
    N = np.sum(n)
    grand_mean = np.sum(n * m) / N
    ss_within = np.sum((n - 1.0) * (s ** 2))
    ss_treatment = np.sum(n * (m - grand_mean) ** 2)
    n_row = np.sum(n, axis=1)
    mean_row = np.sum(n * m, axis=1) / n_row
    n_col = np.sum(n, axis=0)
    mean_col = np.sum(n * m, axis=0) / n_col
    ss_A = np.sum(n_row * (mean_row - grand_mean) ** 2)
    ss_B = np.sum(n_col * (mean_col - grand_mean) ** 2)
    ss_AB = ss_treatment - ss_A - ss_B
    df_A = a - 1
    df_B = b - 1
    df_AB = (a - 1) * (b - 1)
    df_within = int(np.sum(n) - a * b)
    if df_within <= 0:
        return [np.nan, np.nan, np.nan], [df_A, df_B, df_AB, df_within]
    ms_A = ss_A / df_A if df_A > 0 else np.nan
    ms_B = ss_B / df_B if df_B > 0 else np.nan
    ms_AB = ss_AB / df_AB if df_AB > 0 else np.nan
    ms_within = ss_within / df_within
    F_A = ms_A / ms_within if ms_within != 0 else np.nan
    F_B = ms_B / ms_within if ms_within != 0 else np.nan
    F_AB = ms_AB / ms_within if ms_within != 0 else np.nan
    return [F_A, F_B, F_AB], [df_A, df_B, df_AB, df_within]

def f_range(m, s, n, title=None, show_t=False, dp_p=-1, labels=None, max_enumeration=2**16):
    """
    Compute nominal, min and max plausible F (or t if show_t=True) values
    given reported means, sds, ns, allowing for rounding error.
    Works for 1D (one-way) and 2D (two-way) cell arrays.
    """
    m_arr = np.array(m)
    s_arr = np.array(s)
    n_arr = np.array(n)

    # detect decimals
    dp = dp_p
    if dp_p == -1:
        dp = 0
        numbers = np.concatenate([np.ravel(m_arr).astype(float), np.ravel(s_arr).astype(float)])
        for x in numbers:
            if not np.isclose(x, np.round(x, 0)):
                dp = max(dp, 1)
                if not np.isclose(x * 10, np.round(x * 10, 0)):
                    dp = max(dp, 2)
    delta = (0.1 ** dp) / 2.0

    # nominal
    if m_arr.ndim == 1:
        f_nom, _, _ = _anova_oneway_from_summaries(m_arr, s_arr, n_arr)
        useFs = [f_nom]
        default_labels = ["F" if not show_t else "t"]
    elif m_arr.ndim == 2:
        Fs_nom, _ = _anova_twoway_from_summaries(m_arr, s_arr, n_arr)
        useFs = Fs_nom
        default_labels = ["A F", "B F", "A:B F"]
    else:
        raise ValueError("m must be 1D or 2D array-like")

    m_flat = np.ravel(m_arr)
    s_flat = np.ravel(s_arr)
    n_flat = np.ravel(n_arr)
    l = len(m_flat)

    total_combinations = 2 ** l
    if total_combinations > max_enumeration:
        enumeration_mode = "sampled"
        rng = np.random.default_rng(12345)
        sampled_codes = rng.choice(total_combinations, size=max_enumeration, replace=False)
    else:
        enumeration_mode = "full"
        sampled_codes = None

    eps = 1e-8
    s_hi = np.maximum(s_flat - delta, eps)
    s_lo = s_flat + delta

    f_hi = np.array(useFs, dtype=float)
    f_lo = np.array(useFs, dtype=float)

    def compute_for_signs(sign_vector):
        m_adj = (m_flat + sign_vector)
        if m_arr.ndim == 1:
            F_hi, _, _ = _anova_oneway_from_summaries(m_adj, s_hi, n_flat)
            F_lo, _, _ = _anova_oneway_from_summaries(m_adj, s_lo, n_flat)
            return [F_hi], [F_lo]
        else:
            a, b = m_arr.shape
            m_adj_mat = m_adj.reshape(a, b)
            s_hi_mat = s_hi.reshape(a, b)
            s_lo_mat = s_lo.reshape(a, b)
            F_hi_list, _ = _anova_twoway_from_summaries(m_adj_mat, s_hi_mat, n_arr)
            F_lo_list, _ = _anova_twoway_from_summaries(m_adj_mat, s_lo_mat, n_arr)
            return F_hi_list, F_lo_list

    if enumeration_mode == "full":
        for signs in itertools.product([-delta, delta], repeat=l):
            sign_vec = np.array(signs)
            F_hi_list, F_lo_list = compute_for_signs(sign_vec)
            f_hi = np.maximum(f_hi, np.array(F_hi_list, dtype=float))
            f_lo = np.minimum(f_lo, np.array(F_lo_list, dtype=float))
    else:
        for code in sampled_codes:
            sign_vector = np.array([delta if ((code >> i) & 1) == 1 else -delta for i in range(l)])
            F_hi_list, F_lo_list = compute_for_signs(sign_vector)
            f_hi = np.maximum(f_hi, np.array(F_hi_list, dtype=float))
            f_lo = np.minimum(f_lo, np.array(F_lo_list, dtype=float))

    if show_t:
        f_nom_out = np.sqrt(np.clip(np.array(useFs, dtype=float), a_min=0.0, a_max=None))
        f_hi = np.sqrt(np.clip(f_hi, a_min=0.0, a_max=None))
        f_lo = np.sqrt(np.clip(f_lo, a_min=0.0, a_max=None))
    else:
        f_nom_out = np.array(useFs, dtype=float)

    labels_out = labels if (labels is not None and len(labels) == len(f_nom_out)) else default_labels
    result = {
        "title": title,
        "labels": labels_out,
        "nominal": [None if np.isnan(x) else float(x) for x in f_nom_out],
        "min": [None if np.isnan(x) else float(x) for x in f_lo],
        "max": [None if np.isnan(x) else float(x) for x in f_hi],
        "dp": int(dp),
        "enumeration_mode": enumeration_mode,
        "total_combinations": total_combinations
    }
    return result

def extract_ai(pdf_path):
    pil_pages = convert_from_path(pdf_path, dpi=200)

    image_blocks = []
    for idx, img in enumerate(pil_pages, start=1):
        tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        tmp.close()
        img.save(tmp.name, format="PNG")
        b64 = encode_image_to_base64(tmp.name)
        os.unlink(tmp.name)
        image_blocks.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/png;base64,{b64}"}
        })
    messages = []
    content = [
            {"type": "text",
             "text":f'''Please extract all exist tables from the following image and return them in json format with datas and headers in exactly same
             format,make sure that all numerical values are extracted with full accuracy,do not change any format or name,include keys: table title and table note,is_baseline indicate whether the table is baseline table.
             Example JSON output(please do not include this one):
             ```json
            {{
            "tables": [
                {{
                "table_title": "Patient Characteristics",
                "is_baseline": true,
                "group_ns": {{ "Study Group": 50, "Control Group": 48 }}  // optional: integer N or null
                "headers": ["Feature", "Study Group", "Control Group", "P-value"],
                "data": [
                    ["Age (years)", "30 ± 3.51", "29.2 ± 2.93", ".76"],
                    ["Height (cm)", "158 ± 12.33", "164 ± 14.55", ".54"],
                    ["Body weight (kg)", "68 ± 3.20", "65 ± 4.21", ".12"]
                ],
                "table_note": "Baseline characteristics of patients."
                }},
                {{
                "table_title": "Outcome 8 months after therapy",
                "is_baseline": false,
                "group_ns": {{ "Study Group": 50, "Control Group": 48 }}  // optional: integer N or null
                "headers": ["Outcome", "Study Group", "Control Group", "P-value"],
                "data": [
                    ["Menstruating", "35 (89.6%)", "13 (33.3%)", "<.001*"],
                    ["Ovulating", "27 (69.2%)", "10 (25.6%)", "<.001*"],
                    ["POF", "4 (11.4%)", "21 (66.6%)", "<.001*"]
                ],
                "table_note": "P-value < .05 was considered statistically significant."
                }}
                ]
            }}
            ```
             if no table,return No table provided:\n\n'''
            }]+ image_blocks
    messages.append({"role": 'user', "content": content})
    # Call the API to generate the table extraction
    completion = client.chat.completions.create(
        model=AZURE_DEPLOYMENT,
        messages=messages,
        max_tokens=8000,
        temperature=0.0,
        top_p=0.9,
    )
    print("\n📋 Extracted Table:\n")
    extract_tables = completion.choices[0].message.content
    print(extract_tables)
    chat_prompt_split=[
        {"role": "system", "content": "You are a data analysis expert."},
        {"role": "user", "content": f'''Please separate the Mean Standard Deviation (SD) into different columns inside all the following tables,
        return all tables even they remain unchanged,only return data,no explaintions,if no table,return No table provided,here is a example
        Example(DO NOT include this one):
        ```json
            {{
            "tables": [
                {{
                "table_title": "Patient Characteristics",
                "is_baseline": true,
                "group_ns": {{ "Study Group": 50, "Control Group": 48 }}  // optional: integer N or null
                "headers": ["Feature", "Study Group","Study Group Mean","Study Group SD", "Control Group","Control Group Mean","Control Group SD","P-value"],
                "data": [
                    ["Age (years)", "30 ± 3.51","30", "3.51","29.2 ± 2.93","29.2","2.93", ".76"],
                    ["Height (cm)", "158 ± 12.33","158", "12.33","164 ± 14.55","164","14.55", ".54"],
                    ["Body weight (kg)", "68 ± 3.20","68", "3.20","65 ± 4.21","65", "4.21",".12"]
                ],
                "table_note": "Baseline characteristics of patients."
                }},
                {{
                "table_title": "Outcome 8 months after therapy",
                "group_ns": {{ "Study Group": 50, "Control Group": 48 }}  // optional: integer N or null
                "is_baseline": false,
                "headers": ["Outcome", "Study Group", "Control Group", "P-value"],
                "data": [
                    ["Menstruating", "35 (89.6%)", "13 (33.3%)", "<.001*"],
                    ["Ovulating", "27 (69.2%)", "10 (25.6%)", "<.001*"],
                    ["POF", "4 (11.4%)", "21 (66.6%)", "<.001*"]
                ],
                "table_note": "P-value < .05 was considered statistically significant."
                }}
                ]
            }}
        ```
         :\n\n{extract_tables}'''}
    ]
    messages = chat_prompt_split
    completion = client.chat.completions.create(
        model=AZURE_DEPLOYMENT,
        messages=messages,
        max_tokens=8000,
        temperature=0.0,
        top_p=0.95,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None,
        stream=False
    )
    print("\n📋 Splited Table:\n")
    print(completion.choices[0].message.content)
    json_str = completion.choices[0].message.content
    json_str = json_str.replace("```json", "").replace("```", "").strip()
    try:
        data = json.loads(json_str)
    except json.JSONDecodeError as e:
        raise ValueError(f"Not JSON：{e}\n：\n{json_str}")

    output_path = "extracted_tables.json"
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    return completion.choices[0].message.content


def extract_ai_data(files):
    """Use image->LLM route - adapted to also expect group_ns in LLM output (similar to above)."""
    if not files:
        return [], [], "", {}
    pdf_path = files[0]
    b64 = base64.b64encode(open(pdf_path, "rb").read()).decode()
    html_iframe = f'<iframe src="data:application/pdf;base64,{b64}" width="600" height="800" style="border:none;"></iframe>'
    # For brevity call the same LLM prompt on image text: (you can implement image blocks like your original)
    # Here we assume you already have a function extract_ai that returns the same JSON schema as above.
    # We'll reuse the same JSON parsing logic: extract_ai() => json_str
    # For this combined example, we just call a dummy extract_ai that you have in original script.
    json_str = extract_ai(pdf_path)  # expects same output schema with group_ns
    try:
        data = json.loads(json_str.replace("```json", "").replace("```", "").strip())
    except Exception:
        clean = json_str.strip("```json").strip("```")
        data = json.loads(clean)
    tables = data.get("tables", [])
    names  = [t.get("table_title", f"table {i}") for i, t in enumerate(tables)]
    dfs    = [pd.DataFrame(t["data"], columns=t["headers"]) for t in tables]
    pdf_name = Path(pdf_path).name
    state = {
      "pdf_names":    [pdf_name],
      "table_names":  [names],
      "tables":       [dfs],
      "table_ns":     [[t.get("group_ns", {}) for t in tables]],
      "iframes":      [html_iframe],
    }
    return [pdf_name], names, html_iframe, state

# ------------------------
# Parsing helper (pair group mean/sd columns)
# ------------------------
def detect_group_mean_sd_headers(headers):
    """
    Given headers list, detect pairs of mean/sd columns.
    Expect header strings like '24-h group Mean', '24-h group SD'.
    Return list of groups: [(group_label, mean_col, sd_col), ...]
    mean_col/sd_col are header names (strings) or None.
    """
    norm_map = {}
    # normalize: remove tokens mean/sd and parentheses
    for h in headers:
        hstr = str(h).strip()
        hnorm = re.sub(r'[\(\)\[\]\{\}]', ' ', hstr).strip()
        low = hnorm.lower()
        # detect role
        if re.search(r'\b(mean|avg|average|m|µ|mu)\b', low):
            role = 'mean'
        elif re.search(r'\b(sd|s\.d|std|stddev|se|stderr)\b', low):
            role = 'sd'
        else:
            role = None
        # compute prefix by removing role tokens
        prefix = re.sub(r'\b(mean|avg|average|m|µ|mu|sd|s\.d|std|stddev|se|stderr)\b', '', low).strip()
        if prefix == '':
            prefix = low  # fallback to full header
        if prefix not in norm_map:
            norm_map[prefix] = {'mean': None, 'sd': None, 'raw_prefix': prefix}
        if role == 'mean':
            norm_map[prefix]['mean'] = hstr
        elif role == 'sd':
            norm_map[prefix]['sd'] = hstr
        else:
            # unknown: keep as potential group header if no explicit mean/sd
            if prefix not in norm_map:
                norm_map[prefix] = {'mean': None, 'sd': None, 'raw_prefix': prefix}
            # we don't set anything here
    groups = []
    for prefix, info in norm_map.items():
        if info.get('mean') is not None or info.get('sd') is not None:
            groups.append((prefix, info.get('mean'), info.get('sd')))
    return groups

# ------------------------
# Analysis: compute per-row t (and show ranges with f_range)
# ------------------------
def analyze_selected_table(selected_table: str, selected_file: str, state: dict,
                           detected_ns_json: str, assume_n_json: str, dp_p: int, show_t_flag: bool):
    # prefer detected_ns_json (edited), else try assume_n_json
    manual_ns_json = None
    if detected_ns_json and detected_ns_json.strip():
        manual_ns_json = detected_ns_json
    elif assume_n_json and assume_n_json.strip():
        manual_ns_json = assume_n_json
    else:
        manual_ns_json = None
    """
    Parse currently selected table and compute per-row stats (t or F).
    - manual_ns_json: optional editable JSON input (user override)
    - dp_p: decimal places for rounding (-1 auto)
    - show_t_flag: True => output t values (sqrt of F)
    Returns:
      - analysis_text: string summary
      - result_df: pandas DataFrame ready to display in UI (one row per original row, with computed nominal/min/max)
    """
    if not state or "pdf_names" not in state:
        return "No state present. Run extraction first.", pd.DataFrame()

    try:
        file_idx = state["pdf_names"].index(selected_file)
    except ValueError:
        return "Selected file not in state.", pd.DataFrame()

    try:
        tbl_idx = state["table_names"][file_idx].index(selected_table)
    except ValueError:
        return "Selected table not found in state.", pd.DataFrame()

    df = state["tables"][file_idx][tbl_idx].copy()
    headers = list(df.columns)

    # 1) detect groups (mean/sd column pairs)
    groups = detect_group_mean_sd_headers(headers)
    if len(groups) < 2:
        # try fallback: maybe columns are "GroupA Mean", "GroupA SD", "GroupB Mean", ...
        return "Could not detect at least two group mean/sd pairs in headers. Ensure LLM split mean and sd into separate columns.", pd.DataFrame()

    # 2) build mapping header->group name and locate mean/sd column names
    group_labels = [g[0] for g in groups]
    mean_cols = [g[1] for g in groups]
    sd_cols = [g[2] for g in groups]

    # 3) Load Ns: priority order
    #   a) manual override JSON in UI (if provided)
    #   b) LLM-provided state["table_ns"]
    #   c) try to find Ns inside cells (simple heuristics)
    override_ns = {}
    if manual_ns_json:
        try:
            parsed = json.loads(manual_ns_json)
            if isinstance(parsed, dict):
                override_ns = {str(k): int(v) for k, v in parsed.items()}
        except Exception:
            # invalid JSON - ignore and continue to other sources (we'll inform user)
            override_ns = {}

    llm_ns_map = {}
    if "table_ns" in state:
        try:
            llm_ns_map = state["table_ns"][file_idx][tbl_idx]
        except Exception:
            llm_ns_map = {}

    # merged ns_map (fallback)
    ns_map = {}
    for label, mcol, scol in groups:
        # try manual override exact match or fuzzy match
        n_val = None
        # exact label keys first
        if label in override_ns:
            n_val = int(override_ns[label])
        else:
            # fuzzy: try to find a key in override_ns that matches prefix
            for k in override_ns:
                if k.lower() in label.lower() or label.lower() in k.lower():
                    n_val = int(override_ns[k]); break
        if n_val is None:
            # try LLM-provided
            if label in llm_ns_map and llm_ns_map[label] is not None:
                n_val = int(llm_ns_map[label])
            else:
                # fuzzy match
                for k in llm_ns_map:
                    if k.lower() in label.lower() or label.lower() in k.lower():
                        n_val = llm_ns_map[k]
                        if n_val is not None:
                            n_val = int(n_val)
                        break
        # finally try to detect an N in mean or sd column values (like '56 (89.6%)')
        if n_val is None:
            for c in (mcol, scol):
                if c is None or c not in df.columns:
                    continue
                vals = []
                for cell in df[c].astype(str):
                    if not cell or cell.strip() == "":
                        continue
                    m = re.match(r'^\s*(\d+)\s*\(\s*[0-9\.]+%?\s*\)\s*$', cell.strip())
                    if m:
                        vals.append(int(m.group(1)))
                if len(vals) > 0:
                    # choose the most common
                    vals_u, counts = np.unique(vals, return_counts=True)
                    n_val = int(vals_u[np.argmax(counts)])
                    break
        ns_map[label] = n_val

    # check any missing
    missing_ns = [k for k, v in ns_map.items() if v is None]
    if missing_ns:
        # Ask user to provide Ns via manual override; present parsed preview
        parsed_preview = {
            "means_headers": mean_cols,
            "sd_headers": sd_cols,
            "detected_ns": ns_map,
            "labels": group_labels
        }
        return ("Missing sample sizes for groups: " + ", ".join(missing_ns) +
                ". Provide Ns in the 'Assume N' box as JSON map (e.g. {\"24-h group\":56, \"72-h group\":51}).\n\n"
                "Parsed preview:\n" + json.dumps(parsed_preview, indent=2)), pd.DataFrame()

    # 4) For each row, compute t (2-group case) or F (k-group case) & ranges using f_range
    results = []
    num_groups = len(groups)
    for ridx, row in df.iterrows():
        # build per-group mean and sd for this row
        mvals = []
        svals = []
        nvals = []
        for label, mcol, scol in groups:
            # parse mean
            mean_val = np.nan
            sd_val = np.nan
            # mean column may be None (rare) - try to find best candidate
            if mcol and mcol in df.columns:
                cell = df.at[ridx, mcol]
                parsed = parse_mean_sd_cell(cell)
                if parsed:
                    mean_val = parsed[0]
                    # if parsed sd and separate sd col also exists, prefer separate sd col
                    if parsed[1] is not None and (not scol or scol not in df.columns):
                        sd_val = parsed[1]
                else:
                    try:
                        mean_val = float(str(cell).strip())
                    except:
                        mean_val = np.nan
            # parse sd if separate column exists
            if scol and scol in df.columns:
                scell = df.at[ridx, scol]
                if pd.isna(scell) or str(scell).strip()=="":
                    pass
                else:
                    parsed2 = re.search(r'([+-]?\d+(?:\.\d+)?)', str(scell))
                    if parsed2:
                        try:
                            sd_val = float(parsed2.group(1))
                        except:
                            pass
            # if sd still nan and mean column had sd, keep that
            if np.isnan(sd_val):
                # try to re-parse from mean column if not yet done
                if mcol and mcol in df.columns:
                    parsed = parse_mean_sd_cell(df.at[ridx, mcol])
                    if parsed:
                        sd_val = parsed[1]

            mvals.append(mean_val)
            svals.append(sd_val)
            nvals.append(ns_map[label])

        # call f_range for this single-row case
        try:
            r = f_range(m=mvals, s=svals, n=nvals, title=f"{selected_file} - {selected_table} - row {ridx}",
                        show_t=show_t_flag, dp_p=dp_p)
            # f_range returns arrays for nominal/min/max; for one-way with k groups we may get len>1 (F_A etc.)
            # For simple 2-group one-way, r['nominal'][0] is t (if show_t=True) or F (if show_t=False)
            nominal = r['nominal']
            mn = r['min']
            mx = r['max']
            # produce row label if first column is row header
            row_label = None
            first_col = df.columns[0] if len(df.columns)>0 else None
            if first_col:
                row_label = str(df.at[ridx, first_col])
            results.append({
                "row_index": int(ridx),
                "row_label": row_label,
                "nominal": nominal,
                "min": mn,
                "max": mx,
                "dp_used": r['dp'],
                "enumeration_mode": r['enumeration_mode'],
                "total_combinations": r['total_combinations']
            })
        except Exception as e:
            results.append({
                "row_index": int(ridx),
                "row_label": str(df.at[ridx, df.columns[0]]) if df.shape[1]>0 else f"row{ridx}",
                "error": str(e)
            })

    # Build result DataFrame convenient for display
    rows_out = []
    for item in results:
        if 'error' in item:
            rows_out.append({
                "row_index": item["row_index"],
                "row_label": item.get("row_label"),
                "error": item["error"]
            })
            continue
        # join arrays into readable strings (nom/min/max for each effect)
        nominal_str = ", ".join([str(round(x,6)) if x is not None else "NA" for x in item["nominal"]])
        min_str = ", ".join([str(round(x,6)) if x is not None else "NA" for x in item["min"]])
        max_str = ", ".join([str(round(x,6)) if x is not None else "NA" for x in item["max"]])
        rows_out.append({
            "row_index": item["row_index"],
            "row_label": item.get("row_label"),
            "nominal": nominal_str,
            "min": min_str,
            "max": max_str,
            "dp": item["dp_used"],
            "enumeration_mode": item["enumeration_mode"],
            "total_combinations": item["total_combinations"]
        })

    result_df = pd.DataFrame(rows_out)

    # Compose summary text
    summary_lines = []
    summary_lines.append(f"Analysis for table: {selected_table} (file: {selected_file})")
    summary_lines.append(f"Groups detected (label order): {', '.join(group_labels)}")
    summary_lines.append("Note: nominal/min/max may contain multiple values (one per tested effect).")
    summary_text = "\n".join(summary_lines)

    return summary_text, result_df

# ------------------------
# Gradio UI
# ------------------------
def combined_extract(mode, method, files):
    # File upload methods

    if mode=="File Upload" and method=="AI":
        file_names, table_names, html_iframe, state = extract_ai_data(files)
        return gr.update(choices=file_names, value=file_names[0] if file_names else None), gr.update(choices=table_names, value=table_names[0] if table_names else None), html_iframe, state


with gr.Blocks() as app:
    gr.Markdown("## 📑 table extraction + t/F analysis (LLM-detected Ns with editable override)")

    mode    = gr.Radio(["File Upload"], value="File Upload", label="Mode")
    method = gr.Radio(["AI"], value="AI", label="File Upload Method")
    files   = gr.File(file_count="multiple", type="filepath", label="Upload PDF(s)")

    file_selector  = gr.Dropdown(label="Select PDF",   choices=[], value=None)
    table_selector = gr.Dropdown(label="Select Table", choices=[], value=None)
    state = gr.State({})

    with gr.Row():
        with gr.Column(scale=1):
            pdf_preview = gr.HTML(label="PDF Preview")
            # show detected Ns for the selected table (LLM)
            detected_ns_box = gr.Textbox(label="Detected Ns (from LLM) - click 'Load detected Ns' then edit if needed", lines=3)
            load_ns_btn = gr.Button("Load detected Ns for selected table")
        with gr.Column(scale=1):
            table_view = gr.Dataframe(label="Extracted Table (select a row to inspect)", interactive=True)
            dl_trigger = gr.Button("Download CSV")
            download_btn = gr.File(label="", file_count="single", type="filepath", visible=False)

    # Analysis panel
    gr.Markdown("### Analysis Controls")
    assume_n_input = gr.Textbox(label="Manual Ns override (JSON dict) e.g. {\"24-h group\":56, \"72-h group\":51}", placeholder='{"24-h group":56,"72-h group":51}', lines=2)
    dp_input = gr.Number(value=-1, label="dp (decimal places) (-1 = auto)", precision=0)
    show_t_checkbox = gr.Checkbox(value=True, label="Show t (sqrt of F)")
    analyze_btn = gr.Button("Analyze selected table")
    analysis_output = gr.Textbox(label="Analysis Summary", lines=6)
    analysis_df = gr.Dataframe(label="Per-row results", interactive=False)

    def on_download(selected_table, selected_file, st):
        if not st:
            return ""
        fidx = st["pdf_names"].index(selected_file)
        tidx = st["table_names"][fidx].index(selected_table)
        df = st["tables"][fidx][tidx]
        csv_path = f"{selected_file}_{selected_table}.csv"
        df.to_csv(csv_path, index=False)
        return csv_path

    dl_trigger.click(fn=on_download, inputs=[table_selector, file_selector, state], outputs=[download_btn])

    # Wire extraction button
    extract_btn = gr.Button("Extract")
    extract_btn.click(fn=combined_extract,
                      inputs=[mode, method, files],
                      outputs=[file_selector, table_selector, pdf_preview, state])

    # file/table selection callbacks
    def on_file_change(selected_file: str, st: dict):
        if not st:
            return gr.update(choices=[], value=None), "", pd.DataFrame()
        file_idx = st["pdf_names"].index(selected_file)
        iframe_html = st["iframes"][file_idx]
        names = st["table_names"][file_idx]
        df0 = st["tables"][file_idx][0] if names else pd.DataFrame()
        return gr.update(choices=names, value=names[0] if names else None), iframe_html, df0

    file_selector.change(fn=on_file_change, inputs=[file_selector, state], outputs=[table_selector, pdf_preview, table_view])

    def on_table_change(selected_table: str, selected_file: str, st: dict):
        if not st:
            return pd.DataFrame()
        file_idx = st["pdf_names"].index(selected_file)
        tbl_idx  = st["table_names"][file_idx].index(selected_table)
        return st["tables"][file_idx][tbl_idx]

    table_selector.change(fn=on_table_change, inputs=[table_selector, file_selector, state], outputs=[table_view])

    # Load detected Ns into the editable textbox
    def load_detected_ns(selected_table, selected_file, st):
        if not st:
            return ""
        try:
            file_idx = st["pdf_names"].index(selected_file)
            tbl_idx = st["table_names"][file_idx].index(selected_table)
            ns_map = st.get("table_ns", [[{}]])[file_idx][tbl_idx]
            return json.dumps(ns_map, indent=2)
        except Exception:
            return ""

    load_ns_btn.click(fn=load_detected_ns, inputs=[table_selector, file_selector, state], outputs=[detected_ns_box])

    # Analysis button wires to analyzer
    analyze_btn.click(fn=analyze_selected_table,
                  inputs=[table_selector, file_selector, state, detected_ns_box, assume_n_input, dp_input, show_t_checkbox],
                  outputs=[analysis_output, analysis_df])

    app.launch(share=True,debug=True)


* Running on local URL:  http://127.0.0.1:7863

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.



📋 Extracted Table:

```json
{
    "tables": [
        {
            "table_title": "Demographic characteristics of the participants (n = 145)",
            "is_baseline": true,
            "group_ns": { "Growth hormone group": 72, "Microflare only group": 73 },
            "headers": ["Characteristic", "Growth hormone group", "Microflare only group", "P value"],
            "data": [
                ["Age, y", "34.9 ± 4.8", "34.8 ± 5.6", "0.812"],
                ["Duration of infertility, y", "7.3 ± 3.5", "7.5 ± 3.4", "0.636"],
                ["Anti-Müllerian hormone level, ng/mL", "0.4 ± 0.2", "0.5 ± 0.2", "0.744"],
                ["Day 3 follicle-stimulating hormone level, IU/L", "10.2 ± 2.9", "9.9 ± 2.3", "0.548"],
                ["BMI", "23.3 ± 3.6", "23.3 ± 3.6", "0.939"],
                ["Antral follicle count", "5.9 ± 1.6", "5.9 ± 1.7", "0.761"],
                ["No. of previous poor responses", "2.4 ± 1.5", "2.7 ± 1.5", "0.277"]
            ],
            "table_note": "