In [41]:
import os
import openai
import pandas as pd
from io import StringIO
from langchain_openai import ChatOpenAI
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# If you want to chunk PDF text, you can also import TextSplitter utilities:
# from langchain.text_splitter import RecursiveCharacterTextSplitter
import PyPDF2
from dotenv import load_dotenv
import os

## 1. Configure OpenAI Keys

In [42]:
load_dotenv()
OPENAI_API_KEY= os.getenv('OPENAI_API_KEY')
# In code, you might do:
# openai.api_key = os.getenv("OPENAI_API_KEY")

## 2. Helper Functions

In [43]:
def extract_text_from_pdf(pdf_path: str) -> str:
    """
    Extract all text from a PDF file using PyPDF2.
    """
    text_content = []
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            text_content.append(page.extract_text())
    return "\n".join(text_content)

# def extract_methodology_section(full_text: str) -> str:
#     """
#     A naive approach to extract only the 'Methodology' section from
#     the full PDF text. Adjust to your own needs. 
#     """
#     # Example: Find the text from 'Methodology' heading to next heading like 'Results'
#     # This is very simplistic and might need to handle text structures properly.
#     start_keyword = "Methods"
#     end_keywords = ["Results", "Analysis", "Discussion", "Conclusion"]
    
#     start_index = full_text.lower().find(start_keyword.lower())
#     if start_index == -1:
#         return ""
    
#     # Search for the earliest next section heading
#     subsequent_indices = []
#     for ek in end_keywords:
#         idx = full_text.lower().find(ek.lower(), start_index + len(start_keyword))
#         if idx != -1:
#             subsequent_indices.append(idx)
    
#     if not subsequent_indices:
#         # If we don't find any subsequent heading, take everything after 'Methodology'
#         return full_text[start_index:]
    
#     end_index = min(subsequent_indices)
#     return full_text[start_index:end_index]

# import xml.etree.ElementTree as ET

# Function to create XML structure
# def create_xml_summary(dataset_summary):
#     root = ET.Element("dataset")
#     sheets_elem = ET.SubElement(root, "sheets")
    
#     for sheet_name, sheet_data in dataset_summary.items():
#         sheet_elem = ET.SubElement(sheets_elem, "sheet", name=sheet_name)
        
#         description_elem = ET.SubElement(sheet_elem, "description")
#         description_elem.text = sheet_data.get("description", "No description")
        
#         columns_elem = ET.SubElement(sheet_elem, "columns")
#         for col_name, col_data in sheet_data["columns"].items():
#             column_elem = ET.SubElement(columns_elem, "column", name=col_name)
            
#             inferred_type_elem = ET.SubElement(column_elem, "inferred_type")
#             inferred_type_elem.text = col_data["inferred_type"]
            
#             if "summary_statistics" in col_data:
#                 summary_stats_elem = ET.SubElement(column_elem, "summary_statistics")
#                 for stat_name, stat_value in col_data["summary_statistics"].items():
#                     stat_elem = ET.SubElement(summary_stats_elem, stat_name)
#                     stat_elem.text = str(stat_value)
            
#             if "unique_values" in col_data:
#                 unique_values_elem = ET.SubElement(column_elem, "unique_values")
#                 unique_values_elem.text = str(col_data["unique_values"])
            
#             if "value_counts" in col_data:
#                 value_counts_elem = ET.SubElement(column_elem, "value_counts")
#                 for value, count in col_data["value_counts"].items():
#                     value_elem = ET.SubElement(value_counts_elem, str(value))
#                     value_elem.text = str(count)
                    
#             if "time_span" in col_data:
#                 time_span_elem = ET.SubElement(column_elem, "time_span")
#                 start_elem = ET.SubElement(time_span_elem, "start")
#                 start_elem.text = str(col_data["time_span"]["start"])
#                 end_elem = ET.SubElement(time_span_elem, "end")
#                 end_elem.text = str(col_data["time_span"]["end"])

#     return ET.ElementTree(root)


