In [0]:
%pip install /Volumes/openalex/default/libraries/openalex_dlt_utils-0.2.1-py3-none-any.whl

In [0]:
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import builtins

from openalex.utils.environment import *

# ============================================================================
# CONSTANTS AND SCHEMAS
# ============================================================================

ID_BUFFER_CROSSREF = 1_000_000_000
ID_BUFFER_DATACITE = 2_000_000_000

APC_SCHEMA = ArrayType(StructType([
    StructField("price", IntegerType(), True),
    StructField("currency", StringType(), True),
]))

SOCIETIES_SCHEMA = ArrayType(StructType([
    StructField("url", StringType(), True),
    StructField("organization", StringType(), True),
]))

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def remove_duplicate_issns(df):
    """
    Remove duplicate ISSNs across records, keeping them only in the record with the lowest ID.
    
    When the same ISSN appears in multiple source records, this function ensures it's kept
    only in the record with the lowest ID, removing it from all other records.
    """
    df_exploded = df.withColumn('issn', explode('issns')).filter(col('issn').isNotNull())
    
    # for each ISSN, find the record with the lowest ID that should keep it
    issn_keeper_mapping = (
        df_exploded
        .groupBy('issn')
        .agg(
            min('id').alias('keeper_id'),
            count('id').alias('record_count')
        )
        .filter(col('record_count') > 1)
        .select('issn', 'keeper_id')
    )
    
    # create a mapping of which ISSNs each record should remove
    issns_to_remove = (
        df_exploded
        .join(issn_keeper_mapping, 'issn')
        .filter(col('id') != col('keeper_id'))
        .groupBy('id')
        .agg(array_sort(collect_set('issn')).alias('issns_to_remove'))
    )
    
    # apply the deduplication by removing duplicate ISSNs from higher ID records
    return (
        df
        .join(issns_to_remove, 'id', 'left')
        .withColumn(
            'issns_deduplicated',
            when(col('issns_to_remove').isNotNull(),
                 array_except(col('issns'), col('issns_to_remove'))
            ).otherwise(col('issns'))
        )
        .drop('issns', 'issns_to_remove')
        .withColumnRenamed('issns_deduplicated', 'issns')
    )


def get_max_existing_id(*table_names):
    """
    Get the maximum ID across multiple tables.
    """
    max_ids = []
    for table_name in table_names:
        try:
            # Handle different column names for different tables
            if "sources_from_postgres" in table_name:
                max_id = spark.table(table_name).agg(max("id")).collect()[0][0]
            else:
                max_id = spark.table(table_name).agg(max("openalex_id")).collect()[0][0]
            
            if max_id is not None:
                max_ids.append(max_id)
        except:
            continue
    
    return builtins.max(max_ids) if max_ids else 0


def create_id_mappings(new_sources, start_id, id_column, order_column):
    """
    Create OpenAlex ID mappings for new sources with deterministic ordering.
    """
    return (
        new_sources
        .orderBy(order_column)
        .withColumn("row_num", row_number().over(Window.orderBy(order_column)))
        .withColumn("openalex_id", col("row_num") + lit(start_id - 1))
        .withColumn("created_date", current_timestamp())
        .select(id_column, "openalex_id", "created_date")
    )


def manage_persistent_id_mapping(
    sources_df,
    persistent_table,
    id_column,
    id_buffer,
    additional_max_tables=None
):
    """
    Generic function to manage persistent ID mappings for sources.
    """
    # Check if persistent mapping exists
    try:
        existing_mapping = spark.table(persistent_table)
        mapping_exists = True
    except:
        existing_mapping = None
        mapping_exists = False
    
    # Tables to check for max IDs
    tables_to_check = ["openalex.sources.sources_from_postgres"]
    if additional_max_tables:
        tables_to_check.extend(additional_max_tables)
    
    if mapping_exists:
        # Find new sources not in existing mapping
        new_sources = sources_df.join(existing_mapping, id_column, "left_anti")
        
        if new_sources.count() > 0:
            # Get max ID from existing mapping and other tables
            max_mapping_id = existing_mapping.agg(max("openalex_id")).collect()[0][0] or 0
            max_other_id = get_max_existing_id(*tables_to_check)
            next_id = builtins.max(max_mapping_id, max_other_id) + 1
            
            # Create new mappings
            new_mappings = create_id_mappings(new_sources, next_id, id_column, id_column)
            final_mapping = existing_mapping.unionByName(new_mappings)
        else:
            final_mapping = existing_mapping
    else:
        # Create initial mapping with buffer
        max_existing_id = get_max_existing_id(*tables_to_check)
        start_id = max_existing_id + id_buffer
        
        final_mapping = create_id_mappings(sources_df, start_id, id_column, id_column)
    
    # Save to persistent table
    (final_mapping
     .write
     .mode("overwrite")
     .option("mergeSchema", "true")
     .saveAsTable(persistent_table))
    
    return final_mapping

