In [0]:
import dlt
import pyspark.sql.functions as F
from pyspark.sql.window import Window


def parse_investigator(inv_col):
    """
    Parses a Crossref investigator struct (or array element).
    Used for: lead-investigator, co-lead-investigator, and the investigators list.
    """
    return F.struct(
        inv_col["given"].alias("given_name"),
        inv_col["family"].alias("family_name"),
        inv_col["ORCID"].alias("orcid"),
        
        F.when(
            inv_col["role-start"]["date-parts"][0][2].isNotNull(),
            F.make_date(
                inv_col["role-start"]["date-parts"][0][0],
                inv_col["role-start"]["date-parts"][0][1],
                inv_col["role-start"]["date-parts"][0][2]
            )
        ).alias("role_start"),
        
        F.when(
            F.size(inv_col["affiliation"]) > 0,
            F.struct(
                inv_col["affiliation"][0]["name"].alias("name"),
                inv_col["affiliation"][0]["country"].alias("country"),
                # Nested transform for IDs inside affiliation
                F.transform(
                    inv_col["affiliation"][0]["id"],
                    lambda x: F.struct(
                        x["id"].alias("id"),
                        x["id-type"].alias("type"),
                        x["asserted-by"].alias("asserted_by")
                    )
                ).alias("ids")
            )
        ).alias("affiliation")
    )


@dlt.table(name="crossref_grants")
def crossref_grants():
    return (
        spark.read.table('openalex.crossref.crossref_items')
        .select(F.inline('items'))
        .filter(F.col("type") == "grant")
    )


@dlt.table(name="crossref_grants_deduplicated")
def crossref_grants_deduplicated():
    df_grants = dlt.read("crossref_grants")
    
    window = Window.partitionBy("DOI").orderBy(F.col("indexed.timestamp").desc())
    
    return (
        df_grants
        .withColumn("row_num", F.row_number().over(window))
        .filter(F.col("row_num") == 1)
        .drop("row_num")
    )


@dlt.table(name="crossref_awards")
def crossref_awards():
    df_grants = dlt.read("crossref_grants_deduplicated")
    df_funders = spark.read.table('openalex.common.funder')
    
    # --- Step A: Stage the Grants Data ---
    df_staged = df_grants.select(
        "*",
        F.col("project").getItem(0).alias("proj_struct"),
        F.col("project").getItem(0)["funding"][0].alias("fund_struct"),
        F.col("project").getItem(0)["funding"][0]["funder"]["id"].getItem(0).alias("funder_id_struct")
    ).select(
        "*",
        F.col("proj_struct.award-start.date-parts").getItem(0).alias("start_parts"),
        F.col("proj_struct.award-end.date-parts").getItem(0).alias("end_parts"),
        F.when(F.col("funder_id_struct.id-type") == "ROR", F.col("funder_id_struct.id")).alias("join_ror_id"),
        F.when(F.col("funder_id_struct.id-type") == "DOI", F.col("funder_id_struct.id")).alias("join_doi")
    )

    # --- Step B: Stage the Funders Data ---
    df_funders_ready = df_funders.select(
        F.col("funder_id").alias("f_funder_id"),
        F.col("display_name").alias("f_display_name"),
        F.col("ror_id").alias("f_ror_id"),
        F.col("doi").alias("f_doi")
    )

    # --- Step C: Join funders and final projection ---
    return (
        df_staged
        .join(
            F.broadcast(df_funders_ready),
            (F.col("join_doi") == F.col("f_doi")) | (F.col("join_ror_id") == F.col("f_ror_id")),
            "left"
        )
        .select(
            F.col("DOI").alias("id"),
            F.col("proj_struct.project-title")[0]["title"].alias("display_name"),
            F.col("proj_struct.project-description")[0]["description"].alias("description"),
            F.col("award").alias("funder_award_id"),
            F.col("fund_struct.award-amount.amount").alias("amount"),
            F.col("fund_struct.award-amount.currency").alias("currency"),
            F.when(
                F.col("f_funder_id").isNotNull(),
                F.struct(
                    F.concat(F.lit("https://openalex.org/F"), F.col("f_funder_id")).alias("id"),
                    F.coalesce(F.col("f_display_name"), F.col("fund_struct.funder.name")).alias("display_name"),
                    F.col("f_ror_id").alias("ror_id"),
                    F.col("f_doi").alias("doi")
                )
            ).alias("funder"),
            F.col("fund_struct.type").alias("funding_type"),
            F.col("fund_struct.scheme").alias("funder_scheme"),
            F.lit("crossref").alias("provenance"),
            F.when(F.col("start_parts")[2].isNotNull(), 
                   F.make_date(F.col("start_parts")[0], F.col("start_parts")[1], F.col("start_parts")[2])).alias("start_date"),
            F.when(F.col("end_parts")[2].isNotNull(), 
                   F.make_date(F.col("end_parts")[0], F.col("end_parts")[1], F.col("end_parts")[2])).alias("end_date"),
            F.col("start_parts")[0].alias("start_year"),
            F.col("end_parts")[0].alias("end_year"),
            F.transform(
                F.col("proj_struct.lead-investigator"), 
                parse_investigator
            ).getItem(0).alias("lead_investigator"),
            F.transform(
                F.col("proj_struct.co-lead-investigator"), 
                parse_investigator
            ).getItem(0).alias("co_lead_investigator"),
            F.transform(
                F.col("proj_struct.investigator"), 
                parse_investigator
            ).alias("investigators"),
            F.col("resource.primary.URL").alias("landing_page_url"),
            F.col("URL").alias("doi"),
            F.to_timestamp(F.col("created.date-time")).alias("created_date"),
            F.to_timestamp(F.col("indexed.date-time")).alias("updated_date")
        )
    )