In [0]:
%pip install /Volumes/openalex/default/libraries/openalex_dlt_utils-0.2.1-py3-none-any.whl

In [0]:
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import builtins

from openalex.utils.environment import *

@dlt.table(
    name="base_sources",
    table_properties={"quality": "bronze"},
    comment="Sources from original postgresql table."
)
def base_sources():
    apc_schema = ArrayType(StructType([
        StructField("price", IntegerType(), True),
        StructField("currency", StringType(), True)
    ]))
    
    societies_schema = ArrayType(StructType([
        StructField("url", StringType(), True),
        StructField("organization", StringType(), True)
    ]))

    return (spark.table("openalex.sources.sources_from_postgres")
        .drop("is_in_doaj")
        .withColumn(
            'issns',
            when(col('issns').isNull(), None)
            .otherwise(from_json(col('issns'), ArrayType(StringType())))
        )
        .withColumn(
            'alternate_titles',
            when(col('alternate_titles').isNull(), None)
            .otherwise(from_json(col('alternate_titles'), ArrayType(StringType())))
        )
        .withColumn(
            'apc_prices',
            when(col('apc_prices').isNull(), None)
            .otherwise(from_json(col('apc_prices'), apc_schema))
        )
        .withColumn(
            'societies',
            when(col('societies').isNull(), None)
            .otherwise(from_json(col('societies'), societies_schema))
        )
    )

@dlt.table(
    name="crossref_id_mapping",
    comment="Stable ID mapping for crossref journals (reads from persistent table)",
    table_properties={"quality": "silver"}
)
def crossref_id_mapping():
    """Create/maintain mapping between crossref sources and stable OpenAlex IDs"""
    
    persistent_table = "openalex.sources.crossref_id_mapping_persistent"
    
    # check if persistent mapping table already exists
    try:
        existing_mapping = spark.table(persistent_table)
        mapping_exists = True
    except:
        existing_mapping = None
        mapping_exists = False
    
    crossref_sources = (spark.table("openalex.sources.crossref_journals_gold")
                       .select("issns_concat_id", "title")
                       .distinct())
    
    if mapping_exists:
        new_sources = (crossref_sources
                      .join(existing_mapping, "issns_concat_id", "left_anti"))
        
        if new_sources.count() > 0:
            max_mapping_id = existing_mapping.agg(max("openalex_id")).collect()[0][0]
            
            max_postgres_id = spark.table("openalex.sources.sources_from_postgres").agg(max("id")).collect()[0][0]
            
            # Use the higher of the two, plus 1 for next available
            max_mapping_safe = max_mapping_id if max_mapping_id is not None else 0
            max_postgres_safe = max_postgres_id if max_postgres_id is not None else 0
            next_id = builtins.max(max_mapping_safe, max_postgres_safe) + 1
            
            # create new mappings with deterministic ordering
            new_mappings = (new_sources
                           .orderBy("issns_concat_id")
                           .withColumn("row_num", row_number().over(
                               Window.orderBy("issns_concat_id")
                           ))
                           .withColumn("openalex_id", col("row_num") + lit(next_id - 1))
                           .withColumn("created_date", current_timestamp())
                           .select("issns_concat_id", "openalex_id", "created_date"))
            
            final_mapping = existing_mapping.unionByName(new_mappings)
        else:
            final_mapping = existing_mapping
    else:
        max_postgres_id = spark.table("openalex.sources.sources_from_postgres").agg(max("id")).collect()[0][0]
        
        max_postgres_safe = max_postgres_id if max_postgres_id is not None else 0
        start_id = max_postgres_safe + 1000000000
        
        # create initial mapping
        final_mapping = (crossref_sources
                        .orderBy("issns_concat_id")
                        .withColumn("row_num", row_number().over(
                            Window.orderBy("issns_concat_id")
                        ))
                        .withColumn("openalex_id", col("row_num") + lit(start_id - 1))
                        .withColumn("created_date", current_timestamp())
                        .select("issns_concat_id", "openalex_id", "created_date"))
    
    # save to persistent table (survives full refresh)
    (final_mapping
     .write
     .mode("overwrite")
     .option("mergeSchema", "true")
     .saveAsTable(persistent_table))
    
    return final_mapping