# ============================================================================
# DLT PIPELINE TABLES
# ============================================================================

@dlt.table(
    name="base_sources",
    table_properties={"quality": "bronze"},
    comment="Sources from original postgresql table with ISSN deduplication."
)
def base_sources():
    df = (spark.table("openalex.sources.sources_from_postgres")
        .drop("is_in_doaj")
        .filter(col("merge_into_id").isNull())
        .withColumn(
            'issns',
            when(col('issns').isNull(), None)
            .otherwise(from_json(col('issns'), ArrayType(StringType())))
        )
        .withColumn(
            'alternate_titles',
            when(col('alternate_titles').isNull(), None)
            .otherwise(from_json(col('alternate_titles'), ArrayType(StringType())))
        )
        .withColumn(
            'apc_prices',
            when(col('apc_prices').isNull(), None)
            .otherwise(from_json(col('apc_prices'), APC_SCHEMA))
        )
        .withColumn(
            'societies',
            when(col('societies').isNull(), None)
            .otherwise(from_json(col('societies'), SOCIETIES_SCHEMA))
        )
    )
    
    return remove_duplicate_issns(df)

@dlt.table(
    name="crossref_id_mapping",
    comment="Stable ID mapping for crossref journals (reads from persistent table)",
    table_properties={"quality": "silver"}
)
def crossref_id_mapping():
    """Create/maintain mapping between crossref sources and stable OpenAlex IDs."""
    
    crossref_sources = (
        spark.table("openalex.sources.crossref_journals_gold")
        .select("issns_concat_id", "title")
        .distinct()
    )
    
    return manage_persistent_id_mapping(
        sources_df=crossref_sources,
        persistent_table="openalex.sources.crossref_id_mapping_persistent",
        id_column="issns_concat_id",
        id_buffer=ID_BUFFER_CROSSREF,
        additional_max_tables=None
    )

@dlt.table(
    name="datacite_id_mapping_extended",
    comment="Stable ID mapping for unmatched datacite sources (reads from persistent table)",
    table_properties={"quality": "silver"}
)
def datacite_id_mapping_extended():
    """Create/maintain mapping for datacite sources that don't have existing OpenAlex IDs."""
    
    # Get datacite sources without existing OpenAlex mappings
    datacite_sources = spark.read.table("openalex.sources.datacite_sources")
    existing_openalex_mapping = spark.read.table("openalex.sources.datacite_to_openalex_mapping")
    
    unmatched_sources = (
        datacite_sources
        .join(
            existing_openalex_mapping,
            datacite_sources["id"] == existing_openalex_mapping["datacite_id"],
            "left"
        )
        .filter(col("openalex_id").isNull())
        .select(datacite_sources["id"].alias("datacite_id"), "display_name")
        .distinct()
    )
    
    return manage_persistent_id_mapping(
        sources_df=unmatched_sources,
        persistent_table="openalex.sources.datacite_id_mapping_extended_persistent",
        id_column="datacite_id",
        id_buffer=ID_BUFFER_DATACITE,
        additional_max_tables=["openalex.sources.crossref_id_mapping_persistent"]
    )

@dlt.table(
   name="crossref_journals_unmatched",
   comment="Crossref journals that have NO matching ISSNs with existing base sources"
)
def crossref_journals_unmatched():
    # get all ISSNs from base sources
    base_issns = (dlt.read("base_sources")
                 .select("id", explode("issns").alias("issn"))
                 .select("issn")
                 .distinct())
    
    # get crossref journals with their exploded ISSNs
    crossref_with_issns = (spark.table("openalex.sources.crossref_journals_gold")
                          .select("*", explode("issns").alias("issn")))
    
    # find crossref journals that have at least one matching ISSN
    crossref_with_matches = (crossref_with_issns
                           .join(base_issns, "issn", "inner")
                           .select("issns_concat_id")
                           .distinct())
    
    # find crossref journals that have NO matching ISSNs
    crossref_completely_unmatched = (spark.table("openalex.sources.crossref_journals_gold")
                                   .join(crossref_with_matches, "issns_concat_id", "left_anti"))
    
    # join with the stable ID mapping
    id_mapping = dlt.read("crossref_id_mapping")
    
    return (crossref_completely_unmatched
           .join(id_mapping, "issns_concat_id", "inner")
           .select(
               col("openalex_id").alias("id"),
               col("title").alias("display_name"),
               col("issns"),
               col("publisher"),
               lit(False).alias("is_oa"),
               lit("journal").alias("type")
           )
           .withColumn("issn", when(size(col("issns")) > 0, col("issns")[0]).otherwise(lit(None)))
    )
    