# def generate_dataset_summary(df: pd.DataFrame) -> dict:
#     """
#     Generate a summary for each column in the dataset:
#       - Numeric columns: min, max, mean, std, etc.
#       - Categorical columns: unique values, counts
#       - Date/Time columns: min date, max date
#     Return a dictionary containing the summary data.
#     """
#     summary_dict = {}
    
#     for col in df.columns:
#         col_info = {}
#         col_data = df[col].dropna()
        
#         # Try to convert to datetime - if works, treat as time column
#         try:
#             col_data_dt = pd.to_datetime(col_data, errors='raise')
#             # If conversion successful, assume time column
#             col_info["column_type"] = "datetime"
#             col_info["time_span_start"] = str(col_data_dt.min())
#             col_info["time_span_end"] = str(col_data_dt.max())
#         except ValueError:
#             # Not date/time, proceed to numeric or categorical logic
#             if pd.api.types.is_numeric_dtype(col_data):
#                 col_info["column_type"] = "numeric"
#                 col_info["count"] = int(col_data.count())
#                 col_info["mean"] = float(col_data.mean())
#                 col_info["std"] = float(col_data.std())
#                 col_info["min"] = float(col_data.min())
#                 col_info["max"] = float(col_data.max())
#             else:
#                 col_info["column_type"] = "categorical"
#                 uniques = col_data.unique()
#                 col_info["unique_values"] = [str(u) for u in uniques]
#                 col_info["unique_value_count"] = len(uniques)
        
#         summary_dict[col] = col_info
    
#     return summary_dict

# def dict_to_xml_summarization(methodology_summary: str,
#                               statistics_extraction: str,
#                               dataset_summary: dict) -> str:
#     """
#     Create an XML string combining methodology summary, 
#     extracted statistical analyses, and dataset summary.
#     """
#     import xml.etree.ElementTree as ET
    
#     root = ET.Element("SummaryOutput")
    
#     # Methodology part
#     methodology_el = ET.SubElement(root, "MethodologySummary")
#     methodology_el.text = methodology_summary
    
#     # Statistical analyses part
#     stats_el = ET.SubElement(root, "StatisticalAnalyses")
#     stats_el.text = statistics_extraction
    
#     # Dataset Summary
#     data_el = ET.SubElement(root, "DatasetSummary")
#     for col, info in dataset_summary.items():
#         col_el = ET.SubElement(data_el, "Column", name=col)
#         for key, val in info.items():
#             sub_el = ET.SubElement(col_el, key)
#             if isinstance(val, list):
#                 sub_el.text = ", ".join(val)
#             else:
#                 sub_el.text = str(val)
    
#     # Convert to string
#     return ET.tostring(root, encoding='unicode')

In [68]:
import pandas as pd
import json

def summarize_excel_to_json(file_path):
    """
    Reads an Excel (.xlsx) file and creates a JSON summary of each sheet.
    For each column it provides:
      1. The variable (column) name.
      2. The inferred variable type (numeric, datetime, categorical).
      3. For numeric columns: standard summary statistics.
      4. For categorical columns: unique values and their counts.
      5. For datetime columns: the time span (min and max dates).

    Parameters:
        file_path (str): The path to the Excel file.

    Returns:
        str: A JSON string summarizing the workbook.
    """
    workbook_summary = {}
    # Load the Excel file
    xls = pd.ExcelFile(file_path)

    # Process each sheet in the workbook
    for sheet in xls.sheet_names:
        df = pd.read_excel(xls, sheet_name=sheet)
        sheet_summary = []
        
        # Process each column in the sheet
        for col in df.columns:
            column_summary = {"variable_name": str(col)}
            series = df[col]

            # Infer the variable type: numeric, datetime, or categorical
            if pd.api.types.is_numeric_dtype(series):
                column_summary["variable_type"] = "numeric"
                # Compute summary statistics
                stats = series.describe()
                column_summary["summary_statistics"] = {
                    "count": stats.get("count"),
                    "mean": stats.get("mean"),
                    "std": stats.get("std"),
                    "min": stats.get("min"),
                    "25%": stats.get("25%"),
                    "50%": stats.get("50%"),
                    "75%": stats.get("75%"),
                    "max": stats.get("max")
                }
            elif pd.api.types.is_datetime64_any_dtype(series):
                column_summary["variable_type"] = "datetime"
                # Compute time span (min and max)
                times = series.dropna()
                if not times.empty:
                    column_summary["time_span"] = {
                        "start": str(times.min()),
                        "end": str(times.max())
                    }
                else:
                    column_summary["time_span"] = {"start": None, "end": None}
            else:
                column_summary["variable_type"] = "categorical"
                # Compute unique values and their counts
                uniques = series.value_counts(dropna=False)
                unique_values = []
                for value, count in uniques.items():
                    # Represent NaN values as the string "NaN"
                    value_str = "NaN" if pd.isna(value) else str(value)
                    unique_values.append({"value": value_str, "count": int(count)})
                column_summary["unique_values"] = unique_values

            sheet_summary.append(column_summary)
        workbook_summary[sheet] = sheet_summary

    # Convert the dictionary to a formatted JSON string
    json_str = json.dumps(workbook_summary, indent=2)
    return json_str

