In [0]:
import re
import boto3
import math
from datetime import datetime, timezone
from pyspark.sql.functions import col
from concurrent.futures import ThreadPoolExecutor, as_completed

s3_bucket = "unpaywall-data-feed-walden"
source_table = "openalex.unpaywall.unpaywall"
s3_prefix = "full_snapshots"

AWS_ACCESS_KEY_ID = dbutils.secrets.get("webscraper", "aws_access_key_id")
AWS_SECRET_ACCESS_KEY = dbutils.secrets.get("webscraper", "aws_secret_access_key")

df = spark.table(source_table)
current_date = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H%M%S')
final_filename = f"unpaywall_snapshot_{current_date}.jsonl.gz"
final_key = f"{s3_prefix}/{final_filename}"
temp_prefix = f"temp_parts_{current_date}"
temp_s3_path = f"s3://{s3_bucket}/{temp_prefix}"

record_count = df.count()
print(f"Found {record_count} records to export")

if record_count == 0:
    dbutils.notebook.exit("No records found to export")

records_per_partition = 1000000
num_partitions = max(1, int(record_count / records_per_partition))

print(f"Writing data in parallel to {temp_s3_path} using {num_partitions} partitions...")

(df.select("json_response")
    .repartition(num_partitions)
    .write
    .mode("overwrite")
    .option("compression", "gzip")
    .text(temp_s3_path)
)

# combine the parts without downloading them to the driver

print("Parallel write complete. Starting S3 Multipart Merge...")

s3_client = boto3.client(
    's3', 
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)

def copy_part(bucket, source_key, dest_key, part_num, upload_id):
    """Helper function to copy a single part."""
    try:
        copy_source = {'Bucket': bucket, 'Key': source_key}
        response = s3_client.upload_part_copy(
            Bucket=bucket,
            CopySource=copy_source,
            Key=dest_key,
            PartNumber=part_num,
            UploadId=upload_id
        )
        return {
            'ETag': response['CopyPartResult']['ETag'],
            'PartNumber': part_num
        }
    except Exception as e:
        print(f"Error copying part {part_num}: {e}")
        raise e

def s3_multipart_merge(bucket, source_prefix, dest_key):
    """Merges all GZIP part files in a prefix into a single S3 object using threads."""
    
    # List all part files
    paginator = s3_client.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket, Prefix=source_prefix)
    
    parts = []
    for page in pages:
        if 'Contents' in page:
            for obj in page['Contents']:
                if obj['Key'].endswith('.gz') and 'part-' in obj['Key']:
                    parts.append(obj['Key'])
    
    print(f"Found {len(parts)} parts to merge.")
    if not parts:
        raise Exception("No parts found to merge.")

    # sort parts to ensure deterministic order
    parts.sort()

    mp_upload = s3_client.create_multipart_upload(Bucket=bucket, Key=dest_key)
    upload_id = mp_upload['UploadId']
    
    completed_parts = []
    
    try:
        print("Starting parallel merge...")
        with ThreadPoolExecutor(max_workers=20) as executor:
            futures = []
            for i, part_key in enumerate(parts):
                part_num = i + 1
                futures.append(
                    executor.submit(copy_part, bucket, part_key, dest_key, part_num, upload_id)
                )
            
            for future in as_completed(futures):
                completed_parts.append(future.result())
                
        completed_parts.sort(key=lambda x: x['PartNumber'])
            
        print("Finalizing multipart upload...")
        s3_client.complete_multipart_upload(
            Bucket=bucket,
            Key=dest_key,
            UploadId=upload_id,
            MultipartUpload={'Parts': completed_parts}
        )
        print("Merge completed successfully.")
        
    except Exception as e:
        print(f"Merge failed: {e}")
        s3_client.abort_multipart_upload(Bucket=bucket, Key=dest_key, UploadId=upload_id)
        raise e

try:
    s3_multipart_merge(s3_bucket, temp_prefix, final_key)
    print(f"Successfully created single file: s3://{s3_bucket}/{final_key}")
    
    dbutils.fs.rm(temp_s3_path, recurse=True)
    
except Exception as e:
    print(f"Error during S3 merge: {str(e)}")
    raise e

In [0]:
print("Starting cleanup of old snapshots...")

try:
    response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix)
    
    if 'Contents' in response:
        snapshot_files = []
        snapshot_pattern = re.compile(r'unpaywall_snapshot_(\d{4}-\d{2}-\d{2}T\d{6})\.jsonl\.gz$')
        
        for obj in response['Contents']:
            key = obj['Key']
            filename = key.split('/')[-1]
            match = snapshot_pattern.search(filename)
            if match:
                timestamp = datetime.strptime(match.group(1), '%Y-%m-%dT%H%M%S')
                snapshot_files.append({'Key': key, 'LastModified': timestamp})
        
        snapshot_files.sort(key=lambda x: x['LastModified'], reverse=True)
        
        snapshots_to_keep = 5
        if len(snapshot_files) > snapshots_to_keep:
            to_delete = snapshot_files[snapshots_to_keep:]
            print(f"Deleting {len(to_delete)} old snapshot(s)...")
            
            objects_to_delete = [{'Key': item['Key']} for item in to_delete]
            
            for i in range(0, len(objects_to_delete), 1000):
                batch = objects_to_delete[i:i+1000]
                s3_client.delete_objects(Bucket=s3_bucket, Delete={'Objects': batch})
                print(f"Deleted batch of {len(batch)} files.")
        else:
            print("No cleanup needed.")
            
except Exception as e:
    print(f"Warning: Cleanup failed, but export was successful. Error: {e}")

print("Job completed.")