In [0]:
# imports
import configparser
from Bio import Entrez, Medline
import pandas as pd
import dateparser
import time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, concat_ws, regexp_replace, to_date, trim, lit, when, length, udf
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, FloatType
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from functools import partial
import sys

# read config info for run to import config module
config_parser = configparser.ConfigParser()
sys.path.append("/Workspace/Repos/firedb/fire-db/config/")
from pubmed_config import PubmedConfig

# configuration & data class definitions
@dataclass
class ETLConfig:
    email: str
    max_tries: int = 5
    sleep_between_tries: int = 20
    batch_size: int = 1000
    target_count: int = 5000
    incremental: bool = False

@dataclass
class RunMetadata:
    script_name: str
    start_time: datetime
    end_time: datetime
    incremental_run: bool
    total_pubs: int
    new_records: int
    status: str
    duration_seconds: float

# business logic and function definitions
def get_date_range_past_year() -> Tuple[str, str]:
    """calculate date range for the past year in appropriate format for Entrez query."""
    today = datetime.today()
    one_year_ago = today - relativedelta(months=12)
    return (one_year_ago.strftime("%Y/%m/%d"), today.strftime("%Y/%m/%d"))

def build_mesh_query() -> str:
    """build the MeSH terms query string."""
    mesh_terms = [
        '"Exercise"[MeSH Terms]',
        '"Physical Conditioning, Human"[MeSH Terms]',
        '"Resistance Training"[MeSH Terms]',
        '"Aerobic Exercise"[MeSH Terms]',
        '"High-Intensity Interval Training"[MeSH Terms]',
        '"Plyometric Exercise"[MeSH Terms]',
        '"Endurance Training"[MeSH Terms]',
        '"Muscle Stretching Exercises"[MeSH Terms]',
        '"Physical Fitness"[MeSH Terms]',
        '"Cardiorespiratory Fitness"[MeSH Terms]',
        '"Exercise Therapy"[MeSH Terms]',
        '"Walking"[MeSH Terms]',
        '"Swimming"[MeSH Terms]',
        '"Gymnastics"[MeSH Terms]'
    ]
    return ' OR '.join(mesh_terms)

def build_search_term(start_date: str, end_date: str) -> str:
    """build complete PubMed search term with MeSH terms and date range."""
    mesh_query = build_mesh_query()
    return f'({mesh_query}) AND ("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication])'

def extract_doi(record: Dict) -> Optional[str]:
    """extract DOI from PubMed record - used in matching ETL"""
    if "AID" not in record:
        return None
    
    for aid in record["AID"]:
        if "doi" in aid.lower():
            return aid.replace(" [doi]", "")
    return None

def parse_date_fallback(date_str: Optional[str]) -> Optional[str]:
    """parse date string using dateparser as fallback, only if default parser unable to parse."""
    if not date_str:
        return None
    parsed = dateparser.parse(date_str)
    return parsed.strftime("%Y-%m-%d") if parsed else None

# data processing and transformation functions
def transform_record(record: Dict) -> Dict:
    """normalize individual PubMed records into dictionary."""
    return {
        "pmid": record.get("PMID"),
        "title": record.get("TI"),
        "abstract": record.get("AB"),
        "journal": record.get("JT"),
        "date": record.get("DP"),
        "doi": extract_doi(record)}

def normalize_dataframe(df):
    """apply normalization transformations to DataFrame and apply basic data quality filters"""
    # UDF for date parsing
    parse_date_udf = udf(parse_date_fallback, StringType())
    
    return (df
            .withColumn("abstract", regexp_replace(col("abstract"), "\n", " "))
            .withColumn("title", regexp_replace(trim(col("title")), "\n", ""))
            .withColumn("pub_date", parse_date_udf(col("date")))
            .drop("date")
            .filter((col("doi").isNotNull()) & (trim(col("doi")) != "") & (col("pub_date").isNotNull()))
            .dropDuplicates(["doi"]))

def filter_new_records(normalized_df, existing_table_name: str, spark):
    """filter out records that already exist in the target table."""
    try:
        existing_df = spark.table(existing_table_name)
        return normalized_df.join(existing_df.select("doi"), on="doi", how="left_anti")
    except Exception:
        # validate table exists...
        return normalized_df