@dlt.table(
    name="datacite_id_mapping_extended",
    comment="Stable ID mapping for unmatched datacite sources (reads from persistent table)",
    table_properties={"quality": "silver"}
)
def datacite_id_mapping_extended():
    """Create/maintain mapping for datacite sources that don't have existing OpenAlex IDs"""
    
    persistent_table = "openalex.sources.datacite_id_mapping_extended_persistent"
    
    try:
        existing_mapping = spark.table(persistent_table)
        mapping_exists = True
    except:
        existing_mapping = None
        mapping_exists = False
    
    # get datacite sources that don't have existing OpenAlex mappings in datacite_to_openalex_mapping table (not extended mapping)
    datacite_sources = spark.read.table("openalex.sources.datacite_sources")
    existing_openalex_mapping = spark.read.table("openalex.sources.datacite_to_openalex_mapping")
    
    unmatched_sources = (datacite_sources
                        .join(existing_openalex_mapping, 
                              datacite_sources["id"] == existing_openalex_mapping["datacite_id"], 
                              "left")
                        .filter(col("openalex_id").isNull())
                        .select(datacite_sources["id"].alias("datacite_id"), "display_name")
                        .distinct())
    
    if mapping_exists:
        # get sources not in existing extended mapping
        new_sources = (unmatched_sources
                      .join(existing_mapping, "datacite_id", "left_anti"))
        
        if new_sources.count() > 0:
            max_mapping_id = existing_mapping.agg(max("openalex_id")).collect()[0][0]
            
            max_postgres_id = spark.table("openalex.sources.sources_from_postgres").agg(max("id")).collect()[0][0]
            
            try:
                max_crossref_id = spark.table("openalex.sources.crossref_id_mapping_persistent").agg(max("openalex_id")).collect()[0][0]
            except:
                max_crossref_id = 0
            
            # use the highest ID found, plus 1 for next available
            max_mapping_safe = max_mapping_id if max_mapping_id is not None else 0
            max_postgres_safe = max_postgres_id if max_postgres_id is not None else 0
            max_crossref_safe = max_crossref_id if max_crossref_id is not None else 0
            next_id = builtins.max(max_mapping_safe, max_postgres_safe, max_crossref_safe) + 1
            
            new_mappings = (new_sources
                           .orderBy("datacite_id")
                           .withColumn("row_num", row_number().over(
                               Window.orderBy("datacite_id")
                           ))
                           .withColumn("openalex_id", col("row_num") + lit(next_id - 1))
                           .withColumn("created_date", current_timestamp())
                           .select("datacite_id", "openalex_id", "created_date"))
            
            final_mapping = existing_mapping.unionByName(new_mappings)
        else:
            final_mapping = existing_mapping
    else:
        max_postgres_id = spark.table("openalex.sources.sources_from_postgres").agg(max("id")).collect()[0][0]
        
        try:
            max_crossref_id = spark.table("openalex.sources.crossref_id_mapping_persistent").agg(max("openalex_id")).collect()[0][0]
        except:
            max_crossref_id = 0
        
        # start well above existing IDs - add 2 billion for safety buffer (different from crossref)
        max_postgres_safe = max_postgres_id if max_postgres_id is not None else 0
        max_crossref_safe = max_crossref_id if max_crossref_id is not None else 0
        start_id = builtins.max(max_postgres_safe, max_crossref_safe) + 2000000000
        
        # create initial mapping
        final_mapping = (unmatched_sources
                        .orderBy("datacite_id")
                        .withColumn("row_num", row_number().over(
                            Window.orderBy("datacite_id")
                        ))
                        .withColumn("openalex_id", col("row_num") + lit(start_id - 1))
                        .withColumn("created_date", current_timestamp())
                        .select("datacite_id", "openalex_id", "created_date"))
    
    # save to persistent table (survives full refresh)
    (final_mapping
     .write
     .mode("overwrite")
     .option("mergeSchema", "true")
     .saveAsTable(persistent_table))
    
    return final_mapping

@dlt.table(
   name="crossref_journals_unmatched",
   comment="Crossref journals that have NO matching ISSNs with existing base sources"
)
def crossref_journals_unmatched():
    # get all ISSNs from base sources
    base_issns = (dlt.read("base_sources")
                 .select("id", explode("issns").alias("issn"))
                 .select("issn")
                 .distinct())
    
    # get crossref journals with their exploded ISSNs
    crossref_with_issns = (spark.table("openalex.sources.crossref_journals_gold")
                          .select("*", explode("issns").alias("issn")))
    
    # find crossref journals that have at least one matching ISSN
    crossref_with_matches = (crossref_with_issns
                           .join(base_issns, "issn", "inner")
                           .select("issns_concat_id")
                           .distinct())
    
    # find crossref journals that have NO matching ISSNs
    crossref_completely_unmatched = (spark.table("openalex.sources.crossref_journals_gold")
                                   .join(crossref_with_matches, "issns_concat_id", "left_anti"))
    
    # join with the stable ID mapping
    id_mapping = dlt.read("crossref_id_mapping")
    
    return (crossref_completely_unmatched
           .join(id_mapping, "issns_concat_id", "inner")
           .select(
               col("openalex_id").alias("id"),
               col("title").alias("display_name"),
               col("issns"),
               col("publisher"),
               lit(False).alias("is_oa"),
               lit("journal").alias("type")
           )
           .withColumn("issn", when(size(col("issns")) > 0, col("issns")[0]).otherwise(lit(None)))
    )
    