@dlt.table(
    name="datacite_sources_unmatched",
    comment="Datacite sources that have NO match existing sources"
)
def datacite_sources_unmatched():
    df = spark.read.table("openalex.sources.datacite_sources")
    df_mapping = spark.read.table("openalex.sources.datacite_to_openalex_mapping")
    
    # unmatched datacite sources
    unmatched_df = df.join(
        df_mapping,
        col("id") == col("datacite_id"),
        "left"
    ).filter(col("openalex_id").isNull())

    # join with the stable ID mapping - alias the openalex_id to avoid ambiguity
    id_mapping = dlt.read("datacite_id_mapping_extended").select(
        col("datacite_id").alias("mapping_datacite_id"),
        col("openalex_id").alias("new_openalex_id"),
        col("created_date").alias("mapping_created_date")
    )

    return (unmatched_df
        .join(id_mapping, 
              unmatched_df["id"] == id_mapping["mapping_datacite_id"], 
              "inner")
        .select(
            col("new_openalex_id").alias("id"),
            col("id").alias("datacite_id"),
            col("display_name"),
            array_distinct(
                array_compact(
                    array(
                        col("issns.issnl"),
                        col("issns.print"),
                        col("issns.electronic")
                    )
                )
            ).alias("issns"),
            col("provider_name"),
            col("type")
        )
        .withColumn("issn", when(size(col("issns")) > 0, col("issns")[0]).otherwise(lit(None)))
        .withColumn("publisher", col("provider_name"))
        .withColumn("is_oa", lit(True))
        .withColumn("type", 
            when(col("type") == "periodical", "journal")
            .otherwise(col("type"))
        )
        .withColumn("updated_date", current_timestamp())
        .withColumn("created_date", current_timestamp().cast("string"))
        .drop("provider_name")
    )

