In [0]:
%pip uninstall openalex_taxicab -y
%pip uninstall openalex_http -y
%pip install git+https://github.com/ourresearch/openalex-taxicab

In [0]:
import sys
import boto3
# import dlt
import os
import pyspark.sql.functions as F
from pyspark.sql.types import MapType, StringType, StructType, StructField, BooleanType, TimestampType, IntegerType, FloatType, ArrayType
from dataclasses import asdict, is_dataclass
from inspect import getmembers
from openalex_http import http_cache
import sqlite3
from typing import List, Dict, Any
from botocore.config import Config
import time
from tenacity import RetryError
import logging
from threading import Thread, Lock
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
from multiprocessing import Manager, Value, Lock
import warnings
import signal
from functools import partial
from urllib3.exceptions import InsecureRequestWarning

In [0]:

logger = logging.getLogger('openalex_http')
logger.setLevel(logging.CRITICAL)

# Suppress only the InsecureRequestWarning from urllib3
warnings.simplefilter('ignore', InsecureRequestWarning)

In [0]:
# Setup env variables
scope_name = "openalex-elt"
secrets = dbutils.secrets.list(scope=scope_name)
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')


env = {}

for secret in secrets:
    key = secret.key
    value = dbutils.secrets.get(scope=scope_name, key=key)
    env[key] = value
os.environ.update(env)

In [0]:
from openalex_taxicab.harvest import HarvestResult
from openalex_taxicab.legacy.harvest import PublisherLandingPageHarvester

In [0]:
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')

S3_CONFIG = Config(
    connect_timeout=10, 
    read_timeout=20,
    signature_version='s3v4'      
)


In [0]:
NUM_PROCESSES = 28
NUM_THREADS_PER_PROCESS = 50

In [0]:
HARVEST_TIMEOUT = 60

def harvest_single_doi(doi, plp_harvester: PublisherLandingPageHarvester) -> Dict[str, Any]:
    try:
        result = plp_harvester.harvest(doi)
        d = result.to_dict()
        d['error'] = None
        d['last_harvested'] = d['last_harvested_dt']
        for key in {'content', 'last_harvested_dt'}:
            if key in d:
                del d[key]
    except Exception as e:
        d = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
        d['error'] = str(e)
    return d

def worker_process(dois, aws_access_key, aws_secret_key, s3_config, num_threads, progress_dict):
    s3 = boto3.client('s3', 
                aws_access_key_id=aws_access_key,
                aws_secret_access_key=aws_secret_key,
                config=s3_config)
    http_cache.initialize(env)
    plp_harvester = PublisherLandingPageHarvester(s3)

    results = []
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        future_to_doi = {executor.submit(harvest_single_doi, doi, plp_harvester): doi for doi in dois}
        for future in as_completed(future_to_doi):
            doi = future_to_doi[future]
            try:
                result = future.result(timeout=HARVEST_TIMEOUT)
            except TimeoutError:
                error_msg = f"Timeout: DOI harvest exceeded {HARVEST_TIMEOUT} seconds"
                result = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
                result['error'] = error_msg
            except Exception as e:
                error_msg = f"Unexpected error: {str(e)}"
                result = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
                result['error'] = error_msg

            with progress_dict['lock']:
                progress_dict['processed'] += 1
                if result['error'] is not None:
                    progress_dict['errors'] += 1
                    progress_dict['last_error'] = f"DOI {doi}: {result['error']}"
            
            results.append((doi, result))
    
    return results

def harvest(dois: List[str], num_processes: int, num_threads: int):
    manager = Manager()
    progress_dict = manager.dict()
    progress_dict['processed'] = 0
    progress_dict['errors'] = 0
    progress_dict['last_error'] = ''
    progress_dict['lock'] = manager.Lock()

    start = time.time()
    total_dois = len(dois)

    def log_progress():
        while progress_dict['processed'] < total_dois:
            current_processed = progress_dict['processed']
            current_errors = progress_dict['errors']
            elapsed_time = time.time() - start
            speed = current_processed / (elapsed_time / 3600) if elapsed_time > 0 else 0
            error_pct = (current_errors/current_processed)*100 if current_processed > 0 else 0
            last_error = progress_dict['last_error']
            
            print(f"Processed: {current_processed}/{total_dois} | "
                  f"Speed: {speed:.2f} DOIs/hour | "
                  f"Errors: {error_pct:.2f}% | "
                  f"Elapsed: {elapsed_time:.2f} seconds | "
                  f"Last Error: {last_error}", flush=True)
            
            time.sleep(5)

    log_thread = Thread(target=log_progress)
    log_thread.start()

    chunk_size = len(dois) // num_processes
    doi_chunks = [dois[i:i + chunk_size] for i in range(0, len(dois), chunk_size)]

    worker_func = partial(worker_process, 
                          aws_access_key=AWS_ACCESS_KEY, 
                          aws_secret_key=AWS_SECRET_KEY, 
                          s3_config=S3_CONFIG, 
                          num_threads=num_threads, 
                          progress_dict=progress_dict)

    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        all_results = list(executor.map(worker_func, doi_chunks))

    log_thread.join()

    # Flatten results
    results = {doi: result for chunk in all_results for doi, result in chunk}
    return results

harvest_result_schema = StructType([
    StructField("doi", StringType(), True),
    StructField("s3_path", StringType(), True),
    StructField("last_harvested", TimestampType(), True),
    StructField("url", StringType(), True),
    StructField("code", IntegerType(), True),  # Optional field
    StructField("elapsed", FloatType(), True),  # Optional field
    StructField("resolved_url", StringType(), True),
    StructField("content_type", StringType(), True),
    StructField("is_soft_block", BooleanType(), True)
])

In [0]:
crossref_df = spark.read.format("delta").table("crossref.crossref_works")
crossref_df.display()
crossref_df.cache().count()


In [0]:
harvested_df = spark.read.format("delta").table("harvest.harvest_results")
harvested_df_filtered = harvested_df.filter(F.col("url").startswith("https://doi.org/"))
harvested_df_with_doi = harvested_df_filtered.withColumn("doi", F.regexp_replace(F.col("url"), "^https://doi.org/",""))
unharvested_df = crossref_df.join(harvested_df, ["doi"], "left_anti").select("doi")
# unharvested_df = unharvested_df.orderBy(F.rand()).limit(100000)
unharvested_df.display()
unharvested_df.cache().count()


In [0]:
doi_list = [row['doi'] for row in unharvested_df.collect()]
results = harvest(doi_list, NUM_PROCESSES, NUM_THREADS_PER_PROCESS)

harvested_data = []
for doi, result in results.items():
    row = {'doi': doi}
    row.update(result)
    harvested_data.append(row)

harvest_result_df = spark.createDataFrame(harvested_data, schema=harvest_result_schema)

In [0]:
# harvest_result_df = harvest_result_df.select(
#     F.explode("harvest_results").alias("doi", "harvest_result")
# )
# harvest_result_df = harvest_result_df.select(
#     "doi",
#     F.col("harvest_result.*")
# )
harvest_result_df.display()
harvest_result_df.cache().count()

In [0]:
harvest_result_df.write \
    .format("delta") \
    .mode("append") \
    .saveAsTable("harvest.harvest_results")