@dlt.table(
    name="datacite_sources_unmatched",
    comment="Datacite sources that have NO match existing sources"
)
def datacite_sources_unmatched():
    df = spark.read.table("openalex.sources.datacite_sources")
    df_mapping = spark.read.table("openalex.sources.datacite_to_openalex_mapping")
    
    # unmatched datacite sources
    unmatched_df = df.join(
        df_mapping,
        col("id") == col("datacite_id"),
        "left"
    ).filter(col("openalex_id").isNull())

    # join with the stable ID mapping - alias the openalex_id to avoid ambiguity
    id_mapping = dlt.read("datacite_id_mapping_extended").select(
        col("datacite_id").alias("mapping_datacite_id"),
        col("openalex_id").alias("new_openalex_id"),
        col("created_date").alias("mapping_created_date")
    )

    return (unmatched_df
        .join(id_mapping, 
              unmatched_df["id"] == id_mapping["mapping_datacite_id"], 
              "inner")
        .select(
            col("new_openalex_id").alias("id"),
            col("id").alias("datacite_id"),
            col("display_name"),
            array_distinct(
                array_compact(
                    array(
                        col("issns.issnl"),
                        col("issns.print"),
                        col("issns.electronic")
                    )
                )
            ).alias("issns"),
            col("provider_name"),
            col("type")
        )
        .withColumn("issn", when(size(col("issns")) > 0, col("issns")[0]).otherwise(lit(None)))
        .withColumn("publisher", col("provider_name"))
        .withColumn("is_oa", lit(True))
        .withColumn("type", 
            when(col("type") == "periodical", "journal")
            .otherwise(col("type"))
        )
        .withColumn("updated_date", current_timestamp())
        .withColumn("created_date", current_timestamp().cast("string"))
        .drop("provider_name")
    )

