In [0]:

import sys
import os
import re
import logging
from enum import Enum
import time
import configparser
from datetime import datetime, timedelta
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from functools import partial
import pandas as pd
import dateparser
import json
from Bio import Entrez, Medline
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import (
    col, concat_ws, regexp_replace, to_date, trim, lit, when, length, udf, explode)
from pyspark.sql.types import (
    StructType, StructField, StringType, TimestampType, 
    IntegerType, FloatType, ArrayType)
from dateutil.relativedelta import relativedelta

# New imports for relationship extraction
import spacy
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, StructType, StructField, StringType
import mlflow

# import modules
PROJECT_ROOT = '/Workspace/Repos/firedb/fire-db'
MODULES_PATH = os.path.join(PROJECT_ROOT, 'modules')
if MODULES_PATH not in sys.path:
    sys.path.insert(0, MODULES_PATH)
# MeSH map and outcome keywords used for data extraction and mapping
from meshmap import mesh_mapping, outcome_keywords 
# setup 
sys.path.append("/Workspace/Repos/firedb/fire-db/config/")
from pubmed_config import PubmedConfig
config_parser = configparser.ConfigParser()

'''
Note that this ETL job is designed to run on a single-node cluster with the following specifications,
    per Databricks Free Edition's limitations:
        - Limited to serverless compute only (no custom Spark configurations)
        - Small cluster size
        - Max of 5 concurrent tasks
'''

