In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import time
import threading

from datetime import datetime
from pyspark.sql import functions as F
from pyspark.sql.types import StringType

#### Create `openalex.authors.openalex_authors_snapshot` in same format as API

In [None]:
# Explicit field whitelist matching elastic sync
df_transformed = (
    spark.read.table("openalex.authors.openalex_authors")
    # Transform id to full URL format
    .withColumn("id", F.concat(F.lit("https://openalex.org/A"), F.col("id").cast("string")))
    # Limit topics and topic_share to first 5 (matching elastic sync)
    .withColumn("topics", F.slice(F.col("topics"), 1, 5))
    .withColumn("topic_share", F.slice(F.col("topic_share"), 1, 5))
    # Fix x_concepts: add URL prefix to id and rename col4 to level
    .withColumn("x_concepts", F.expr("""
        transform(x_concepts, c -> named_struct(
            'id', concat('https://openalex.org/C', cast(c.id as string)),
            'wikidata', c.wikidata,
            'display_name', c.display_name,
            'level', c.col4,
            'score', c.score,
            'count', c.count
        ))
    """))
    # Coalesce null arrays to empty arrays
    .withColumn("display_name_alternatives", F.coalesce(F.col("display_name_alternatives"), F.array()))
    .withColumn("affiliations", F.coalesce(F.col("affiliations"), F.array()))
    .withColumn("last_known_institutions", F.coalesce(F.col("last_known_institutions"), F.array()))
    .withColumn("topics", F.coalesce(F.col("topics"), F.array()))
    .withColumn("topic_share", F.coalesce(F.col("topic_share"), F.array()))
    .withColumn("x_concepts", F.coalesce(F.col("x_concepts"), F.array()))
    .withColumn("sources", F.coalesce(F.col("sources"), F.array()))
    .withColumn("counts_by_year", F.coalesce(F.col("counts_by_year"), F.array()))
    # Explicit field selection
    .select(
        "id",
        "display_name",
        "display_name_alternatives",
        "orcid",
        "works_count",
        "cited_by_count",
        "summary_stats",
        "ids",
        "affiliations",
        "last_known_institutions",
        "topics",
        "topic_share",
        "x_concepts",
        "sources",
        "counts_by_year",
        "works_api_url",
        "updated_date",
        "created_date"
    )
)

df_transformed.write \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("openalex.authors.openalex_authors_snapshot")

#### Export in json lines format to S3

In [None]:
entity_type = "authors"
date_str = datetime.now().strftime("%Y-%m-%d")
RECORDS_PER_FILE = 400000
s3_base_path = f"s3://openalex-sandbox/snapshots/{date_str}"
output_path = f"{s3_base_path}/{entity_type}"

def export():
    print(f"Starting export to: {output_path}")
    print(f"Records per file: {RECORDS_PER_FILE:,}")
    
    spark.conf.set("spark.sql.adaptive.enabled", "true")
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")
    spark.conf.set("spark.sql.shuffle.partitions", "2000")
    
    df = spark.read.table("openalex.authors.openalex_authors_snapshot")
    # Create partition column (this gets removed during partitionBy)
    # Keep updated_date in the data for the JSON output
    df = df.withColumn("_partition_date", F.to_date("updated_date"))
    
    # Use broadcast for the counts join - the aggregated counts table is tiny (~10K rows)
    # This avoids a second full scan of the dataset
    date_counts = df.groupBy("_partition_date").count().withColumnRenamed("count", "date_count")
    df_with_count = df.join(F.broadcast(date_counts), on="_partition_date")

    date_stats = df_with_count.select("_partition_date", "date_count").distinct().orderBy(F.desc("date_count")).collect()
    print("\nDate distribution (top 10):")
    for row in date_stats[:10]:
        expected_files = (row['date_count'] + RECORDS_PER_FILE - 1) // RECORDS_PER_FILE
        print(f"  {row['_partition_date']}: {row['date_count']:,} records → {expected_files} files expected")
    
    # apply hash-based salting for predictable distribution
    df_salted = df_with_count.withColumn(
        "salt",
        F.when(F.col("date_count") > 100_000_000, F.abs(F.hash("id")) % 1400)
        .when(F.col("date_count") > 40_000_000, F.abs(F.hash("id")) % 160)
        .when(F.col("date_count") > 10_000_000, F.abs(F.hash("id")) % 50)
        .when(F.col("date_count") > 5_000_000, F.abs(F.hash("id")) % 25)
        .when(F.col("date_count") > 2_000_000, F.abs(F.hash("id")) % 10)
        .when(F.col("date_count") > 800_000, F.abs(F.hash("id")) % 3)
        .otherwise(0)
    ).drop("date_count")
    
    print("\nRepartitioning and writing to S3...")
    df_out = df_salted.repartition(F.col("_partition_date"), F.col("salt")).drop("salt")
    
    (df_out.write
         .mode("overwrite")
         .option("compression", "gzip")
         .option("maxRecordsPerFile", RECORDS_PER_FILE)
         .partitionBy("_partition_date")
         .json(output_path))
    
    print("Export completed!")