# functions to manage interactions with Entrez API (biopython)
def fetch_with_retry(fetch_func, max_tries: int = 5, sleep_time: int = 20):
    """retry wrapper for API calls."""
    for attempt in range(max_tries):
        try:
            return fetch_func()
        except Exception as e:
            if "HTTP Error 429" in str(e) and attempt < max_tries - 1:
                print(f"Rate limit exceeded. Retrying after {sleep_time}s... (attempt {attempt + 1})")
                time.sleep(sleep_time)
            else:
                raise e
    raise Exception(f"Failed after {max_tries} attempts")

def get_total_count(search_term: str) -> int:
    """get total count of records matching search term."""
    def _fetch():
        handle = Entrez.esearch(db="pubmed", term=search_term, retmax=0)
        return int(Entrez.read(handle)["Count"])
    
    return fetch_with_retry(_fetch)

def fetch_pmid_batch(search_term: str, start: int, batch_size: int) -> List[str]:
    """fetch a batch of PMIDs given a batch size and starting index."""
    def _fetch():
        handle = Entrez.esearch(
            db="pubmed",
            term=search_term,
            retmax=batch_size,
            retstart=start
        )
        return Entrez.read(handle)["IdList"]
    
    return fetch_with_retry(_fetch)

def fetch_records_batch(pmids: List[str]) -> List[Dict]:
    """fetch detailed records for a batch of PMIDs."""
    def _fetch():
        handle = Entrez.efetch(db="pubmed", id=pmids, rettype="medline", retmode="text")
        return list(Medline.parse(handle))
    
    return fetch_with_retry(_fetch)

# orchestration and transformation functions
def collect_all_pmids(search_term: str, config: ETLConfig) -> List[str]:
    """collect all PMIDs matching the search criteria."""
    total_count = get_total_count(search_term)
    print(f"Total available records: {total_count}")
    
    pmid_list = []
    target_count = min(config.target_count, total_count)
    
    for start in range(0, target_count, config.batch_size):
        batch_pmids = fetch_pmid_batch(search_term, start, config.batch_size)
        pmid_list.extend(batch_pmids)
        print(f"Fetched {len(pmid_list)} PMIDs so far...")
    
    return pmid_list

def extract_all_records(pmid_list: List[str], config: ETLConfig) -> List[Dict]:
    """extract detailed records for all PMIDs."""
    all_records = []
    
    for i in range(0, len(pmid_list), config.batch_size):
        batch_pmids = pmid_list[i:i + config.batch_size]
        try:
            records = fetch_records_batch(batch_pmids)
            transformed_records = [transform_record(rec) for rec in records]
            all_records.extend(transformed_records)
        except Exception as e:
            print(f"Error processing batch starting at {i}: {e}")
            continue
    
    return all_records

def create_spark_dataframe(records: List[Dict], spark) -> 'DataFrame':
    """convert records list to pyspark df."""
    if not records:
        schema = StructType([
            StructField("pmid", StringType(), True),
            StructField("title", StringType(), True),
            StructField("abstract", StringType(), True),
            StructField("journal", StringType(), True),
            StructField("date", StringType(), True),
            StructField("doi", StringType(), True)])
        return spark.createDataFrame([], schema)
    
    rows = [Row(**record) for record in records]
    return spark.createDataFrame(rows)

def save_data(df, table_name: str, mode: str = "overwrite"):
    """save pyspark df to delta table."""
    writer = df.write.format("delta").mode(mode)
    if mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")
    writer.saveAsTable(table_name)

def save_run_metadata(metadata: RunMetadata, spark):
    """store run metadata to tracking table."""
    schema = StructType([
        StructField("script_name", StringType(), False),
        StructField("last_run_timestamp", TimestampType(), False),
        StructField("incremental_run", StringType(), False),
        StructField("total_pubs", IntegerType(), True),
        StructField("new_records", IntegerType(), True),
        StructField("run_status", StringType(), True),
        StructField("run_start", TimestampType(), True),
        StructField("run_end", TimestampType(), True),
        StructField("duration_seconds", FloatType(), True)])
    
    row_data = [(
        metadata.script_name,
        metadata.end_time,
        str(metadata.incremental_run),
        metadata.total_pubs,
        metadata.new_records,
        metadata.status,
        metadata.start_time,
        metadata.end_time,
        metadata.duration_seconds)]
    
    metadata_df = spark.createDataFrame(row_data, schema=schema)
    metadata_df.write.format("delta").mode("append").saveAsTable("script_run_metadata")