@dlt.table(
    name="sources",
    comment=f"Combined sources with DOAJ status, J-STAGE status, and sample PMH records in {ENV.upper()}"
)
def sources():
    # combine base sources with unmatched crossref and datacite sources
    base_combined = (
        dlt.read("base_sources")
        .unionByName(
            dlt.read("crossref_journals_unmatched"),
            allowMissingColumns=True
        )
        .unionByName(
            dlt.read("datacite_sources_unmatched"),
            allowMissingColumns=True
        )
    )
    
    # get DOAJ ISSNs
    doaj = (
        spark.table("openalex.sources.doaj_from_csv")
        .withColumn(
            "doaj_license_normalized",
            when(col("journal_license") == "CC BY", "cc-by")
            .when(col("journal_license") == "CC BY-NC", "cc-by-nc")
            .when(col("journal_license") == "CC BY-NC-ND", "cc-by-nc-nd")
            .when(col("journal_license") == "CC BY-NC-SA", "cc-by-nc-sa")
            .when(col("journal_license") == "CC BY-SA", "cc-by-sa")
            .when(col("journal_license") == "CC BY-ND", "cc-by-nd")
            .when(col("journal_license") == "Public domain", "public-domain")
            .otherwise(None)  # For multi-license cases or unrecognized licenses
        )
        .selectExpr("explode(issns) as doaj_issn", "oa_start_year", "doaj_license_normalized")
        .distinct()
    )

    # get J-STAGE ISSNs
    jstage = (
        spark.table("openalex.sources.jstage_oa")
        .selectExpr("explode(issns) as jstage_issn", "jstage_oa_start_year", "jstage_oa_end_year")
        .distinct()
    )

    # get scielo issns (all are oa)
    scielo = (
        spark.table("openalex.sources.crossref_journals_gold")
        .filter(lower(col("publisher")).startswith("scielo"))
        .selectExpr("explode(issns) as scielo_issn", "publisher as scielo_publisher")
        .distinct()
    )

    ojs = (
        spark.table("openalex.sources.ojs_journals")
        .select(
            col("issn").alias("ojs_issn"),
            col("is_oa").alias("ojs_is_oa")
        )
        .distinct()
    )

    # new curation process from postgres
    approved_curations_window = Window.partitionBy("source_id", "property").orderBy(col("moderated_date").desc())
    
    approved_curations = (
        spark.table("openalex.curations.approved_curations")
        .filter(col("entity") == "sources")
        .filter(col("status") == "approved")
        .withColumn("source_id", regexp_replace(col("entity_id"), "^S", "").cast("long"))
        .withColumn("row_num", row_number().over(approved_curations_window))
        .filter(col("row_num") == 1)  # Get most recent curation per source
        .select(
            col("source_id"),
            col("property"),
            col("property_value")
        )
    )

    # Pivot curations to get oa_flip_year and is_oa columns
    curations_pivoted = (
        approved_curations
        .groupBy("source_id")
        .agg(
            max(when(col("property") == "oa_flip_year", col("property_value").cast("int"))).alias("curation_oa_flip_year"),
            max(when(col("property") == "is_oa", col("property_value").cast("boolean"))).alias("curation_is_oa")
        )
    )

    # get curation requests and convert to high OA rate table format
    base_high_oa_rate_issns = (
        spark.table("openalex.sources.high_oa_rate_issns")
        .select("issn_l", "oa_year")
        .distinct()
    )

    # Get most recent approved curation per ISSN (even if sets value to false)
    window = Window.partitionBy("issn").orderBy(col("ingestion_timestamp").desc())

    curation_requests = (
        spark.table("openalex.unpaywall.journal_curation_requests")
        .filter(col("approved") == "yes")
        .withColumn("row_num", row_number().over(window))
        .filter(col("row_num") == 1)
        .select(
            col("issn").alias("issn_l"),
            col("new_is_oa").alias("is_oa_high_oa_rate"),
            col("new_oa_date").cast("int").alias("high_oa_rate_start_year")
        )
    )

    # Apply curation overrides — if curation exists, it takes full priority
    high_oa_rate_issns = (
        base_high_oa_rate_issns.alias("base")
        .join(curation_requests.alias("cur"), on="issn_l", how="outer")
        .select(
            coalesce(col("cur.issn_l"), col("base.issn_l")).alias("issn_l"),
            when(col("cur.is_oa_high_oa_rate").isNotNull(), col("cur.is_oa_high_oa_rate"))
                .otherwise(col("base.oa_year").isNotNull()).alias("is_oa_high_oa_rate"),
            when(col("cur.issn_l").isNotNull(), col("cur.high_oa_rate_start_year"))
                .otherwise(col("base.oa_year")).alias("high_oa_rate_start_year")
        )
    )
    
    # process records with and without ISSNs
    sources_with_issns = (
        base_combined
        .filter(col("issns").isNotNull() & (size(col("issns")) > 0))
        .withColumn("exploded_issn", explode(col("issns")))
        .join(doaj, col("exploded_issn") == doaj["doaj_issn"], "left")
        .join(jstage, col("exploded_issn") == jstage["jstage_issn"], "left")
        .join(scielo, col("exploded_issn") == scielo["scielo_issn"], "left")
        .join(ojs, col("exploded_issn") == ojs["ojs_issn"], "left")
        .join(high_oa_rate_issns, col("exploded_issn") == high_oa_rate_issns["issn_l"], "left")
        .withColumn("is_in_doaj", 
                    when(doaj["doaj_issn"].isNotNull(), True).otherwise(False))
        .withColumn("is_in_scielo", 
                when(scielo["scielo_issn"].isNotNull(), True).otherwise(False))
        .withColumn("is_ojs", 
                    when(ojs["ojs_issn"].isNotNull(), True).otherwise(False))
        .withColumn("is_oa_high_oa_rate", 
            when(lower(col("publisher")).rlike("^(mdpi|academic journals|edorium journals)"), True)
            .when(scielo["scielo_issn"].isNotNull(), True)
            .when(ojs["ojs_is_oa"] == True, True)
            .otherwise(coalesce(high_oa_rate_issns["is_oa_high_oa_rate"], lit(False))))
        .withColumn("high_oa_rate_start_year",
            when(lower(col("publisher")).rlike("^(mdpi|academic journals|edorium journals)"), lit(None).cast("int"))
            .when(scielo["scielo_issn"].isNotNull(), lit(None).cast("int"))
            .when(ojs["ojs_is_oa"] == True, lit(None).cast("int"))
            .otherwise(high_oa_rate_issns["high_oa_rate_start_year"]))
        .groupBy("id", *[c for c in base_combined.columns if c not in ["id", "issns"]])
        .agg( # make sure the issns are sorted and deduplicated - avoid non-deterministic sorting
            array_sort(collect_set("exploded_issn")).alias("issns"),
            
            # doaj
            max(when(col("doaj_issn").isNotNull(), lit(True)).otherwise(lit(False))).alias("is_in_doaj"),
            min(col("oa_start_year")).alias("is_in_doaj_start_year"),
            first(col("doaj_license_normalized"), ignorenulls=True).alias("doaj_license"),

            # scielo
            max(when(col("scielo_issn").isNotNull(), lit(True)).otherwise(lit(False))).alias("is_in_scielo"),

            # ojs
            max(when(col("ojs_issn").isNotNull(), lit(True)).otherwise(lit(False))).alias("is_ojs"),

            # j-stage data for later use
            max(col("jstage_oa_start_year")).alias("jstage_oa_start_year"),
            max(col("jstage_oa_end_year")).alias("jstage_oa_end_year"),
            max(when(col("jstage_issn").isNotNull(), lit(True)).otherwise(lit(False))).alias("has_jstage_issn"),

            # high oa rate
            max(
                when(lower(col("publisher")).rlike("^(mdpi|academic journals|edorium journals)"), lit(True))
                .when(col("scielo_issn").isNotNull(), lit(True))
                .when(col("ojs_is_oa") == True, lit(True))
                .otherwise(coalesce(col("is_oa_high_oa_rate"), lit(False)))
            ).alias("is_oa_high_oa_rate"),

            # high oa rate start year
            min(
                when(lower(col("publisher")).rlike("^(mdpi|academic journals|edorium journals)"), lit(None).cast("int"))
                .when(col("scielo_issn").isNotNull(), lit(None).cast("int"))
                .when(col("ojs_is_oa") == True, lit(None).cast("int"))
                .otherwise(col("high_oa_rate_start_year"))
            ).alias("high_oa_rate_start_year")
        )
        .drop("doaj_issn", "jstage_issn", "scielo_issn", "ojs_issn", "issn_l")
        .withColumn("issn", when(size(col("issns")) > 0, col("issns")[0]).otherwise(lit(None)))
        .withColumn("rank", row_number().over( # deduplicate by ISSN
            Window.partitionBy("issn").orderBy(
                size("issns").desc(),
                length("display_name").desc(),
                col("id").asc()
            )
        ))
        .filter((col("issn").isNull()) | (col("rank") == 1))
        .drop("rank")
    )
    
    sources_null_issns = (
        base_combined
        .filter((col("issns").isNull()) | (size(col("issns")) == 0))
        .withColumn("is_in_doaj", lit(False))
        .withColumn("doaj_license", lit(None).cast("string"))
        .withColumn("is_in_scielo", lit(False))
        .withColumn("is_ojs", lit(False))
        .withColumn("jstage_oa_start_year", lit(None).cast("int"))
        .withColumn("jstage_oa_end_year", lit(None).cast("int"))
        .withColumn("has_jstage_issn", lit(False))
        .withColumn("is_oa_high_oa_rate", 
            when(lower(col("publisher")).rlike("^(mdpi|academic journals|edorium journals)"), True)
            .otherwise(lit(False)))
        .withColumn("oa_start_year", lit(None).cast("int"))
        .withColumn("high_oa_rate_start_year", lit(None).cast("int"))
        .withColumnRenamed("oa_start_year", "is_in_doaj_start_year")
        .withColumn("datacite_id", col("datacite_id"))
    )
    
    # combine ISSN and non-ISSN records
    sources_with_doaj_jstage_and_oa = (
        sources_with_issns
        .unionByName(sources_null_issns)
        .withColumn("rank", row_number().over(
            Window.partitionBy("id").orderBy(
                # prefer records that have ISSNs in de-duplication
                (col("issns").isNotNull() & (size(col("issns")) > 0)).desc(),
                col("updated_date").desc(),
                length("display_name").desc(),
                col("created_date").desc()
            )
        ))
        .filter(col("rank") == 1)
        .drop("rank", "has_issns")
    )

    sources_with_curations = (
        sources_with_doaj_jstage_and_oa
        .join(curations_pivoted, sources_with_doaj_jstage_and_oa["id"] == curations_pivoted["source_id"], "left")
        .withColumn("high_oa_rate_start_year",
            when(col("curation_is_oa") == False, lit(None))
            .when(col("curation_oa_flip_year").isNotNull(), col("curation_oa_flip_year") + 1)
            .otherwise(col("high_oa_rate_start_year")))
        .withColumn("is_oa_high_oa_rate",
            when(col("curation_is_oa").isNotNull(), col("curation_is_oa"))
            .when(col("high_oa_rate_start_year").isNotNull(), lit(True))
            .otherwise(col("is_oa_high_oa_rate")))
        .drop("source_id", "curation_oa_flip_year", "curation_is_oa")
    )

    # update is_oa column based on is_oa_high_oa_rate
    sources_with_updated_oa = (
        sources_with_curations
        .withColumn("is_oa",
                    when(col("is_in_doaj") == True, True)
                    .when(col("is_in_scielo") == True, True)
                    .when(col("is_oa_high_oa_rate") == True, True)
                    .when(col("is_oa_high_oa_rate") == False, False)
                    .otherwise(col("is_oa")))
    )

    sources_api_years = (
        spark.table("openalex.sources.sources_api")
        .select("id", "first_publication_year", "last_publication_year")
    )

    sources_with_jstage = (
        sources_with_updated_oa
        .join(sources_api_years, "id", "left")
        .withColumn(
            "is_fully_open_in_jstage",
            (
                col("has_jstage_issn") &
                col("first_publication_year").cast("int").isNotNull() &
                col("last_publication_year").cast("int").isNotNull() &
                col("jstage_oa_start_year").cast("int").isNotNull() &
                col("jstage_oa_end_year").cast("int").isNotNull() &
                (col("jstage_oa_start_year").cast("int") <= col("first_publication_year").cast("int")) &
                (col("jstage_oa_end_year").cast("int") >= col("last_publication_year").cast("int"))
            )
        )
        .withColumn(
            "is_oa",
            when(col("is_in_doaj"), True)
            .when(col("is_fully_open_in_jstage"), True)
            .when(col("is_in_scielo"), True)
            .when(col("is_oa_high_oa_rate"), True)
            .when(~col("is_oa_high_oa_rate"), False)
            .otherwise(col("is_oa"))
        )
        .drop(
            "first_publication_year", "last_publication_year",
            "has_jstage_issn", "jstage_oa_start_year", "jstage_oa_end_year"
        )
    )

    
    # join with endpoint mapping
    endpoints_table = (
        spark.table("openalex.sources.endpoint_mapping")
        .select("endpoint_id", "sample_pmh_record")
    )

    # join with datacite id mapping to get all datacite_ids for a given openalex_id
    datacite_id_mapping = (
        spark.table("openalex.sources.datacite_to_openalex_mapping")
        .groupBy("openalex_id")
        .agg(
            array_sort(collect_set("datacite_id")).alias("all_datacite_ids")
        )
    )

    final_result = (
        sources_with_jstage
        .join(
            endpoints_table,
            sources_with_jstage["repository_id"] == endpoints_table["endpoint_id"],
            "left"
        )
        .join(
            datacite_id_mapping,
            sources_with_jstage["id"] == datacite_id_mapping["openalex_id"],
            "left"
        )
        .withColumn("datacite_ids", 
            when(col("datacite_id").isNotNull(), 
                array(col("datacite_id")))
            .when(col("all_datacite_ids").isNotNull(), 
                col("all_datacite_ids"))
            .otherwise(array())
        )
        .withColumn("type",
            when(
                lower(col("display_name")).contains("rxiv") |
                lower(col("display_name")).contains("research square"), 
                "repository"
            )
            .otherwise(col("type"))
        )
        # remove "Deleted Journal" and other bad journal from sources
        .filter(~col("id").isin(4317411217, 4363604846))
        .drop("endpoint_id", "openalex_id", "all_datacite_ids")
    )

    # add preprints sources, need to move this to separate table
    preprints_org = spark.createDataFrame([
        {
            "id": 6309402219,
            "display_name": "Preprints.org",
            "publisher": "MDPI AG",
            "webpage": "https://preprints.org",
            "type": "repository",
            "publisher_id": 4310310987,
            "is_oa": True,
            "is_oa_high_oa_rate": True,
            "is_in_doaj": False,
            "is_in_scielo": False,
            "is_ojs": False,
            "is_fully_open_in_jstage": False,
            "updated_date": "2025-08-12",
            "created_date": "2025-08-12"
        }
    ]).withColumn("datacite_ids", array())

    return final_result.unionByName(preprints_org, allowMissingColumns=True)