@dlt.table(
    name="sources",
    comment=f"Combined sources with DOAJ status and sample PMH records in {ENV.upper()}"
)
def sources():
    # combine base sources with unmatched crossref and datacite sources
    base_combined = (
        dlt.read("base_sources")
        .unionByName(
            dlt.read("crossref_journals_unmatched"),
            allowMissingColumns=True
        )
        .unionByName(
            dlt.read("datacite_sources_unmatched"),
            allowMissingColumns=True
        )
    )
    
    # get DOAJ ISSNs
    doaj = (
        spark.table("openalex.sources.doaj_from_csv")
        .selectExpr("explode(issns) as doaj_issn", "oa_start_year")
        .distinct()
    )

    # get curation requests and convert to high OA rate table format
    base_high_oa_rate_issns = (
        spark.table("openalex.sources.high_oa_rate_issns")
        .select("issn_l", "oa_year")
        .distinct()
    )

    # Get most recent approved curation per ISSN (even if sets value to false)
    window = Window.partitionBy("issn").orderBy(col("ingestion_timestamp").desc())

    curation_requests = (
        spark.table("openalex.unpaywall.journal_curation_requests")
        .filter(col("approved") == "yes")
        .withColumn("row_num", row_number().over(window))
        .filter(col("row_num") == 1)
        .select(
            col("issn").alias("issn_l"),
            col("new_is_oa").alias("is_oa_high_oa_rate"),
            col("new_oa_date").cast("int").alias("high_oa_rate_start_year")
        )
    )

    # Apply curation overrides — if curation exists, it takes full priority
    high_oa_rate_issns = (
        base_high_oa_rate_issns.alias("base")
        .join(curation_requests.alias("cur"), on="issn_l", how="outer")
        .select(
            coalesce(col("cur.issn_l"), col("base.issn_l")).alias("issn_l"),
            when(col("cur.is_oa_high_oa_rate").isNotNull(), col("cur.is_oa_high_oa_rate"))
                .otherwise(col("base.oa_year").isNotNull()).alias("is_oa_high_oa_rate"),
            when(col("cur.issn_l").isNotNull(), col("cur.high_oa_rate_start_year"))
                .otherwise(col("base.oa_year")).alias("high_oa_rate_start_year")
        )
    )
    
    # process records with and without ISSNs
    sources_with_issns = (
        base_combined
        .filter(col("issns").isNotNull() & (size(col("issns")) > 0))
        .withColumn("exploded_issn", explode(col("issns")))
        .join(doaj, col("exploded_issn") == doaj["doaj_issn"], "left")
        .join(high_oa_rate_issns, col("exploded_issn") == high_oa_rate_issns["issn_l"], "left")
        .withColumn("is_in_doaj", 
                    when(doaj["doaj_issn"].isNotNull(), True).otherwise(False))
        .withColumn("is_oa_high_oa_rate", 
            when((lower(col("publisher")).startswith("mdpi")) & (col("is_in_doaj") == False), True)
            .otherwise(coalesce(high_oa_rate_issns["is_oa_high_oa_rate"], lit(False))))
        .drop("doaj_issn", "issn_l")
        .groupBy("id", *[c for c in base_combined.columns if c not in ["id", "issns"]])
        .agg( # make sure the issns are sorted and deduplicated - avoid non-deterministic sorting
            array_sort(collect_set("exploded_issn")).alias("issns"),
            max("is_in_doaj").alias("is_in_doaj"),
            max("is_oa_high_oa_rate").alias("is_oa_high_oa_rate"),
            max("oa_start_year").alias("oa_start_year"),
            max("high_oa_rate_start_year").alias("high_oa_rate_start_year")
        )        
        .withColumnRenamed("oa_start_year", "is_in_doaj_start_year")
        .withColumn("rank", row_number().over( # deduplicate by ISSN
            Window.partitionBy("issn").orderBy(
                size("issns").desc(),
                length("display_name").desc()
            )
        ))
        .filter((col("issn").isNull()) | (col("rank") == 1))
        .drop("rank")
    )
    
    sources_null_issns = (
        base_combined
        .filter((col("issns").isNull()) | (size(col("issns")) == 0))
        .withColumn("is_in_doaj", lit(False))
        .withColumn("is_oa_high_oa_rate", 
            when(lower(col("publisher")).startswith("mdpi"), True)
            .otherwise(lit(False)))
        .withColumn("oa_start_year", lit(None).cast("int"))
        .withColumn("high_oa_rate_start_year", lit(None).cast("int"))
        .withColumnRenamed("oa_start_year", "is_in_doaj_start_year")
        .withColumn("datacite_id", col("datacite_id"))
    )
    
    # combine ISSN and non-ISSN records
    sources_with_doaj_and_oa = (
        sources_with_issns
        .unionByName(sources_null_issns)
        .withColumn("rank", row_number().over(
            Window.partitionBy("id").orderBy(
                # prefer records that have ISSNs in de-duplication
                (col("issns").isNotNull() & (size(col("issns")) > 0)).desc(),
                col("updated_date").desc(),
                length("display_name").desc()  # optional: tie-breaker
            )
        ))
        .filter(col("rank") == 1)
        .drop("rank", "has_issns")
    )

    # update is_oa column based on is_oa_high_oa_rate
    sources_with_updated_oa = (
        sources_with_doaj_and_oa
        .withColumn("is_oa",
                    when(col("is_in_doaj") == True, True)
                    .when(col("is_oa_high_oa_rate") == True, True)
                    .when(col("is_oa_high_oa_rate") == False, False)
                    .otherwise(col("is_oa")))
    )
    
    # join with endpoint mapping
    endpoints_table = (
        spark.table("openalex.sources.endpoint_mapping")
        .select("endpoint_id", "sample_pmh_record")
    )

    # join with datacite id mapping to get all datacite_ids for a given openalex_id
    datacite_id_mapping = (
        spark.table("openalex.sources.datacite_to_openalex_mapping")
        .groupBy("openalex_id")
        .agg(
            collect_set("datacite_id").alias("all_datacite_ids")
        )
    )

    final_result = (
        sources_with_updated_oa
        .join(
            endpoints_table,
            sources_with_updated_oa["repository_id"] == endpoints_table["endpoint_id"],
            "left"
        )
        .join(
            datacite_id_mapping,
            sources_with_updated_oa["id"] == datacite_id_mapping["openalex_id"],
            "left"
        )
        .withColumn("datacite_ids", 
            when(col("datacite_id").isNotNull(), 
                array(col("datacite_id")))
            .when(col("all_datacite_ids").isNotNull(), 
                col("all_datacite_ids"))
            .otherwise(array())
        )
        # remove "Deleted Journal" and DOAJ source records from upstream
        .filter(~col("id").isin(4317411217, 4306401280))
        .drop("endpoint_id", "openalex_id", "all_datacite_ids")
    )
    return final_result