In [69]:
def excel_to_json_dict(file_path):
    # Read all sheets from the Excel file into a dictionary of DataFrames
    sheets_dict = pd.read_excel(file_path, sheet_name=None)

    # Convert each sheet DataFrame to JSON format
    json_dict = {sheet_name: df.to_dict(orient="records") for sheet_name, df in sheets_dict.items()}

    return json_dict

In [70]:
xlsx_path = "../data/poc_example_data/lazaro_et_al_2021.xlsx"

In [71]:
dataset = summarize_excel_to_json(xlsx_path)

In [72]:
dataset

'{\n  "WildBeeDiversit&Abundance": [\n    {\n      "variable_name": "Island",\n      "variable_type": "categorical",\n      "unique_values": [\n        {\n          "value": "Tinos",\n          "count": 5\n        },\n        {\n          "value": "Paros",\n          "count": 4\n        },\n        {\n          "value": "Santorini",\n          "count": 4\n        },\n        {\n          "value": "Anafi",\n          "count": 3\n        },\n        {\n          "value": "Folegandros",\n          "count": 3\n        },\n        {\n          "value": "Ios",\n          "count": 3\n        },\n        {\n          "value": "Kea",\n          "count": 3\n        },\n        {\n          "value": "Kythnos",\n          "count": 3\n        },\n        {\n          "value": "Milos",\n          "count": 3\n        },\n        {\n          "value": "Mykonos",\n          "count": 3\n        },\n        {\n          "value": "Serifos",\n          "count": 3\n        },\n        {\n          "value": 

## 3. LangChain LLM Setup

In [84]:
summarization_llm = ChatOpenAI(
    temperature=0.0,
    model_name="gpt-4o" 
)

planner_llm = ChatOpenAI(
    temperature=1,
    model_name="o1-mini"
)

executor_llm = ChatOpenAI(
    temperature=0.0,
    model_name="o1-mini"
)

In [85]:
statistical_methods_prompt = PromptTemplate(
    input_variables=["full_text"],
    template=(
        "You are a model specialized in ecological data analysis.\n"
        "Read the following text delimited by triple backticks and list all information on the dataset and statistical analyses used.\n"
        "Focus ONLY on:\n"
        "- Extracting all of the tested hypotheses.\n"
        "- Listing all variables that were collected.\n"
        "- Providing example values for categorical variables, and ranges for continuous variables (if available).\n"
        "- Listing all statistical methods with every detail of the performed analysis (specifically which functions, settings, and variables were used).\n"
        "Return only the report.\n"
        "```\n{full_text}\n```"
    )
)

# data_summarization_prompt = PromptTemplate(
#     input_variables=["dataset"],
#     template=(
#         "Summarize the dataset like this:\n"
#         "1. Look at each sheet in the dataset.\n"
#         "2. Describe birefly what it contains\n"
#         "3. For each column in every sheet provide:\n"
#         "a) Short description, what do you think this variable is about based on its name and values.\n"
#         "b) Inferred type.\n"
#         "c) Standard summary statisitics if it is a numeric column.\n"
#         "d) Unique values and unique value counts if it is a categrical valriable.\n"
#         "e) Time span for time data.\n"
#         "4. Format your response as XML.\n"
#         "Dataset: \n{dataset}\n"
#     )
# )

