In [0]:
# imports
import configparser
from Bio import Entrez, Medline
import pandas as pd
import dateparser
import time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, concat_ws, regexp_replace, to_date, trim, lit, to_date, when, length, udf, regexp_replace
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, FloatType

@udf(StringType())
def fallback_dateparser(date_str):
    """
    Converts dates that can't be parsed by the default parser.
    Returns a dataparser parsed datetime value.
    """
    if not date_str:
        return None
    parsed = dateparser.parse(date_str)
    return parsed.strftime("%Y-%m-%d") if parsed else None

def extract_doi(rec):
    """
    Extracts the document identifier doi from the record and returns it as a string. The DOI can be parsed from the AID.
    """
    if "AID" in rec:
        for aid in rec["AID"]:
            if "doi" in aid.lower():
                return aid.replace(" [doi]", "")
    return None

def get_pubmed_date_range_past_year():
    """
    Calculates timedelta of one month for pubmed query.
    Returns the start and end dates as strings in the format "YYYY/MM/DD".
    """
    today = datetime.today()
    one_year_ago = today - relativedelta(months=12)
    
    # Format dates as YYYY/MM/DD (PubMed compatible)
    today_str = today.strftime("%Y/%m/%d")
    one_year_ago_str = one_year_ago.strftime("%Y/%m/%d")
    
    return one_year_ago_str, today_str

# spark config
spark = SparkSession.builder.getOrCreate()

# job metadata metrics
run_start = datetime.now()
run_status = "SUCCESS"
new_records = 0

# Code to connect to AWS S3 (commented out for Databricks Free Edition compatibility)
# In production, S3 access is granted via IAM role attached to the cluster,
# or through Databricks Secrets configured for storage credentials
# spark.conf.set("fs.s3a.aws.credentials.provider", "com.amazonaws.auth.InstanceProfileCredentialsProvider")

