### Daily unpaywall snapshot - all records into single compressed file

In [0]:
import re
from datetime import datetime, timezone
import pyspark.sql.functions as F

s3_bucket = "unpaywall-data-feed-walden"
source_table = "openalex.unpaywall.unpaywall"
s3_prefix = "full_snapshots"

df = spark.table(source_table)
current_date = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H%M%S')
filename = f"unpaywall_snapshot_{current_date}.jsonl.gz"

record_count = df.count()
print(f"Found {record_count} records to export")

if record_count == 0:
    print("No records found to export. Job completed.")
    dbutils.notebook.exit("No records found to export")

final_path = f"s3://{s3_bucket}/{s3_prefix}/{filename}"

print(f"Exporting {record_count} records to {final_path}")

records_per_partition = 5000000
partition_count = max(1, int(record_count / records_per_partition))
print(f"Using {partition_count} partitions for initial distribution")

# Create a temp directory with unique name for uncompressed data
temp_dir = f"s3://{s3_bucket}/temp_export_{current_date}/"

# Write uncompressed data in distributed fashion
(df.select("json_response")
    .repartition(partition_count)
    .write
    .format("text")
    .mode("overwrite")
    .save(temp_dir)
)

print("Data written to temp location in uncompressed format")
print("Compressing data into a single file...")

# use coalesce instead of repartition for the final step to avoid shuffle
spark.read.text(temp_dir).coalesce(1).write.format("text").mode("overwrite").option("compression", "gzip").save("/tmp/unpaywall_single/")

# get the resulting single part file
single_part_file = [f.path for f in dbutils.fs.ls("/tmp/unpaywall_single/") if f.name.startswith("part-") and f.name.endswith(".gz")][0]

dbutils.fs.cp(single_part_file, final_path)
dbutils.fs.rm("/tmp/unpaywall_single/", recurse=True)
dbutils.fs.rm(temp_dir, recurse=True)

print(f"Successfully exported {record_count} records to {final_path}")

# cleanup: Keep only the 5 most recent snapshots
print("Starting cleanup of old snapshots...")

try:
    # list all files in the snapshots directory
    snapshot_files = dbutils.fs.ls(f"s3://{s3_bucket}/{s3_prefix}/")
    
    # filter for snapshot files and extract timestamps
    snapshot_pattern = re.compile(r'unpaywall_snapshot_(\d{4}-\d{2}-\d{2}T\d{6})\.jsonl\.gz$')
    snapshots_with_timestamps = []
    
    for file_info in snapshot_files:
        match = snapshot_pattern.search(file_info.name)
        if match:
            timestamp_str = match.group(1)
            # Convert timestamp to datetime for sorting
            timestamp = datetime.strptime(timestamp_str, '%Y-%m-%dT%H%M%S')
            snapshots_with_timestamps.append((file_info.path, timestamp, file_info.name))
    
    print(f"Found {len(snapshots_with_timestamps)} snapshot files")
    
    # sort by timestamp (newest first)
    snapshots_with_timestamps.sort(key=lambda x: x[1], reverse=True)
    
    # keep only the 5 most recent, delete the rest
    snapshots_to_keep = 5
    if len(snapshots_with_timestamps) > snapshots_to_keep:
        snapshots_to_delete = snapshots_with_timestamps[snapshots_to_keep:]
        
        print(f"Deleting {len(snapshots_to_delete)} old snapshot(s):")
        for file_path, timestamp, filename in snapshots_to_delete:
            print(f"  - Deleting: {filename} (created: {timestamp})")
            dbutils.fs.rm(file_path)
        
        print(f"Cleanup completed. Kept {snapshots_to_keep} most recent snapshots.")
    else:
        print(f"Only {len(snapshots_with_timestamps)} snapshots found. No cleanup needed.")
        
except Exception as e:
    print(f"Error during cleanup: {str(e)}")
    print("Snapshot creation was successful, but cleanup failed.")

print("Script completed successfully.")