### This pipeline parses PDFs with grobid.

**input**: last 3 days of records from taxicab where type is pdf from table `openalex.taxicab.taxicab_results`

**process**: grobid API on ECS

**output**: xml from grobid, along with urls and ids into table `openalex.pdf.grobid_processing_results`

In [0]:
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError
import time
import random
from urllib3.util import Retry
from requests.adapters import HTTPAdapter
from datetime import datetime, timedelta
from time import sleep
import re
import requests
from requests.exceptions import Timeout
import json
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, BooleanType, TimestampType

In [0]:
GROBID_URL = "http://grobid-api-load-balancer-1880850154.us-east-1.elb.amazonaws.com/parse"
BATCH_SIZE = 100
MAX_WORKERS = 70
http_session = requests.Session()

retry_strategy = Retry(
    total=3,
    backoff_factor=1,
    status_forcelist=[429, 502, 503, 504],
    allowed_methods=["GET", "POST"]
)

adapter = HTTPAdapter(
    pool_connections=MAX_WORKERS,
    pool_maxsize=MAX_WORKERS*2,
    max_retries=retry_strategy
)

http_session.mount("http://", adapter)
http_session.mount("https://", adapter)

In [0]:
grobid_results_schema = StructType([
    StructField("id", StringType(), True),
    StructField("status", StringType(), True),
    StructField("source_pdf_id", StringType(), True),
    StructField("url", StringType(), True),
    StructField("native_id", StringType(), True),
    StructField("native_id_namespace", StringType(), True),
    StructField("s3_key", StringType(), True),
    StructField("s3_path", StringType(), True),
    StructField("xml_content", StringType(), True),
    StructField("error_message", StringType(), True),
    StructField("created_date", TimestampType(), True)
])

In [0]:
def create_error_response(pdf_uuid, url, native_id, native_id_namespace, error_message):
    """Create a standardized error response dictionary"""
    return {
        "id": None,
        "status": "failed",
        "source_pdf_id": pdf_uuid,
        "url": url,
        "native_id": native_id,
        "native_id_namespace": native_id_namespace,
        "s3_key": None,
        "s3_path": None,
        "xml_content": None,
        "error_message": error_message,
        "created_date": datetime.now(),
    }