export()

#### Rename the files into sequential numbers, remove spark metadata

In [None]:
def rename_files_and_cleanup(output_path, max_workers=30):
    from concurrent.futures import ThreadPoolExecutor, as_completed
    import threading
    import time

    partitions = dbutils.fs.ls(output_path)
    partitions_to_process = [p for p in partitions if p.name.startswith("_partition_date=")]
    
    print(f"Found {len(partitions_to_process)} partitions to process")
    print("Will rename directories from _partition_date= to updated_date= during file processing")
    
    def process_single_partition_fast(partition):
        """Process partition: rename files and move to updated_date= directory"""
        try:
            # Determine new directory path
            date_value = partition.name.replace("_partition_date=", "").rstrip("/")
            new_partition_path = f"{output_path}/updated_date={date_value}/"
            
            files = dbutils.fs.ls(partition.path)
            
            # categorize files
            json_files = []
            metadata_files = []
            
            for f in files:
                if f.name.endswith('.json.gz') or (f.name.startswith('part_') and f.name.endswith('.gz')):
                    json_files.append(f)
                else:
                    metadata_files.append(f)
            
            # sort by full name to preserve order
            json_files.sort(key=lambda x: x.name)
            
            if len(json_files) == 0:
                # Clean up empty partition
                for f in metadata_files:
                    try:
                        dbutils.fs.rm(f.path)
                    except:
                        pass
                try:
                    dbutils.fs.rm(partition.path, recurse=True)
                except:
                    pass
                return partition.name, True, "empty partition cleaned up"
            
            # Move and rename files to new partition path
            if len(json_files) > 100:
                print(f"  {partition.name}: Large directory ({len(json_files)} files), using parallel processing...")
                
                file_assignments = [(f, i) for i, f in enumerate(json_files)]
                
                counter_lock = threading.Lock()
                counter = {'moved': 0, 'errors': 0}
                
                def move_single_file(file_info, file_number):
                    try:
                        new_name = f"part_{str(file_number).zfill(4)}.gz"
                        new_path = f"{new_partition_path}{new_name}"
                        dbutils.fs.mv(file_info.path, new_path)
                        
                        with counter_lock:
                            counter['moved'] += 1
                            if counter['moved'] % 100 == 0:
                                print(f"    {partition.name}: {counter['moved']}/{len(json_files)} moved...")
                        return True
                    except Exception as e:
                        with counter_lock:
                            counter['errors'] += 1
                        return False
                
                with ThreadPoolExecutor(max_workers=50) as executor:
                    futures = [executor.submit(move_single_file, f, num) 
                              for f, num in file_assignments]
                    
                    for future in as_completed(futures):
                        future.result()
                
                moved_count = counter['moved']
                
            else:
                # small directories - sequential
                moved_count = 0
                for idx, file_info in enumerate(json_files):
                    new_name = f"part_{str(idx).zfill(4)}.gz"
                    new_path = f"{new_partition_path}{new_name}"
                    
                    try:
                        dbutils.fs.mv(file_info.path, new_path)
                        moved_count += 1
                    except Exception as e:
                        print(f"    Error: {e}")
            
            # clean up metadata files and old partition directory
            for f in metadata_files:
                try:
                    dbutils.fs.rm(f.path)
                except:
                    pass
            
            # Remove old empty partition directory
            try:
                dbutils.fs.rm(partition.path, recurse=True)
            except:
                pass
            
            return partition.name, True, f"{moved_count} files moved to updated_date={date_value}"
            
        except Exception as e:
            return partition.name, False, str(e)
    
    # process partitions
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_single_partition_fast, p): p for p in partitions_to_process}
        
        completed = 0
        start_time = time.time()
        
        for future in as_completed(futures):
            partition_name, success, message = future.result()
            completed += 1
            elapsed = time.time() - start_time
            
            if success:
                print(f"  [{completed}/{len(partitions_to_process)}] ✓ {partition_name}: {message} ({elapsed:.1f}s)")
            else:
                print(f"  [{completed}/{len(partitions_to_process)}] ✗ {partition_name}: Error - {message}")
    
    print(f"\nTotal time: {time.time() - start_time:.1f} seconds")