# Load SpaCy model globally for efficiency
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("SpaCy model not found. Installing...")
    os.system("python -m spacy download en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# configuration & data class definitions
@dataclass
class ETLConfig:
    email: str
    max_tries: int = 10
    sleep_between_tries: int = 20
    batch_size: int = 1000
    target_count: int = 20000
    incremental: bool = False
    max_records_per_query: int = 9999  # PubMed's limit

@dataclass
class RunMetadata:
    script_name: str
    start_time: datetime
    end_time: datetime
    incremental_run: bool
    total_pubs: int
    new_records: int
    status: str
    duration_seconds: float

@dataclass
class RelationshipTriplet:
    """Represents an exercise-outcome relationship triplet"""
    exercise_modality: str
    relationship_type: str  # e.g., "decreases", "increases", "improves"
    outcome: str
    confidence_score: float
    source_sentence: str

class ETLStatus(Enum):
    SUCCESS = "SUCCESS"
    FAILED = "FAILED"
    PARTIAL = "PARTIAL"
    NO_DATA = "NO_DATA"

# BUSINESS LOGIC FUNCTION DEFINITIONS
def setup_logging():
    """configure logging for ETL"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler('pubmed_etl.log')])
    return logging.getLogger(__name__)

def get_date_range_past_year() -> Tuple[str, str]:
    """calculate date range for the past year in appropriate format for Entrez query."""
    today = datetime.today()
    one_year_ago = today - relativedelta(months=12)
    return (one_year_ago.strftime("%Y/%m/%d"), today.strftime("%Y/%m/%d"))

def build_mesh_query() -> str:
    """build the MeSH terms query string for Pubmed API."""
    mesh_terms = [
        '"Exercise"[MeSH Terms]',
        '"Physical Conditioning, Human"[MeSH Terms]',
        '"Resistance Training"[MeSH Terms]',
        '"Aerobic Exercise"[MeSH Terms]',
        '"Circuit-Based Exercise"[MeSH Terms]',
        '"Periodization"[MeSH Terms]',
        '"Weight Lifting"[MeSH Terms]',
        '"High-Intensity Interval Training"[MeSH Terms]',
        '"Plyometric Exercise"[MeSH Terms]',
        '"Endurance Training"[MeSH Terms]',
        '"Muscle Stretching Exercises"[MeSH Terms]',
        '"Physical Fitness"[MeSH Terms]',
        '"Cardiorespiratory Fitness"[MeSH Terms]',
        '"Exercise Therapy"[MeSH Terms]',
        '"Walking"[MeSH Terms]',
        '"Swimming"[MeSH Terms]',
        '"Gymnastics"[MeSH Terms]',
        '"Bicycling"[MeSH Terms]']
    return ' OR '.join(mesh_terms)

def build_search_term(start_date: str, end_date: str) -> str:
    """build complete PubMed search term with MeSH terms and date range."""
    mesh_query = build_mesh_query()
    return f'({mesh_query}) AND ("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication])'

def partition_date_range(start_date_str: str, end_date_str: str, num_partitions: int) -> List[Tuple[str, str]]:
    """partition date range used to query Pubmed into smaller chunks.
    - enables bypassing the default limit of 10k requests."""
    start_date = datetime.strptime(start_date_str, "%Y/%m/%d")
    end_date = datetime.strptime(end_date_str, "%Y/%m/%d")
    
    total_days = (end_date - start_date).days
    days_per_partition = max(1, total_days // num_partitions)
    
    partitions = []
    current_start = start_date
    
    for i in range(num_partitions):
        if i == num_partitions - 1:
            # last partition gets any remaining days...
            current_end = end_date
        else:
            current_end = current_start + timedelta(days=days_per_partition)
            if current_end > end_date:
                current_end = end_date
        
        partitions.append((
            current_start.strftime("%Y/%m/%d"),
            current_end.strftime("%Y/%m/%d")))
        
        current_start = current_end + timedelta(days=1)
        if current_start > end_date:
            break
    
    return partitions

def extract_doi(record: Dict) -> Optional[str]:
    """extract DOI from PubMed record - used in matching ETL"""
    if "AID" not in record:
        return None
    
    for aid in record["AID"]:
        if "doi" in aid.lower():
            return aid.replace(" [doi]", "")
    return None

def parse_date_fallback(date_str: Optional[str]) -> Optional[str]:
    """parse date string using dateparser as fallback, only if default parser unable to find a date."""
    if not date_str:
        return None
    parsed = dateparser.parse(date_str)
    return parsed.strftime("%Y-%m-%d") if parsed else None

def extract_sample_size(text: str) -> int:
    """extract sample size of cohort/study from abstract using regex."""
    match = re.search(r"(?:N\s*=\s*|sample size of\s+|total of\s+)(\d{2,5})", text or "", re.IGNORECASE)
    return int(match.group(1)) if match else None

def build_synonym_lookup(mesh_map: dict) -> Dict[str, str]:
    """flatten mesh_mapping to a dict {synonym_or_term_lower: canonical_term} for fast lookup."""
    lookup = {}

    def recurse(term, node):
        # map the canonical term to itself (mesh)
        lookup[term.lower()] = term
        # map synonyms to the canonical term then recurse into subterms
        for syn in node.get("synonyms", []):
            lookup[syn.lower()] = term
        for subterm, subnode in node.get("subterms", {}).items():
            recurse(subterm, subnode)

    for top_term, node in mesh_map.items():
        recurse(top_term, node)
    return lookup

def extract_outcomes(text: str) -> List[str]:
    """match known clinical outcomes from abstract"""
    return [term for term in outcome_keywords if term.lower() in (text or "").lower()]

def extract_modalities(text):
    """extract known exercise modalities from abstract"""
    text_lower = text.lower()
    found_modalities = []
    for mesh_term, synonyms in synonym_lookup.items():
        for term in synonyms:
            if term in text_lower:
                found_modalities.append(mesh_term)
                break 
    return list(set(found_modalities))

# RELATIONSHIP EXTRACTION FUNCTIONS
def get_llama_client():
    """Initialize LLaMA client for Databricks"""
    try:
        # Databricks Foundation Models API
        import mlflow.deployments
        client = mlflow.deployments.get_deploy_client("databricks")
        return client
    except Exception as e:
        logger.warning(f"Could not initialize LLaMA client: {e}")
        return None

def extract_relationships_spacy(text: str, modalities: List[str], outcomes: List[str]) -> List[Dict]:
    """
    Extract exercise-outcome relationships using SpaCy NLP
    """
    if not text or not modalities or not outcomes:
        return []
    
    doc = nlp(text)
    relationships = []
    
    # Define relationship verbs that indicate exercise-outcome connections
    relationship_verbs = [
        "increase", "increases", "increased", "improve", "improves", "improved",
        "decrease", "decreases", "decreased", "reduce", "reduces", "reduced", 
        "enhance", "enhances", "enhanced", "boost", "boosts", "boosted",
        "lower", "lowers", "lowered", "raise", "raises", "raised",
        "affect", "affects", "affected", "influence", "influences", "influenced"
        "lift", "lifts", "lifted", "elevate", "elevates", "elevated", "lift", "lifts", "support", "supported", 
        "supports", "encourage", "encourages", "encouraged", "extend", "extends", "extended",
        "multiply", "multiplies", "multiplied", "surpress", "surpresses", "surpressed",
        "upregulate", "upregulates", "upregulated", "abate", "abates", "abated",
        "curve", "curves", "curved", "weaken", "weakened", "weakens", "strengthen", "strengthened",
        "strengthens", "lessened", "lessened", "lessens", "heighten", "heightened", "heightens",
        "diminish", "diminishes", "diminished"]
    
    # Look for sentences containing both modalities and outcomes
    for sent in doc.sents:
        sent_text = sent.text.lower()
        
        # Find modalities and outcomes in this sentence
        sent_modalities = [mod for mod in modalities if mod.lower() in sent_text]
        sent_outcomes = [out for out in outcomes if out.lower() in sent_text]
        
        if sent_modalities and sent_outcomes:
            # Look for relationship verbs
            for token in sent:
                if token.lemma_ in [v.rstrip('sd') for v in relationship_verbs]:
                    for modality in sent_modalities:
                        for outcome in sent_outcomes:
                            relationships.append({
                                'exercise_modality': modality,
                                'relationship_type': token.lemma_,
                                'outcome': outcome,
                                'confidence_score': 0.7,  # Base confidence for SpaCy extraction
                                'source_sentence': sent.text.strip(),
                                'method': 'spacy'
                            })
    
    return relationships

def extract_relationships_llama(text: str, modalities: List[str], outcomes: List[str]) -> List[Dict]:
    """
    Extract exercise-outcome relationships using LLaMA model
    """
    if not text or not modalities or not outcomes:
        return []
    
    client = get_llama_client()
    if not client:
        return []
    
    try:
        # Prepare prompt for LLaMA
        prompt = f"""
        Extract exercise-outcome relationships from the following research abstract.
        
        Available exercise modalities: {', '.join(modalities)}
        Available outcomes: {', '.join(outcomes)}
        
        Text: {text}
        
        For each relationship found, return a JSON object with:
        - exercise_modality: the specific exercise type
        - relationship_type: how the exercise affects the outcome (increases, decreases, improves, etc.)
        - outcome: the health/fitness outcome
        - confidence_score: confidence level (0.0-1.0)
        - source_sentence: the sentence containing the relationship
        
        Return only valid JSON objects, one per line. If no relationships found, return empty response.
        """

        response = client.predict(
            endpoint="databricks-meta-llama-3-3-70b-instruct", 
            inputs={"messages": [{"role": "user", "content": prompt}]}
        )
        
        # Parse LLaMA response
        relationships = []
        if response and 'predictions' in response:
            for line in response['predictions'][0].split('\n'):
                line = line.strip()
                if line and line.startswith('{'):
                    try:
                        rel = json.loads(line)
                        rel['method'] = 'llama'
                        relationships.append(rel)
                    except json.JSONDecodeError:
                        continue
        
        return relationships
        
    except Exception as e:
        logger.warning(f"LLaMA extraction failed: {e}")
        return []

def combine_relationship_extractions(spacy_rels: List[Dict], llama_rels: List[Dict]) -> List[Dict]:
    """
    Combine and deduplicate relationships from SpaCy and LLaMA
    """
    all_relationships = []
    seen = set()
    
    # Add SpaCy relationships
    for rel in spacy_rels:
        key = (rel['exercise_modality'].lower(), rel['outcome'].lower(), rel['relationship_type'].lower())
        if key not in seen:
            all_relationships.append(rel)
            seen.add(key)
    
    # Add LLaMA relationships, boosting confidence if they match SpaCy findings
    for rel in llama_rels:
        key = (rel['exercise_modality'].lower(), rel['outcome'].lower(), rel['relationship_type'].lower())
        if key in seen:
            # Find matching SpaCy relationship and boost confidence
            for existing_rel in all_relationships:
                existing_key = (existing_rel['exercise_modality'].lower(), 
                              existing_rel['outcome'].lower(), 
                              existing_rel['relationship_type'].lower())
                if existing_key == key:
                    existing_rel['confidence_score'] = min(1.0, existing_rel['confidence_score'] + 0.2)
                    existing_rel['method'] = 'spacy+llama'
                    break
        else:
            all_relationships.append(rel)
            seen.add(key)
    
    return all_relationships

def extract_all_relationships(text: str, modalities: List[str], outcomes: List[str]) -> List[Dict]:
    """
    Master function to extract relationships using both SpaCy and LLaMA
    """
    spacy_relationships = extract_relationships_spacy(text, modalities, outcomes)
    llama_relationships = extract_relationships_llama(text, modalities, outcomes)
    
    return combine_relationship_extractions(spacy_relationships, llama_relationships)

# DATA PROCESSING/TRANSFORMATION FUNCTIONS
def transform_record(record: Dict) -> Dict:
    """normalize individual PubMed records into dictionary with enrichment."""
    abstract = record.get("AB", "")
    title = record.get("TI", "")
    full_text = f"{title} {abstract}"
    
    # Extract basic features
    modalities = extract_modalities(full_text)
    outcomes = extract_outcomes(full_text)
    
    # Extract relationships
    relationships = extract_all_relationships(full_text, modalities, outcomes)
    
    return {
        "pmid": record.get("PMID"),
        "title": title,
        "abstract": abstract,
        "journal": record.get("JT"),
        "date": record.get("DP"),
        "doi": extract_doi(record),
        "mesh_terms": record.get("MH", []),
        "publication_types": record.get("PT", []),
        "keywords": record.get("OT", []),
        "n_size": extract_sample_size(abstract),
        "outcomes": outcomes,
        "modalities": modalities,
        "relationships": relationships  # New field for relationship triplets
    }

def normalize_dataframe(df):
    """apply normalization transformations to df and apply basic data quality filters"""
    parse_date_udf = udf(parse_date_fallback, StringType())
    
    return (df
            .withColumn("abstract", regexp_replace(col("abstract"), "\n", " "))
            .withColumn("title", regexp_replace(trim(col("title")), "\n", ""))
            .withColumn("pub_date", parse_date_udf(col("date")))
            .drop("date")
            .filter((col("doi").isNotNull()) & (trim(col("doi")) != "") & (col("pub_date").isNotNull()))
            .dropDuplicates(["doi"]))

def filter_new_records(normalized_df, existing_table_name: str, spark):
    """filter out records that already exist in the target table -- enables incremental loading."""
    try:
        existing_df = spark.table(existing_table_name)
        return normalized_df.join(existing_df.select("doi"), on="doi", how="left_anti")
    except Exception:
        # validate table exists...
        return normalized_df

# ENTREZ API/BIOPYTHON EXTRACTION FUNCTIONS
def fetch_with_retry(fetch_func, max_tries: int = 5, sleep_time: int = 20):
    """ function to fetch data from Pubmed w/ retry wrapper and error reporting"""
    last_exception = None
    
    for attempt in range(max_tries):
        try:
            return fetch_func()
        except Exception as e:
            last_exception = e
            error_msg = str(e)
            
            if "HTTP Error 429" in error_msg and attempt < max_tries - 1:
                wait_time = sleep_time * (2 ** attempt)  # Exponential backoff
                logger.warning(f"Rate limit hit! Retrying in {wait_time}s (attempt {attempt + 1}/{max_tries})")
                time.sleep(wait_time)
            elif attempt < max_tries - 1:
                logger.warning(f"Request failed! {error_msg}. Retrying... (attempt {attempt + 1}/{max_tries})")
                time.sleep(sleep_time)
            else:
                logger.error(f"All retry attempts failed! Last error: {error_msg}")
                raise last_exception
    
    raise last_exception

def get_total_count(search_term: str) -> int:
    """get total count of records matching search term."""
    def _fetch():
        handle = Entrez.esearch(db="pubmed", term=search_term, retmax=0)
        return int(Entrez.read(handle)["Count"])
    
    return fetch_with_retry(_fetch)

def fetch_pmid_batch(search_term: str, start: int, batch_size: int) -> List[str]:
    """fetch a batch of PMIDs given a batch size and starting index."""
    def _fetch():
        handle = Entrez.esearch(
            db="pubmed",
            term=search_term,
            retmax=batch_size,
            retstart=start)
        return Entrez.read(handle)["IdList"]
    
    return fetch_with_retry(_fetch)

def fetch_records_batch(pmids: List[str]) -> List[Dict]:
    """fetch detailed records for a batch of PMIDs."""
    def _fetch():
        handle = Entrez.efetch(db="pubmed", id=pmids, rettype="medline", retmode="text")
        return list(Medline.parse(handle))
    
    return fetch_with_retry(_fetch)

# ORCHESTRATION FUNCTIONS
def collect_pmids_for_date_range(start_date: str, end_date: str, config: ETLConfig, remaining_target: int = None) -> List[str]:
    """collect PMIDs for a specific date range, handling the 9999 limit."""
    search_term = build_search_term(start_date, end_date)
    total_count = get_total_count(search_term)
    
    print(f"Date range {start_date} to {end_date}: {total_count} records")
    effective_target = remaining_target if remaining_target is not None else config.target_count
    
    if total_count > config.max_records_per_query: # check if limit hit
        print(f"Records ({total_count}) exceed limit ({config.max_records_per_query}). Partitioning date range...")
        
        # calculate number of partitions needed
        num_partitions = (total_count // config.max_records_per_query) + 1
        partitions = partition_date_range(start_date, end_date, num_partitions)
        
        all_pmids = []
        for partition_start, partition_end in partitions:
            remaining_needed = effective_target - len(all_pmids)
            if remaining_needed <= 0:
                break
                
            partition_pmids = collect_pmids_for_date_range(partition_start, partition_end, config, remaining_needed)
            all_pmids.extend(partition_pmids)
        
        return all_pmids
    
    # if under the limit, proceed with normal collection
    target_count = min(effective_target, total_count)
    pmid_list = []
    
    for start in range(0, target_count, config.batch_size):
        remaining_in_batch = min(config.batch_size, target_count - len(pmid_list))
        batch_pmids = fetch_pmid_batch(search_term, start, remaining_in_batch)
        pmid_list.extend(batch_pmids)
        print(f"Fetched {len(pmid_list)} PMIDs so far for range {start_date} to {end_date}...")
        
        if len(pmid_list) >= target_count:
            break
    
    return pmid_list

def collect_all_pmids(search_term: str, config: ETLConfig, start_date: str, end_date: str) -> List[str]:
    """collect all PMIDs matching the search criteria with automatic date partitioning."""
    # first, get the total count to make informed decisions
    total_available = get_total_count(search_term)
    print(f"Total records available: {total_available}")
    print(f"Target count configured: {config.target_count}")
    
    # then, determine how many to actually collect
    effective_target = min(config.target_count, total_available)
    print(f"Will collect: {effective_target} records")
    
    return collect_pmids_for_date_range(start_date, end_date, config, effective_target)

def extract_all_records(pmid_list: List[str], config: ETLConfig) -> List[Dict]:
    """extract detailed records for all PMIDs."""
    all_records = []
    
    for i in range(0, len(pmid_list), config.batch_size):
        batch_pmids = pmid_list[i:i + config.batch_size]
        try:
            records = fetch_records_batch(batch_pmids)
            transformed_records = [transform_record(rec) for rec in records]
            all_records.extend(transformed_records)
            print(f"Processed {len(all_records)} records so far...")
        except Exception as e:
            print(f"Error processing batch starting at {i}: {e}")
            continue
    
    return all_records

def create_spark_dataframe(records: List[Dict], spark) -> 'DataFrame':
    """Convert records list to PySpark df with complete schema including relationships."""
    # Define relationship schema
    relationship_schema = StructType([
        StructField("exercise_modality", StringType(), True),
        StructField("relationship_type", StringType(), True),
        StructField("outcome", StringType(), True),
        StructField("confidence_score", FloatType(), True),
        StructField("source_sentence", StringType(), True),
        StructField("method", StringType(), True)
    ])
    
    if not records:
        schema = StructType([
            StructField("pmid", StringType(), True),
            StructField("title", StringType(), True),
            StructField("abstract", StringType(), True),
            StructField("journal", StringType(), True),
            StructField("date", StringType(), True),
            StructField("doi", StringType(), True),
            StructField("mesh_terms", ArrayType(StringType()), True),
            StructField("publication_types", ArrayType(StringType()), True),
            StructField("keywords", ArrayType(StringType()), True),
            StructField("n_size", IntegerType(), True),
            StructField("outcomes", ArrayType(StringType()), True),
            StructField("modalities", ArrayType(StringType()), True),
            StructField("relationships", ArrayType(relationship_schema), True)
        ])
        return spark.createDataFrame([], schema)
    
    rows = [Row(**record) for record in records]
    return spark.createDataFrame(rows)

def save_data(df, table_name: str, mode: str = "overwrite"):
    """save pyspark df to delta table."""
    writer = df.write.format("delta").mode(mode)
    if mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")
    writer.saveAsTable(table_name)

def create_relationships_table(df, spark):
    """Create a separate table for exercise-outcome relationships"""
    relationships_df = (df
                       .select("pmid", "doi", explode("relationships").alias("relationship"))
                       .select("pmid", "doi",
                              col("relationship.exercise_modality").alias("exercise_modality"),
                              col("relationship.relationship_type").alias("relationship_type"), 
                              col("relationship.outcome").alias("outcome"),
                              col("relationship.confidence_score").alias("confidence_score"),
                              col("relationship.source_sentence").alias("source_sentence"),
                              col("relationship.method").alias("extraction_method")))
    
    # Save relationships table
    save_data(relationships_df, "firedb_exercise_outcome_relationships", "overwrite")
    return relationships_df

def save_run_metadata(metadata: RunMetadata, spark):
    """store run metadata to tracking table."""
    schema = StructType([
        StructField("script_name", StringType(), False),
        StructField("last_run_timestamp", TimestampType(), False),
        StructField("incremental_run", StringType(), False),
        StructField("total_pubs", IntegerType(), True),
        StructField("new_records", IntegerType(), True),
        StructField("run_status", StringType(), True),
        StructField("run_start", TimestampType(), True),
        StructField("run_end", TimestampType(), True),
        StructField("duration_seconds", FloatType(), True)])
    
    row_data = [(
        metadata.script_name,
        metadata.end_time,
        str(metadata.incremental_run),
        metadata.total_pubs,
        metadata.new_records,
        metadata.status,
        metadata.start_time,
        metadata.end_time,
        metadata.duration_seconds)]
    
    metadata_df = spark.createDataFrame(row_data, schema=schema)
    metadata_df.write.format("delta").mode("append").saveAsTable("script_run_metadata")

def create_run_metadata(status: str, start_time: datetime, total: int, new: int) -> RunMetadata:
    """helper function to create RunMetadata object for logging table."""
    end_time = datetime.now()
    return RunMetadata(
        script_name="pubmed_ingestion",
        start_time=start_time,
        end_time=end_time,
        incremental_run=False,
        total_pubs=total,
        new_records=new,
        status=status,
        duration_seconds=(end_time - start_time).total_seconds())

def get_incremental_date_range(spark) -> Tuple[str, str]:
    """Get date range for incremental run based on last execution."""
    try:
        last_run_df = spark.sql("""
            SELECT last_run_timestamp
            FROM script_run_metadata
            WHERE script_name = 'pubmed_ingestion'
            ORDER BY last_run_timestamp DESC
            LIMIT 1""")
        
        last_run_row = last_run_df.collect()
        if last_run_row:
            start_date = last_run_row[0]['last_run_timestamp'].strftime('%Y/%m/%d')
            end_date = datetime.now().strftime('%Y/%m/%d')
            return start_date, end_date
        else:
            return get_date_range_past_year()
    except Exception:
        return get_date_range_past_year()

def generate_relationship_summary(relationships_df, spark):
    """Generate summary statistics for extracted relationships"""
    summary_stats = relationships_df.agg(
        {"confidence_score": "avg", "*": "count"}
    ).collect()[0]
    
    print(f"Relationship Extraction Summary:")
    print(f"Total relationships extracted: {summary_stats[1]}")
    print(f"Average confidence score: {summary_stats[0]:.3f}")
    
    # Top modalities and outcomes
    print("\nTop Exercise Modalities:")
    top_modalities = relationships_df.groupBy("exercise_modality").count().orderBy(col("count").desc()).limit(10)
    top_modalities.show()
    
    print("\nTop Outcomes:")
    top_outcomes = relationships_df.groupBy("outcome").count().orderBy(col("count").desc()).limit(10)
    top_outcomes.show()
    
    print("\nTop Relationship Types:")
    top_relationships = relationships_df.groupBy("relationship_type").count().orderBy(col("count").desc()).limit(10)
    top_relationships.show()

# MAIN EXTRACTION FUNCTION
def run_pubmed_etl(config: ETLConfig, spark) -> RunMetadata:
    """main ETL pipeline with relationship extraction"""
    start_time = datetime.now()
    logger.info(f"Starting PubMed ETL with relationship extraction. Config: {config}")
    
    try:
        Entrez.email = config.email
        Entrez.max_tries = config.max_tries
        Entrez.sleep_between_tries = config.sleep_between_tries
        
        # generate date range based on incremental vs. bulk load configuration
        if config.incremental:
            start_date, end_date = get_incremental_date_range(spark)
            logger.info(f"Incremental run: {start_date} to {end_date}")
        else:
            start_date, end_date = get_date_range_past_year()
            logger.info(f"Full run: {start_date} to {end_date}")
        
        # build search + collect relevant pmids 
        search_term = build_search_term(start_date, end_date)
        logger.info(f"Search term: {search_term}")
        pmid_list = collect_all_pmids(search_term, config, start_date, end_date)
        
        if not pmid_list:
            logger.info("No new PMIDs found")
            return create_run_metadata(ETLStatus.NO_DATA.value, start_time, 0, 0)
        logger.info(f"Collected {len(pmid_list)} PMIDs") 
        
        # extract + transform with relationship extraction
        logger.info("Starting record extraction and relationship mining...")
        records = extract_all_records(pmid_list, config)
        logger.info(f"Extracted {len(records)} records with relationships")

        # process w/ spark function then normalize
        df = create_spark_dataframe(records, spark)
        normalized_df = normalize_dataframe(df)
        
        # Create separate relationships table
        logger.info("Creating relationships table...")
        relationships_df = create_relationships_table(normalized_df, spark)
        
        # Generate relationship summary
        generate_relationship_summary(relationships_df, spark)
        
        # handle incremental vs full loading logic
        if config.incremental:
            final_df = filter_new_records(normalized_df, "firedb_pubmed", spark)
            new_records = final_df.count()
            if new_records > 0:
                save_data(final_df, "firedb_pubmed", "append")
                logger.info(f"Appended {new_records} new records")
        else:
            new_records = normalized_df.count()
            save_data(normalized_df, "firedb_pubmed", "overwrite")
            logger.info(f"Saved {new_records} records (full refresh)")
        
        total_records = normalized_df.count()
        total_relationships = relationships_df.count()
        
        logger.info(f"ETL completed successfully:")
        logger.info(f"  - Total publications: {total_records}")
        logger.info(f"  - Total relationships extracted: {total_relationships}")
        logger.info(f"  - Average relationships per publication: {total_relationships/total_records:.2f}")
        
        metadata = create_run_metadata(ETLStatus.SUCCESS.value, start_time, total_records, new_records)
        return metadata
        
    except Exception as e:
        logger.error(f"ETL failed with error: {str(e)}", exc_info=True)
        return create_run_metadata(ETLStatus.FAILED.value, start_time, 0, 0)

if __name__ == "__main__":
    # spark initialization
    spark = SparkSession.builder.getOrCreate()
    logger = setup_logging()

    config = ETLConfig(
        email=PubmedConfig.EMAIL,
        incremental=PubmedConfig.INCREMENTAL,
        batch_size=PubmedConfig.BATCH_SIZE,
        target_count=PubmedConfig.TARGET_COUNT)
    
    print(f"Configuration:")
    print(f"Email: {PubmedConfig.EMAIL}")
    print(f"Incremental: {PubmedConfig.INCREMENTAL}")
    print(f"Batch Size: {PubmedConfig.BATCH_SIZE}")
    print(f"Target Count: {PubmedConfig.TARGET_COUNT}")
    
    # Build synonym lookup for modality extraction
    synonym_lookup = build_synonym_lookup(mesh_mapping)
    
    print(f"Loaded {len(synonym_lookup)} exercise modality synonyms")
    print(f"Loaded {len(outcome_keywords)} outcome keywords")

    # main ETL with relationship extraction
    metadata = run_pubmed_etl(config, spark)
    
    # save run metadata
    save_run_metadata(metadata, spark)
    
    # final logs
    print(f"\n=== ETL COMPLETED ===")
    print(f"Status: {metadata.status}")
    print(f"Total records: {metadata.total_pubs}")
    print(f"New records: {metadata.new_records}")
    print(f"Duration: {metadata.duration_seconds:.1f} seconds")
    
    sql_query = """
        SELECT exercise_modality, relationship_type, outcome, confidence_score, extraction_method
        FROM firedb_exercise_outcome_relationships 
        WHERE confidence_score > 0.8 
        ORDER BY confidence_score DESC 
        LIMIT 20
        """
    try:
        sample_relationships = spark.sql(sql_query)
        print(f"\n=== TOP CONFIDENCE RELATIONSHIPS ===")
        sample_relationships.show(truncate=False)
    except Exception as e:
        print(f"Could not display sample relationships: {e}")
        
    # Display relationship statistics by method
    sql_query = """
            SELECT extraction_method,
                   COUNT(*) as total_relationships,
                   AVG(confidence_score) as avg_confidence,
                   COUNT(DISTINCT exercise_modality) as unique_modalities,
                   COUNT(DISTINCT outcome) as unique_outcomes
            FROM firedb_exercise_outcome_relationships 
            GROUP BY extraction_method
            ORDER BY total_relationships DESC
        """
    try:
        method_stats = spark.sql(sql_query)
        print(f"\n=== EXTRACTION METHOD STATISTICS ===")
        method_stats.show()
    except Exception as e:
        print(f"Could not display method statistics: {e}")