try:
# Entrez configuration + set max retry configuration to gracefully handle issues
    Entrez.email = "mikeandersen622@gmail.com"
    Entrez.max_tries = 5 
    Entrez.sleep_between_tries = 20 

    # set batch vs incremental 
    incremental = False
    print(f"Incremental mode: {incremental}")

    # incremental determines date range of search conditions
    if incremental == True:
        last_run_df = spark.sql("""
                SELECT last_run_timestamp
                FROM script_run_metadata
                WHERE script_name = 'pubmed_ingestion'
                ORDER BY last_run_timestamp DESC
                LIMIT 1""")
            
        last_run_row = last_run_df.head(1)
        if last_run_row:
            start_date = last_run_row[0]['last_run_timestamp'].strftime('%Y/%m/%d')
            end_date = datetime.max()
        else:
            start_date = datetime.now()
    else:
        print("No previous run found — defaulting to past month.")
        start_date, end_date = get_pubmed_date_range_past_year()

    displayHTML(f"start date: {start_date}")
    displayHTML(f"end date: {end_date}")

    # pubmed search condition definition
    # pulls documents tagged with MeSH terms for physical conditioning
    mesh_query = (
        '"Exercise"[MeSH Terms] OR '
        '"Physical Conditioning, Human"[MeSH Terms] OR '
        '"Resistance Training"[MeSH Terms] OR '
        '"Aerobic Exercise"[MeSH Terms] OR '
        '"High-Intensity Interval Training"[MeSH Terms] OR '
        '"Plyometric Exercise"[MeSH Terms] OR '
        '"Endurance Training"[MeSH Terms] OR '
        '"Muscle Stretching Exercises"[MeSH Terms] OR '
        '"Physical Fitness"[MeSH Terms] OR '
        '"Cardiorespiratory Fitness"[MeSH Terms] OR '
        '"Exercise Therapy"[MeSH Terms] OR '
        '"Walking"[MeSH Terms] OR '
        '"Swimming"[MeSH Terms] OR '
        '"Gymnastics"[MeSH Terms]')

    # generate search query for API
    search_term = f'({mesh_query}) AND ("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication])'
    #search_term = f'({mesh_query}))'

    # batch in sets of 100 (temporary max of 1000)
    batch_size = 1000
    target_count = 30000 # only grab up to 30000
    pmid_list = []

    # get total counts of pubs
    # with error handling for Entrez API
    retries = 0
    while retries < Entrez.max_tries:
        try:
            handle = Entrez.esearch(db="pubmed", term=search_term, retmax=0)
            total_count = int(Entrez.read(handle)["Count"])
            displayHTML(f"<b>total available {total_count}</b>")
            break
        except Exception as e:
            if "HTTP Error 429" in str(e):
                print("Rate limit exceeded. Retrying after a delay...")
                time.sleep(Entrez.sleep_between_tries)
                retries += 1
            else:
                displayHTML(f"Error fetching PMIDs: {e}")
                total_count = 0
                break
           
    # fetch results matching search criteria in batches
    # gets pmids based on defined query and batch size
    if total_count > 0:
        for start in range(0, min(target_count, total_count), batch_size):
            # error handling for Entrez API
            retries = 0
            while retries < Entrez.max_tries:
                try:
                    handle = Entrez.esearch(
                        db="pubmed",
                        term=search_term,
                        retmax=batch_size,
                        retstart=start)
                    batch_result = Entrez.read(handle)
                    pmid_list.extend(batch_result["IdList"])
                    displayHTML(f"<b>Fetched {len(pmid_list)} PMIDs so far...</b>")
                    break
                except Exception as e:
                    if "HTTP Error 429" in str(e):
                        print("Rate limit exceeded. Retrying after a delay...")
                        time.sleep(Entrez.sleep_between_tries)
                        retries += 1
                    else:
                        displayHTML(f"Error fetching PMIDs: {e}")
                        break

        # parse pubmed udf, then iterate over pmid results and fetch record details
        parse_pubmed_date_udf = udf(fallback_dateparser, StringType())
        all_records = []
        for i in range(0, len(pmid_list), batch_size):
            try:
                batch_pmids = pmid_list[i:i+batch_size]
                handle = Entrez.efetch(db="pubmed", id=batch_pmids, rettype="medline", retmode="text")
                records = list(Medline.parse(handle))
                for rec in records:
                    all_records.append({
                        "pmid": rec.get("PMID"),
                        "title": rec.get("TI"),
                        "abstract": rec.get("AB"),
                        "journal": rec.get("JT"),
                        "date": rec.get("DP"),
                        "doi": extract_doi(rec)})
            except Entrez.Parser.ValidationError as e:
                print(f"Entrez XML parsing error: {e}.")
            except Exception as e:
                displayHTML(f"Unexpected error extracting record for PMID: {i}")

        # read into pyspark dataframe
        rows = []
        for rec in all_records:
            rows.append(Row(
                pmid=rec.get("pmid"),
                title=rec.get("title"),
                abstract=rec.get("abstract"),
                journal=rec.get("journal"),
                date=rec.get("date"),
                doi=rec.get("doi")))
        df = spark.createDataFrame(rows)
        displayHTML(f"Records prior to normalization: {df.count()}")

        # data quality and normalization:
        #   strip whitespace and new lines
        #   remove records with null doi (required for matching later)
        #   remove duplicate records based on doi (considered true duplicate)
        normalized_df = (
            df.withColumn("abstract", regexp_replace(col("abstract"), "\n", " "))
            .withColumn("title", regexp_replace(trim(col("title")), "\n", ""))
            .withColumn("pub_date", fallback_dateparser(col("date"))).drop("date"))

        normalized_df = normalized_df.filter(
            (col("doi").isNotNull()) & (trim(col("doi")) != "") & (col("pub_date").isNotNull()))
        normalized_df = normalized_df.dropDuplicates(["doi"])
        total_count = normalized_df.count()

        # if incremental, only append new records
        if incremental:
            loaded_df = spark.table("firedb_pubmed")
            # find records not in existing table
            new_records_df = normalized_df.join(loaded_df.select("doi"), on="doi", how="left_anti")
            new_records = new_records_df.count()
            if new_records > 0:
                new_records_df.write.format("delta").mode("append").saveAsTable("firedb_pubmed")
            else:
                print("No new records to append.")
        # full load mode overwrites entire table
        else:
            normalized_df.write.format("delta").mode("overwrite").option(
                "overwriteSchema", "true").saveAsTable("firedb_pubmed")
    else:
        print("No new records")
        total_count = 0

    # main logic complete, job run successful
    run_status = 'SUCCESS'

except Exception as e:
    # unhandled exception
    print('error during pubmed ingestion', e)
    run_status = 'FAILED'

# logging in run metadata
run_end = datetime.now()
duration = (run_end - run_start).total_seconds()
schema = StructType([
    StructField("script_name", StringType(), False),
    StructField("last_run_timestamp", TimestampType(), False),
    StructField("incremental_run", StringType(), False),
    StructField("total_pubs", IntegerType(), True),
    StructField("new_records", IntegerType(), True),
    StructField("run_status", StringType(), True),
    StructField("run_start", TimestampType(), True),
    StructField("run_end", TimestampType(), True),
    StructField("duration_seconds", FloatType(), True)])
run_metadata = [('pubmed_ingestion',datetime.now(),str(incremental), 
    int(total_count),int(new_records), run_status,run_start,
    datetime.now(), float(duration))]
# write metadata to delta table
metadata_df = spark.createDataFrame(run_metadata, schema=schema)
metadata_df.write.format("delta").mode("append").saveAsTable("script_run_metadata")
