In [0]:
%pip uninstall openalex_taxicab -y
%pip install git+https://github.com/ourresearch/openalex-taxicab.git

In [0]:
import sys
import boto3
import dlt
import os
import pyspark.sql.functions as F
from pyspark.sql.types import MapType, StringType, StructType, StructField, BooleanType, TimestampType, IntegerType, FloatType
from dataclasses import asdict, is_dataclass
from inspect import getmembers
import sqlite3

In [0]:
# Setup env variables
scope_name = "openalex-elt"
secrets = dbutils.secrets.list(scope=scope_name)
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')



env = {}

for secret in secrets:
    key = secret.key
    value = dbutils.secrets.get(scope=scope_name, key=key)
    env[key] = value
os.environ.update(env)

In [0]:
from openalex_taxicab.harvest import HarvestResult
from openalex_taxicab.legacy.harvest import PublisherLandingPageHarvester

In [0]:
AWS_ACCESS_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_ACCESS_KEY_ID')
AWS_SECRET_KEY = dbutils.secrets.get(scope=scope_name, key='AWS_SECRET_ACCESS_KEY')
LOCAL_S3_DB_LOOKUP_PATH = '/tmp/s3_lookup.db'

In [0]:
def explode_dict_column(df, dict_column_name, schema=None):
    if df.rdd.isEmpty():
        return df
    sample_dict = df.select(dict_column_name).filter(F.col(dict_column_name).isNotNull()).first()
    if not sample_dict:
        return df

    keys = sample_dict[0].keys()
    select_exprs = [F.when(F.col(dict_column_name).isNotNull(), F.col(dict_column_name)[key]).alias(key) for key in keys]
    exploded_df = df.select('*', *select_exprs).drop(dict_column_name)
    
    if schema and isinstance(schema, StructType):
        exploded_df = exploded_df.select(schema.fieldNames())
    
    return exploded_df


def harvest(doi):
    try:
        from openalex_http import http_cache
        from openalex_taxicab.legacy.s3_cache import download_s3_lookup_db
        s3 = boto3.client('s3', 
                    aws_access_key_id=AWS_ACCESS_KEY,
                    aws_secret_access_key=AWS_SECRET_KEY)
        if not os.path.exists(LOCAL_S3_DB_LOOKUP_PATH):
            s3_lookup_temp_path = download_s3_lookup_db(s3)
            os.rename(s3_lookup_temp_path, LOCAL_S3_DB_LOOKUP_PATH)
        else:
            s3_lookup_temp_path = LOCAL_S3_DB_LOOKUP_PATH 
        http_cache.initialize(env)
        db_conn = sqlite3.connect(s3_lookup_temp_path)
        plp_harvester = PublisherLandingPageHarvester(s3, db_conn)
        result: HarvestResult = plp_harvester.harvest(doi)
        d = result.to_dict()
        d['error'] = None
        for key in {'content', 'last_harvested_dt'}:
            if key in d:
                del d[key]
        return d
    except Exception as e:
        d = HarvestResult(s3_path=None, last_harvested=None, content=None, url=None).to_dict()
        d['error'] = str(e)
        print(f'Error harvesting {doi}: {e}')
        return None

harvest_udf = F.udf(harvest, MapType(StringType(), StringType()))

harvest_result_schema = StructType([
    StructField("s3_path", StringType(), True),
    StructField("last_harvested", TimestampType(), True),
    StructField("url", StringType(), True),
    StructField("code", IntegerType(), True),  # Optional field
    StructField("elapsed", FloatType(), True),  # Optional field
    StructField("resolved_url", StringType(), True),
    StructField("content_type", StringType(), True),
    StructField("is_soft_block", BooleanType(), True)
])


In [0]:
# def harvested_content():
#     crossref_df = spark.read.format("delta").load("dbfs:/pipelines/063e402f-4261-43ca-8952-8bdc31aa3d48/tables/crossref_works").limit(100)
#     dois_df = crossref_df.select('DOI')
#     harvest_result_df = dois_df.withColumn('harvest_result', harvest_udf(dois_df['DOI']))
#     exploded_df = explode_dict_column(harvest_result_df, 'harvest_result')
#     return exploded_df
# harvested_content().display()

In [0]:
@dlt.table(
    name="harvested_content",
    comment="Metadata about harvested URLs (S3 path, response code, etc)",
    table_properties={'quality': 'bronze'}
)
def harvested_content():
    crossref_df = dlt.read("crossref_works")
    dois_df = crossref_df.select('DOI').repartition(30)
    harvest_result_df = dois_df.withColumn('harvest_result', harvest_udf(dois_df['DOI']))
    exploded_df = explode_dict_column(harvest_result_df, 'harvest_result')
    return exploded_df