# main pipeine/ETL function
def run_pubmed_etl(config: ETLConfig, spark) -> RunMetadata:
    """main ETL pipeline function."""
    start_time = datetime.now()
    
    try:
        # Entrez configuration
        Entrez.email = config.email
        Entrez.max_tries = config.max_tries
        Entrez.sleep_between_tries = config.sleep_between_tries
        
        # date range generation
        if config.incremental:
            # get last run date from metadata table...
            start_date, end_date = get_incremental_date_range(spark)
        else:
            start_date, end_date = get_date_range_past_year()
        print(f"Date range: {start_date} to {end_date}")
        
        # generate search term string for Entrez API query
        search_term = build_search_term(start_date, end_date)
        # data extraction from API
        pmid_list = collect_all_pmids(search_term, config)
        # if no new pmids, end the main function and log the run metadata
        if not pmid_list:
            return create_run_metadata("SUCCESS", start_time, 0, 0)
        
        records = extract_all_records(pmid_list, config)
        
        # transformations and normalization
        df = create_spark_dataframe(records, spark)
        normalized_df = normalize_dataframe(df)
        
        # depending on whether incremental or bulk, save data to table 
        # but only save records not already existing in targe table (ensure idempotency)
        if config.incremental:
            final_df = filter_new_records(normalized_df, "firedb_pubmed", spark)
            new_records = final_df.count()
            if new_records > 0:
                save_data(final_df, "firedb_pubmed", "append")
        else:
            new_records = normalized_df.count()
            save_data(normalized_df, "firedb_pubmed", "overwrite")
        # get new record counts for metadata logging
        total_records = normalized_df.count()
        
        return create_run_metadata("SUCCESS", start_time, total_records, new_records)
        
    except Exception as e:
        print(f"ETL failed: {e}")
        # log failure if issue occurs
        return create_run_metadata("FAILED", start_time, 0, 0)

def create_run_metadata(status: str, start_time: datetime, total: int, new: int) -> RunMetadata:
    """helper function to create RunMetadata object for logging table."""
    end_time = datetime.now()
    return RunMetadata(
        script_name="pubmed_ingestion",
        start_time=start_time,
        end_time=end_time,
        incremental_run=False,
        total_pubs=total,
        new_records=new,
        status=status,
        duration_seconds=(end_time - start_time).total_seconds())

def get_incremental_date_range(spark) -> Tuple[str, str]:
    """Get date range for incremental run based on last execution."""
    try:
        last_run_df = spark.sql("""
            SELECT last_run_timestamp
            FROM script_run_metadata
            WHERE script_name = 'pubmed_ingestion'
            ORDER BY last_run_timestamp DESC
            LIMIT 1""")
        
        last_run_row = last_run_df.collect()
        if last_run_row:
            start_date = last_run_row[0]['last_run_timestamp'].strftime('%Y/%m/%d')
            end_date = datetime.now().strftime('%Y/%m/%d')
            return start_date, end_date
        else:
            return get_date_range_past_year()
    except Exception:
        return get_date_range_past_year()

if __name__ == "__main__":
    # spark initialization
    spark = SparkSession.builder.getOrCreate()

    config_parser.read('.pubmedcfg')
    config = ETLConfig(
        email=PubmedConfig.EMAIL,
        incremental=PubmedConfig.INCREMENTAL,
        batch_size=PubmedConfig.BATCH_SIZE,
        target_count=PubmedConfig.TARGET_COUNT)
    
    # main ETL
    metadata = run_pubmed_etl(config, spark)
    # run metadata
    save_run_metadata(metadata, spark)
    
    # logs...
    print(f"ETL completed with status: {metadata.status}")
    print(f"Total records: {metadata.total_pubs}, New records: {metadata.new_records}")