#### Format works same as API and save to `openalex.works.openalex_works_snapshot`

In [0]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import math
import time

from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
from pyspark.sql.window import Window

#### Create `openalex.works.openalex_works_snapshot` in same format as API

In [0]:
@udf(StringType())
def truncate_abstract_index_string(raw_json: str, max_bytes: int = 32760) -> str:
    try:
        if not raw_json:
            return None
        
        try:
            json.loads(raw_json)
        except (json.JSONDecodeError, ValueError):
            # Invalid JSON - return None
            return None

        if len(raw_json) <= (max_bytes // 4):
            return raw_json

        encoded = raw_json.encode('utf-8')
        if len(encoded) <= max_bytes:
            return raw_json

        truncated = encoded[:max_bytes].decode('utf-8', errors='ignore')
        last_bracket = truncated.rfind(']')
        if last_bracket == -1:
            return None

        return truncated[:last_bracket + 1] + '}'
    except Exception:
        return None


def sanitize_name(col_name: str):
    """
    Cleans a string column by removing unwanted characters and normalizing whitespace.
    Handles multilingual text by preserving letters, numbers, punctuation, and symbols from all Unicode scripts.
    """
    unwanted_chars_pattern = r"[^\p{L}\p{N}\p{P}\p{S}\p{Z}]"
    multiple_spaces_pattern = r"\s+"

    return F.trim(
        F.regexp_replace( 
            F.regexp_replace(F.col(col_name), unwanted_chars_pattern, ""),
            multiple_spaces_pattern, " "
        )
    )


def sanitize_string(col_name: str, max_len: int = 32000):
    return F.when(F.col(col_name).isNotNull(), F.substring(F.col(col_name), 1, max_len)).otherwise(None)


empty_sdg_array = F.array().cast("array<struct<id:string,display_name:string,score:double>>")

df_transformed = (
    spark.read.table("openalex.works.openalex_works")
    .withColumn("display_name", F.col("title"))
    .withColumn("created_date", F.to_timestamp("created_date"))
    .withColumn("updated_date", F.to_timestamp("updated_date"))
    .withColumn("publication_date", F.to_date("publication_date"))
    .withColumn(
        "concepts",
        F.transform(
            F.col("concepts"),
            lambda c: F.struct(
                F.concat(F.lit("https://openalex.org/C"), c.id).alias("id"),
                c.wikidata.alias("wikidata"),
                c.display_name.alias("display_name"),
                c.level.alias("level"),
                c.score.alias("score")
            )
        )
    )
    .withColumn(
        "created_date",
        F.when(
            F.col("created_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("created_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "updated_date",
        F.when(
            F.col("updated_date").between(F.lit("1000-01-01"), F.lit("9999-12-31")),
            F.col("updated_date")
        ).otherwise(F.lit(None).cast("timestamp"))
    )
    .withColumn(
        "publication_date",
        F.when(
            F.col("publication_date").between(F.lit("1000-01-01"), F.lit("2050-12-31")),
            F.col("publication_date")
        ).otherwise(F.lit(None).cast("date"))
    )
    .withColumn("id", F.concat(F.lit("https://openalex.org/W"), F.col("id")))
    .withColumn("publication_year", F.year("publication_date"))
    .withColumn("title", sanitize_name("title"))
    .withColumn("display_name", sanitize_name("display_name"))
    .withColumn("ids", 
        F.transform_values("ids",
            lambda k, v: F.when(k == "doi", 
                    F.concat(F.lit("https://doi.org/"),v)).otherwise(v)
        )
    )
    .withColumn("doi", sanitize_string("doi"))
    .withColumn("language", sanitize_string("language"))
    .withColumn("type", sanitize_string("type"))
    .withColumn("abstract", sanitize_string("abstract"))
    .withColumn("referenced_works", 
                F.expr("transform(referenced_works, x -> 'https://openalex.org/W' || x)"))
    .withColumn("referenced_works_count", 
                F.when(F.col("referenced_works").isNotNull(), F.size("referenced_works")).otherwise(0))
    .withColumn("abstract_inverted_index", truncate_abstract_index_string(F.col("abstract_inverted_index")))
    .withColumn("open_access", F.struct(
        F.col("open_access.is_oa"),
        sanitize_string("open_access.oa_status").alias("oa_status"),
        F.lit(False).cast("boolean").alias("any_repository_has_fulltext"),
        F.col("open_access.oa_url")
    ))
    .withColumn("authorships", F.expr("""
        transform(authorships, x -> named_struct(
            'affiliations', x.affiliations,
            'author', x.author,
            'author_position', substring(x.author_position, 1, 32000),
            'countries', x.countries,
            'raw_author_name', substring(x.raw_author_name, 1, 32000),
            'is_corresponding', x.is_corresponding,
            'raw_affiliation_strings', transform(x.raw_affiliation_strings, aff -> substring(aff, 1, 32000)),
            'institutions', x.institutions
        ))
    """))
    .withColumn("locations", F.expr("""
        transform(locations, x -> named_struct(
            'is_oa', x.is_oa,
            'is_published', x.version = 'publishedVersion',
            'landing_page_url', substring(x.landing_page_url, 1, 32000),
            'pdf_url', substring(x.pdf_url, 1, 32000),
            'source', x.source,
            'raw_source_name', x.raw_source_name,
            'native_id', x.native_id,
            'provenance', x.provenance,
            'license', x.license,
            'license_id', x.license_id
        ))
    """))
    .withColumn("concepts", F.slice(F.col("concepts"), 1, 40))
    .withColumn("indexed_in", F.expr("""
        array_sort(
            array_distinct(
                array_compact(
                    flatten(
                        TRANSFORM(locations, loc ->
                            CASE
                            WHEN loc.provenance IN ('crossref', 'pubmed', 'datacite')
                                THEN array(loc.provenance, IF(loc.source.is_in_doaj, 'doaj', NULL))
                            WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:arxiv.org%'
                                THEN array('arxiv')
                            WHEN loc.provenance = 'repo' AND lower(loc.native_id) like 'oai:doaj.org/%'
                                THEN array('doaj')
                            WHEN loc.provenance = 'mag' AND lower(loc.source.display_name) = 'pubmed'
                                THEN array('pubmed')
                            ELSE array()
                            END
                        )
                    )
                )
            )
        )
    """))
    .withColumn("corresponding_author_ids", F.coalesce(F.col("corresponding_author_ids"), F.lit([])))
    .withColumn("corresponding_institution_ids", F.coalesce(F.col("corresponding_institution_ids"), F.lit([])))
    .withColumn("sustainable_development_goals", F.coalesce(F.col("sustainable_development_goals"), empty_sdg_array))
    .withColumn("related_works", F.coalesce(F.col("related_works"), F.lit([])))
    .withColumn("fwci", F.coalesce(F.col("fwci"), F.lit(0)))
    .withColumn("mesh", F.coalesce(F.col("mesh"), F.lit([])))
)

df_transformed.write \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("openalex.works.openalex_works_snapshot")

#### Export in json lines format to S3

In [0]:
entity_type = "works"
date_str = datetime.now().strftime("%Y-%m-%d")
RECORDS_PER_FILE = 400000
s3_base_path = f"s3://openalex-sandbox/snapshots/{date_str}"
output_path = f"{s3_base_path}/{entity_type}"

def export():
    print(f"Starting export to: {output_path}")
    print(f"Records per file: {RECORDS_PER_FILE:,}")
    
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")
    spark.conf.set("spark.sql.shuffle.partitions", "2000")
    
    df = spark.read.table("openalex.works.openalex_works_snapshot")
    df = df.withColumn("updated_date", F.to_date("updated_date"))
    
    df_with_count = df.join(
        df.groupBy("updated_date").count().withColumnRenamed("count", "date_count"),
        on="updated_date"
    )

    date_stats = df_with_count.select("updated_date", "date_count").distinct().orderBy(F.desc("date_count")).collect()
    print("\nDate distribution (top 10):")
    for row in date_stats[:10]:
        expected_files = (row['date_count'] + RECORDS_PER_FILE - 1) // RECORDS_PER_FILE
        print(f"  {row['updated_date']}: {row['date_count']:,} records → {expected_files} files expected")
    
    # apply hash-based salting for predictable distribution
    df_salted = df_with_count.withColumn(
        "salt",
        F.when(F.col("date_count") > 100_000_000, F.abs(F.hash("id")) % 1400)
        .when(F.col("date_count") > 40_000_000, F.abs(F.hash("id")) % 160)
        .when(F.col("date_count") > 10_000_000, F.abs(F.hash("id")) % 50)
        .when(F.col("date_count") > 5_000_000, F.abs(F.hash("id")) % 25)
        .when(F.col("date_count") > 2_000_000, F.abs(F.hash("id")) % 10)
        .when(F.col("date_count") > 800_000, F.abs(F.hash("id")) % 3)
        .otherwise(0)
    ).drop("date_count")
    
    print("\nRepartitioning and writing to S3...")
    df_out = df_salted.repartition(F.col("updated_date"), F.col("salt")).drop("salt")
    
    (df_out.write
         .mode("overwrite")
         .option("compression", "gzip")
         .option("maxRecordsPerFile", RECORDS_PER_FILE)
         .partitionBy("updated_date")
         .json(output_path))
    
    print("Export completed!")

export()

#### Rename the files into sequential numbers, remove spark metadata

In [0]:
def rename_files_and_cleanup(output_path, max_workers=30):
    partitions = dbutils.fs.ls(output_path)
    partitions_to_process = [p for p in partitions if p.name.startswith("updated_date=")]
    
    print(f"Found {len(partitions_to_process)} partitions to process")
    
    def process_single_partition(partition):
        try:
            # first, get all files in the directory
            files = dbutils.fs.ls(partition.path)
            
            # separate already-renamed files from those needing renaming
            already_renamed = []
            needs_renaming = []
            metadata_files = []
            
            for f in files:
                if f.name.startswith('part_') and f.name.endswith('.gz'):
                    already_renamed.append(f)
                elif f.name.endswith('.json.gz'):
                    needs_renaming.append(f)
                else:
                    metadata_files.append(f)
            
            # sort both lists to ensure consistent ordering
            already_renamed.sort(key=lambda x: x.name)
            needs_renaming.sort(key=lambda x: x.name)
            
            total_data_files = len(already_renamed) + len(needs_renaming)
            
            if len(needs_renaming) == 0:
                # already processed
                return partition.name, True, f"{total_data_files} files already renamed"
            
            # find the highest existing part number
            max_existing = -1
            for f in already_renamed:
                try:
                    num_str = f.name.replace('part_', '').replace('.gz', '')
                    num = int(num_str)
                    max_existing = max(max_existing, num)
                except:
                    pass
            
            # start renaming from the next available number
            start_idx = max_existing + 1
            
            renamed_count = 0
            for idx, file_info in enumerate(needs_renaming):
                new_number = start_idx + idx
                new_name = f"part_{str(new_number).zfill(4)}.gz"
                new_path = f"{partition.path}{new_name}"
                
                try:
                    # check if target already exists (safety check)
                    existing = [f for f in files if f.name == new_name]
                    if existing:
                        print(f"  WARNING: {new_name} already exists in {partition.name}, skipping")
                        continue
                    
                    dbutils.fs.mv(file_info.path, new_path)
                    renamed_count += 1
                    
                    # Progress indicator for large directories
                    if renamed_count % 100 == 0:
                        print(f"  {partition.name}: Renamed {renamed_count}/{len(needs_renaming)} files...")
                        
                except Exception as e:
                    print(f"  Error renaming in {partition.name}: {e}")
            
            # clean up metadata files
            cleanup_count = 0
            for f in metadata_files:
                try:
                    dbutils.fs.rm(f.path)
                    cleanup_count += 1
                except:
                    pass
            
            return partition.name, True, f"{renamed_count} renamed, {len(already_renamed)} existing, {cleanup_count} cleaned"
            
        except Exception as e:
            return partition.name, False, str(e)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_single_partition, p): p for p in partitions_to_process}
        
        completed = 0
        start_time = time.time()
        
        for future in as_completed(futures):
            partition_name, success, message = future.result()
            completed += 1
            elapsed = time.time() - start_time
            
            if success:
                print(f"  [{completed}/{len(partitions_to_process)}] ✓ {partition_name}: {message} ({elapsed:.1f}s)")
            else:
                print(f"  [{completed}/{len(partitions_to_process)}] ✗ {partition_name}: Error - {message}")

#### Create manifest

In [0]:
def create_manifest():
    """
    Create a manifest file with all file metadata using parallel processing.
    """
    output_path = f"{s3_base_path}/{entity_type}"
    
    print(f"\nCreating manifest...")
    
    partitions = dbutils.fs.ls(output_path)
    partitions_to_process = sorted([p for p in partitions if p.name.startswith("updated_date=")], 
                                   key=lambda x: x.name, reverse=True)
    
    def process_file(partition_name, file_info):
        """Process a single file to get its metadata"""
        if not file_info.name.endswith('.gz'):
            return None
            
        try:
            # count records in the file
            record_count = spark.read.json(file_info.path).count()
            
            s3_url = file_info.path.replace("dbfs:/", "s3://")
            entry = {
                "url": s3_url,
                "meta": {
                    "content_length": file_info.size,
                    "record_count": record_count
                }
            }
            
            return {
                "entry": entry,
                "partition": partition_name,
                "file": file_info.name,
                "size": file_info.size,
                "count": record_count
            }
        except Exception as e:
            print(f"Error processing {partition_name}{file_info.name}: {e}")
            return None
    
    # collect all file tasks
    file_tasks = []
    for partition in partitions_to_process:
        files = dbutils.fs.ls(partition.path)
        for file_info in files:
            if file_info.name.endswith('.gz'):
                file_tasks.append((partition.name, file_info))
    
    print(f"Processing {len(file_tasks)} files across {len(partitions_to_process)} partitions...")
    
    # process files in parallel
    entries = []
    total_content_length = 0
    total_record_count = 0
    
    with ThreadPoolExecutor(max_workers=50) as executor:  # Increase workers for file reading
        futures = {executor.submit(process_file, task[0], task[1]): task 
                  for task in file_tasks}
        
        completed = 0
        for future in as_completed(futures):
            result = future.result()
            completed += 1
            
            if result:
                entries.append(result["entry"])
                total_content_length += result["size"]
                total_record_count += result["count"]
                
                if completed % 50 == 0 or completed == len(file_tasks):
                    print(f"  Progress: {completed}/{len(file_tasks)} files processed...")
                
                # print details for large files
                if result["size"] > 100 * 1024 * 1024:  # Files > 100MB
                    print(f"  {result['partition']}{result['file']}: "
                          f"{result['count']:,} records, {result['size']/(1024*1024):.1f} MB")
    
    entries.sort(key=lambda x: x["url"])
    
    manifest = {
        "entries": entries,
        "meta": {
            "content_length": total_content_length,
            "record_count": total_record_count
        }
    }
    
    manifest_path = f"{output_path}/manifest"
    manifest_json = json.dumps(manifest, indent=2)
    dbutils.fs.put(manifest_path, manifest_json, overwrite=True)
    
    print(f"\nManifest created: {manifest_path}")
    print(f"Total files: {len(entries)}")
    print(f"Total size (compressed): {total_content_length / (1024**3):.2f} GB")
    print(f"Total records: {total_record_count:,}")
    
    return manifest

manifest = create_manifest()