def process_pdf_single(row_data):
    """Process a single PDF through the GROBID service with improved connection handling"""
    pdf_uuid = row_data.get('source_pdf_id')
    url = row_data.get('url')
    native_id = row_data.get('native_id')
    native_id_namespace = row_data.get('native_id_namespace')
    max_retries = 2
    
    print(f"Processing PDF {pdf_uuid} from {url}")
    start_time = time.time()
    
    data = {
        "pdf_uuid": pdf_uuid,
        "url": url,
        "native_id": native_id,
        "native_id_namespace": native_id_namespace
    }
    
    retry_count = 0
    while retry_count <= max_retries:
        try:
            if retry_count > 0:
                jitter = random.uniform(0.1, 1.0)
                time.sleep(jitter)
                
            response = http_session.post(
                GROBID_URL,
                json=data,
                timeout=(30, 120)
            )
            
            if response.status_code in [429, 503, 504]:
                retry_count += 1
                # exponential backoff with jitter
                wait_time = min(2 ** retry_count + random.uniform(0, 1), 60)
                print(f"Service error {response.status_code} for {url}, retrying in {wait_time:.2f} seconds... ({retry_count}/{max_retries})")
                time.sleep(wait_time)
                continue

            if response.status_code >= 400:
                error_message = f"HTTP error: {response.status_code}"
                try:
                    response_json = response.json()
                    if "error" in response_json and response_json["error"] is not None:
                        error_message = f"HTTP error {response.status_code}: {response_json['error']}"
                    else:
                        error_message = f"HTTP error {response.status_code}: {response.text[:200]}"
                except:
                    # If not valid JSON, use text
                    error_message = f"HTTP error {response.status_code}: {response.text[:200]}"
                
                return create_error_response(
                    pdf_uuid, url, native_id, native_id_namespace,
                    error_message
                )

            response.raise_for_status()
            
            result = response.json()
            
            end_time = time.time()
            processing_time = end_time - start_time
            print(f"GROBID request for {pdf_uuid} took {processing_time:.2f} seconds")
            
            # successful response
            return {
                "id": result.get("id"),
                "status": result.get("status") or "success",
                "source_pdf_id": result.get("source_pdf_id"),
                "url": url,
                "native_id": native_id,
                "native_id_namespace": native_id_namespace,
                "s3_key": result.get("s3_key"),
                "s3_path": result.get("s3_path"),
                "xml_content": result.get("xml_content"),
                "error_message": None,
                "created_date": datetime.now()
            }
                
        except requests.exceptions.ConnectionError as e:
            retry_count += 1
            error_msg = f"Connection error: {str(e)}"
            
            # If exceeded retries, return error
            if retry_count > max_retries:
                print(f"Failed to connect for {url} after {max_retries} retries: {error_msg}")
                return create_error_response(
                    pdf_uuid, url, native_id, native_id_namespace,
                    error_msg
                )
            
            # longer wait for connection errors
            wait_time = min(5 ** retry_count + random.uniform(0, 2), 120)
            print(f"Connection error for {url}, retrying in {wait_time:.2f} seconds... ({retry_count}/{max_retries})")
            time.sleep(wait_time)
            
        except Exception as e:
            retry_count += 1
            error_msg = str(e)
            
            # if exceeded retries, return error
            if retry_count > max_retries:
                print(f"Error processing {url}: {error_msg}")
                return create_error_response(
                    pdf_uuid, url, native_id, native_id_namespace,
                    error_msg
                )
            
            # wait before retrying with jitter
            wait_time = min(2 ** retry_count + random.uniform(0, 1), 60)
            print(f"Error: {error_msg}, retrying in {wait_time:.2f} seconds... ({retry_count}/{max_retries})")
            time.sleep(wait_time)

In [0]:
def process_pdfs_with_continuous_batching(candidates_df, write_size=BATCH_SIZE*10):
    """Process PDFs with continuous batching that doesn't wait for stragglers"""
    
    # Convert to pandas for processing
    pdf_data = candidates_df.toPandas()
    
    # Create a list of row data dictionaries
    rows = []
    for _, row in pdf_data.iterrows():
        rows.append({
            'source_pdf_id': row['source_pdf_id'],
            'url': row['url'],
            'native_id': row['native_id'],
            'native_id_namespace': row['native_id_namespace']
        })
    
    print(f"Processing {len(rows)} records with continuous batching")
    print(f"Will write to database in batches of {write_size} records")
    
    # Create a bounded thread pool
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit all tasks to the executor without waiting
        future_to_row = {executor.submit(process_pdf_single, row): row for row in rows}
        
        # Process results as they complete rather than waiting for all
        completed_count = 0
        accumulated_results = []
        
        for future in as_completed(future_to_row):
            row = future_to_row[future]
            try:
                result = future.result(timeout=180)  # Longer timeout for individual tasks
                accumulated_results.append(result)
                completed_count += 1
                
                if completed_count % 50 == 0:
                    print(f"Completed {completed_count}/{len(rows)} records")
                
                # Write in larger chunks
                if len(accumulated_results) >= write_size:
                    print(f"Writing batch of {len(accumulated_results)} results to database")
                    result_df = pd.DataFrame(accumulated_results)
                    spark_df = spark.createDataFrame(result_df, schema=grobid_results_schema)
                    write_results_to_table(spark_df)
                    accumulated_results = []
                    
            except TimeoutError:
                print(f"Task timed out for PDF: {row.get('url')}")
                accumulated_results.append({
                    "id": None,
                    "status": "timeout",
                    "source_pdf_id": row.get('source_pdf_id'),
                    "url": row.get('url'),
                    "native_id": row.get('native_id'),
                    "native_id_namespace": row.get('native_id_namespace'),
                    "s3_key": None,
                    "s3_path": None,
                    "xml_content": None,
                    "error_message": "Task timed out",
                    "created_date": datetime.now(),
                })
            except Exception as e:
                print(f"Error processing PDF {row.get('url')}: {str(e)}")
                accumulated_results.append({
                    "id": None,
                    "status": "error",
                    "source_pdf_id": row.get('source_pdf_id'),
                    "url": row.get('url'),
                    "native_id": row.get('native_id'),
                    "native_id_namespace": row.get('native_id_namespace'),
                    "s3_key": None,
                    "s3_path": None,
                    "xml_content": None,
                    "error_message": str(e),
                    "created_date": datetime.now(),
                })
        
        # Write any remaining results
        if accumulated_results:
            print(f"Writing final batch of {len(accumulated_results)} results to database")
            result_df = pd.DataFrame(accumulated_results)
            spark_df = spark.createDataFrame(result_df, schema=grobid_results_schema)
            write_results_to_table(spark_df)
    
    print(f"Successfully processed {completed_count} records")
    return completed_count