rename_files_and_cleanup(output_path)

#### Create manifest

In [None]:
def create_manifest():
    """
    Create a manifest file with all file metadata using parallel processing.
    """
    output_path = f"{s3_base_path}/{entity_type}"
    
    print(f"\nCreating manifest...")
    
    partitions = dbutils.fs.ls(output_path)
    partitions_to_process = sorted([p for p in partitions if p.name.startswith("updated_date=")], 
                                   key=lambda x: x.name, reverse=True)
    
    def process_file(partition_name, file_info):
        """Process a single file to get its metadata"""
        if not file_info.name.endswith('.gz'):
            return None
            
        try:
            # count records in the file
            record_count = spark.read.text(file_info.path).count()
            
            # set the s3 url to the prod s3 folder
            raw = file_info.path.replace("dbfs:/", "s3://")
            marker = f"/{entity_type}/"
            idx = raw.find(marker)
            if idx == -1:
                raise ValueError(f"Could not find '{marker}' in path: {raw}")
            relative = raw[idx:]
            s3_url = f"s3://openalex/data{relative}"

            entry = {
                "url": s3_url,
                "meta": {
                    "content_length": file_info.size,
                    "record_count": record_count
                }
            }
            
            return {
                "entry": entry,
                "partition": partition_name,
                "file": file_info.name,
                "size": file_info.size,
                "count": record_count
            }
        except Exception as e:
            print(f"Error processing {partition_name}{file_info.name}: {e}")
            return None
    
    # collect all file tasks
    file_tasks = []
    for partition in partitions_to_process:
        files = dbutils.fs.ls(partition.path)
        for file_info in files:
            if file_info.name.endswith('.gz'):
                file_tasks.append((partition.name, file_info))
    
    print(f"Processing {len(file_tasks)} files across {len(partitions_to_process)} partitions...")
    
    # process files in parallel
    entries = []
    total_content_length = 0
    total_record_count = 0
    
    with ThreadPoolExecutor(max_workers=50) as executor:
        futures = {executor.submit(process_file, task[0], task[1]): task 
                  for task in file_tasks}
        
        completed = 0
        for future in as_completed(futures):
            result = future.result()
            completed += 1
            
            if result:
                entries.append(result["entry"])
                total_content_length += result["size"]
                total_record_count += result["count"]
                
                if completed % 50 == 0 or completed == len(file_tasks):
                    print(f"  Progress: {completed}/{len(file_tasks)} files processed...")
                
                # print details for large files
                if result["size"] > 100 * 1024 * 1024:  # Files > 100MB
                    print(f"  {result['partition']}{result['file']}: "
                          f"{result['count']:,} records, {result['size']/(1024*1024):.1f} MB")
    
    entries.sort(key=lambda x: x["url"])
    
    manifest = {
        "entries": entries,
        "meta": {
            "content_length": total_content_length,
            "record_count": total_record_count
        }
    }
    
    manifest_path = f"{output_path}/manifest"
    manifest_json = json.dumps(manifest, indent=2)
    dbutils.fs.put(manifest_path, manifest_json, overwrite=True)
    
    print(f"\nManifest created: {manifest_path}")
    print(f"Total files: {len(entries)}")
    print(f"Total size (compressed): {total_content_length / (1024**3):.2f} GB")
    print(f"Total records: {total_record_count:,}")
    
    return manifest

create_manifest()