planner_prompt = PromptTemplate(
    input_variables=["methodology_summary", "dataset_summary"],
    template=(
        "You are a planning model. Based on the methodology summary and dataset summary, "
        "plan a step-by-step routine (as programmatic pseudocode or structured steps) "
        "to execute the identified statistical analyses on the full dataset.\n\n"
        "Methodology Summary:\n{methodology_summary}\n\n"
        "Dataset Summary:\n{dataset_summary}\n\n"
        "Provide your plan in XML format, with each <Step> containing a structured explanation "
        "of how to implement it programmatically."
    )
)


executor_prompt = PromptTemplate(
    input_variables=["analysis_plan_xml"],
    template=(
        "You are an executor model specialized in generating R scripts for each step. "
        "Given the plan in XML, do the following:\n"
        "1. Generate separate R scripts for each analysis step.\n"
        "2. Generate a single master R script that runs them all in a structured manner.\n"
        "Output your results clearly, indicating how the scripts should be saved.\n\n"
        "Plan XML:\n{analysis_plan_xml}"
    )
)

In [78]:
# --- Step 1: Read PDF and extract methodology ---
pdf_path = "../data/poc_example_data/lazaro_et_al_2021_accessible.pdf"
full_text = extract_text_from_pdf(pdf_path)

In [79]:
# --- Step 2: Summarize the methodology section & extract stats info ---
summarization_chain = LLMChain(llm=summarization_llm, prompt=statistical_methods_prompt)
summary_result = summarization_chain.run(full_text=full_text)

In [80]:
summary_result

"**Report on Dataset and Statistical Analyses**\n\n**Tested Hypotheses:**\n1. High honey bee visitation is negatively related to wild bee richness and abundance, suggesting interspecific competition.\n2. High honey bee visitation increases competition for resources with wild bees, with varying effects between bee families.\n3. High honey bee visitation influences the structure of wild bee pollination networks, directly or indirectly through changes in pollinator diversity.\n\n**Variables Collected:**\n- Honey bee visitation rate (visits to flowers per 2-hour survey)\n- Wild bee richness (number of species)\n- Wild bee abundance (number of individuals)\n- Flower abundance (number of flowers per m²)\n- Flower richness (number of plant species with flowers)\n- Landscape heterogeneity (Shannon's diversity index)\n- Percentage of natural and semi-natural habitats\n- Island area (km²)\n\n**Example Values and Ranges:**\n- Honey bee visitation rate: 0.0 to 486.3 visits per 2 hours\n- Wild bee 

In [86]:
# Plan statistical analyses based on the methodology summary and the dataset summary
planner_chain = LLMChain(llm=planner_llm, prompt=planner_prompt)
plan_analyses_xml = planner_chain.run(
        methodology_summary=summary_result,
        dataset_summary=dataset
    )

In [87]:
plan_analyses_xml

'```xml\n<AnalysisPlan>\n    <Step number="1">\n        <Description>\n            Load all necessary R libraries required for the analysis.\n        </Description>\n        <Actions>\n            <Action>Load libraries: lme4, MuMIn, car, r2glmm, bipartite, emmeans, piecewiseSEM, PAC, etc.</Action>\n            <ExampleCode>\n                <![CDATA[\n                library(lme4)\n                library(MuMIn)\n                library(car)\n                library(r2glmm)\n                library(bipartite)\n                library(emmeans)\n                library(piecewiseSEM)\n                library(PAC)\n                ]]>\n            </ExampleCode>\n        </Actions>\n    </Step>\n\n    <Step number="2">\n        <Description>\n            Import and merge all relevant datasets into a single data frame for analysis.\n        </Description>\n        <Actions>\n            <Action>Read datasets: WildBeeDiversit&Abundance, PAC, PiecewiseSEM_WildBeeNetworks, With_WithoutApis, D