In [0]:
def write_results_to_table(results_df):
    """Write results to the grobid_processing_results table"""
    if results_df is None or results_df.isEmpty():
        print("No results to write")
        return
    
    # create the table if it doesn't exist
    spark.sql("""
    CREATE TABLE IF NOT EXISTS openalex.pdf.grobid_processing_results (
        id STRING,
        status STRING,
        source_pdf_id STRING,
        url STRING,
        native_id STRING,
        native_id_namespace STRING,
        s3_key STRING,
        s3_path STRING,
        xml_content STRING,
        error_message STRING,
        created_date TIMESTAMP
    )
    USING DELTA
    """)
    
    results_df.write.format("delta").mode("append").saveAsTable("openalex.pdf.grobid_processing_results")
    
    print(f"Successfully appended {results_df.count()} records to grobid_processing_results")

In [0]:
print("Starting GROBID PDF processing job...")

three_days_ago = datetime.now() - timedelta(days=3)

# get existing IDs from grobid_processing_results
try:
    existing_ids_df = spark.table("openalex.pdf.grobid_processing_results")
    existing_ids = existing_ids_df.select("source_pdf_id").distinct()
except:
    print("No existing grobid_processing_results table found. Creating new.")
    existing_ids = spark.createDataFrame([], schema=StructType([StructField("source_pdf_id", StringType(), True)]))

# get candidate records from taxicab_results
taxicab_df = spark.table("openalex.taxicab.taxicab_results") \
    .filter(
        (F.col("taxicab_id").isNotNull()) & 
        (F.col("content_type").contains("pdf")) &
        (F.col("processed_date") >= F.lit(three_days_ago))
    )

# filter out records that already exist in grobid_processing_results
candidates_df = taxicab_df \
    .join(existing_ids, taxicab_df["taxicab_id"] == existing_ids["source_pdf_id"], "left_anti") \
    .select(
        F.col("taxicab_id").alias("source_pdf_id"),
        "url",
        "native_id",
        "native_id_namespace"
    )

candidates_df = candidates_df.cache()

# get the count of records to process
total_records = candidates_df.count()
print(f"Found {total_records} records to process")

In [0]:
if total_records > 0:
    write_batch_size = BATCH_SIZE * 10
    
    print(f"Using continuous batching method with {MAX_WORKERS} workers")
    print(f"Will write to database in batches of {write_batch_size} records")
    
    # Process with continuous batching
    processed_count = process_pdfs_with_continuous_batching(
        candidates_df, 
        write_size=write_batch_size
    )
    
    print(f"Successfully processed {processed_count} of {total_records} records")
    
    # clean up
    candidates_df.unpersist()
    print("GROBID PDF processing job completed successfully")
    
else:
    print("No records to process. Exiting.")