In [0]:
import sys
import boto3
# import dlt
import os
import pyspark.sql.functions as F
from pyspark.sql.types import MapType, StringType, StructType, StructField, BooleanType, TimestampType, IntegerType, FloatType, ArrayType
from dataclasses import asdict, is_dataclass
from inspect import getmembers
from openalex_http import http_cache
import sqlite3
from typing import List
from botocore.config import Config
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

In [0]:
# Setup env variables
scope_name = "openalex-elt"
secrets = dbutils.secrets.list(scope=scope_name)
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')


env = {}

for secret in secrets:
    key = secret.key
    value = dbutils.secrets.get(scope=scope_name, key=key)
    env[key] = value
os.environ.update(env)

In [0]:
from openalex_taxicab.harvest import HarvestResult
from openalex_taxicab.legacy.harvest import PublisherLandingPageHarvester

In [0]:
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')

S3_CONFIG = Config(
    connect_timeout=10, 
    read_timeout=20      
)


In [0]:
BATCH_SIZE = 30
PARTITIONS = 180

In [0]:
def harvest_single_doi(doi, plp_harvester: PublisherLandingPageHarvester) -> HarvestResult:
    # d = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
    # d['error'] = 'DUMMY'
    # return d
    try:
        result: HarvestResult = plp_harvester.harvest(doi)
        d = result.to_dict()
        d['error'] = None
        d['last_harvested'] = d['last_harvested_dt']
        for key in {'content', 'last_harvested_dt'}:
            if key in d:
                del d[key]
        return d
    except Exception as e:
        d = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
        d['error'] = str(e)
        print(f'Error harvesting {doi}: {e}')
        return None


def harvest(dois: List[str]):
    s3 = boto3.client('s3', 
                aws_access_key_id=AWS_ACCESS_KEY,
                aws_secret_access_key=AWS_SECRET_KEY,
                config=S3_CONFIG)
    http_cache.initialize(env)
    plp_harvester = PublisherLandingPageHarvester(s3)

    with ThreadPoolExecutor(max_workers=len(dois)) as executor:
        future_to_doi = {executor.submit(harvest_single_doi, doi, plp_harvester): doi for doi in dois}
        results = {}
        for future in as_completed(future_to_doi):
            doi = future_to_doi[future]
            result = future.result()
            results[doi] = result
        
        return results

harvest_result_schema = StructType([
    StructField("s3_path", StringType(), True),
    StructField("last_harvested", TimestampType(), True),
    StructField("url", StringType(), True),
    StructField("code", IntegerType(), True),  # Optional field
    StructField("elapsed", FloatType(), True),  # Optional field
    StructField("resolved_url", StringType(), True),
    StructField("content_type", StringType(), True),
    StructField("is_soft_block", BooleanType(), True)
])

harvest_udf = F.udf(harvest, MapType(StringType(), harvest_result_schema))

In [0]:
crossref_df = spark.read.format("delta").table("crossref.crossref_works")
crossref_df.display()
crossref_df.cache().count()


In [0]:
harvested_df = spark.read.format("delta").table("harvest.harvest_results")
harvested_df_filtered = harvested_df.filter(F.col("url").startswith("https://doi.org/"))
harvested_df_with_doi = harvested_df_filtered.withColumn("doi", F.regexp_replace(F.col("url"), "^https://doi.org/",""))
unharvested_df = crossref_df.join(harvested_df, ["doi"], "left_anti").select("doi")
unharvested_df = unharvested_df.limit(100000)
unharvested_df.display()
unharvested_df.cache().count()


In [0]:
unharvested_df = unharvested_df.withColumn("batch_id", F.floor(F.monotonically_increasing_id() / BATCH_SIZE))
unharvested_batches_df = unharvested_df.groupBy("batch_id").agg(F.collect_list("doi").alias("doi_list"))


unharvested_batches_df = unharvested_batches_df.repartition(PARTITIONS)
unharvested_batches_df.display()
unharvested_batches_df.cache().count()

In [0]:
harvest_result_df = unharvested_batches_df.withColumn('harvest_results', harvest_udf(unharvested_batches_df['doi_list'])).select('harvest_results')

# processed_batches = 0
# processed_dois = 0
# start_time = time.time()
# total_dois = unharvested_df.cache().count()

# for batch in harvest_result_df.toLocalIterator():
#     processed_batches += 1
#     processed_dois += len(batch['harvest_results'])
#     current_time = time.time()
#     elapsed_time = current_time - start_time
#     speed = processed_dois / (elapsed_time / 3600) if elapsed_time > 0 else 0
 
#     print(f"Processed: {processed_dois}/{total_dois} | Speed: {speed:.2f} DOIs/hour | Elapsed time: {elapsed_time:.2f} seconds", flush=True)
# print('Finished harvesting')

In [0]:
harvest_result_df.display()

In [0]:
exploded_df = harvest_result_df.select(
    F.explode("harvest_results").alias("doi", "harvest_result")
)f
final_df = exploded_df.select(
    "doi",
    F.col("harvest_result.*")
)
final_df.display()
final_df.cache().count()

In [0]:
final_df.write \
    .format("delta") \
    .mode("append") \
    .saveAsTable("harvest.harvest_results")