In [0]:
# No external packages needed - using synthetic data generation

# NYC Green Taxi Data Ingestion - Transactional Batch Mode

This notebook demonstrates ingesting NYC Green Taxi dataset into Azure Cosmos DB using **transactional batch mode**.

**Configuration:**
- Authentication: Managed Identity (Azure Databricks)
- Same data preparation as normal bulk mode for fair comparison
- Only difference: `spark.cosmos.write.bulk.transactional = true`

**Key Requirements for Transactional Batches:**
- All operations within a batch must be for the same partition key
- Maximum 100 operations per batch
- Maximum 2MB per batch
- All operations succeed or all fail atomically

Dataset: [Azure Open Datasets - NYC Green Taxi](https://learn.microsoft.com/en-gb/azure/open-datasets/dataset-taxi-green)

## Configuration
Update these values with your Cosmos DB account details

In [0]:
# Cosmos DB Configuration - Managed Identity
config = {
    "spark.cosmos.accountEndpoint": "https://tvk-my-cosmos-account.documents.azure.com:443/",
    "spark.cosmos.database": "spark-load-tests",
    "spark.cosmos.container": "transactional-bulk",
    
    # Use MI (no client secret)
    "spark.cosmos.auth.type": "ManagedIdentity",
    
    # Required for ARM metadata lookup
    "spark.cosmos.account.subscriptionId": "220fc532-6091-423c-8ba0-66c2397d591b",
    "spark.cosmos.account.resourceGroupName": "tvk-mgt-tests",
    "spark.cosmos.account.tenantId": "72f988bf-86f1-41af-91ab-2d7cd011db47",
}

In [0]:
# ========================================
# PERFORMANCE TEST CONFIGURATION
# ========================================
# Total records to generate
RECORD_LIMIT = 1000000  # Change this value to control data volume

# Minimum and maximum records per partition (logical partition key)
# IMPORTANT: For transactional batches, must be <= 100
# Provide a range to vary partition sizes. Set MIN_RECORDS_PER_PARTITION and MAX_RECORDS_PER_PARTITION.
# Examples:
#   MIN=5, MAX=10  - Variable small batches
#   MIN=50, MAX=90 - Variable medium batches
MIN_RECORDS_PER_PARTITION = 10
MAX_RECORDS_PER_PARTITION = 10  # Ensures no partition exceeds this count

import random

# Build partition size list using random sizes within the range while summing to RECORD_LIMIT
partition_sizes = []
remaining = RECORD_LIMIT
while remaining > 0:
    sz = random.randint(MIN_RECORDS_PER_PARTITION, MAX_RECORDS_PER_PARTITION)
    if sz > remaining:
        sz = remaining
    partition_sizes.append(sz)
    remaining -= sz

num_partitions = len(partition_sizes)

print(f"Record limit: {RECORD_LIMIT}")
print(f"Partition size range: {MIN_RECORDS_PER_PARTITION}-{MAX_RECORDS_PER_PARTITION}")
print(f"Will generate {num_partitions} partitions (hours) with variable records per partition")

Record limit: 1000000
Max records per partition: 10
Will generate 100000 partitions (hours) with ~10 records each


## Generate Synthetic Taxi Trip Data

Creating realistic synthetic data - SAME AS NORMAL MODE for fair comparison

In [0]:
from pyspark.sql import Row
from datetime import datetime, timedelta
import random

# Set random seed for reproducibility - SAME AS NORMAL MODE
random.seed(42)

# Generate synthetic taxi trip data with controlled partition distribution
print(f"Generating {RECORD_LIMIT} synthetic taxi trip records...")
print(f"Distributing across {num_partitions} partitions with max {MAX_RECORDS_PER_PARTITION} records each")

# Create realistic taxi trip data with controlled distribution
data = []
base_time = datetime(2022, 1, 1, 0, 0, 0)

# Distribute records according to partition_sizes
partition_index = 0
for sz in partition_sizes:
    for _ in range(sz):
        # Generate pickup time for this partition
        pickup_time = base_time + timedelta(hours=partition_index, minutes=random.randint(0, 59))
        trip_duration = random.randint(5, 60)  # 5-60 minutes
        dropoff_time = pickup_time + timedelta(minutes=trip_duration)
        data.append(Row(
            vendorID=random.choice([1, 2]),
            lpepPickupDatetime=pickup_time,
            lpepDropoffDatetime=dropoff_time,
            passengerCount=random.randint(1, 6),
            tripDistance=round(random.uniform(0.5, 20.0), 2),
            fareAmount=round(random.uniform(5.0, 75.0), 2),
            extra=round(random.uniform(0, 2.0), 2),
            tipAmount=round(random.uniform(0, 15.0), 2),
            tollsAmount=round(random.uniform(0, 5.0), 2),
            totalAmount=0.0  # Will calculate
        ))
    partition_index += 1

# Create DataFrame
df = spark.createDataFrame(data)

# Calculate total amount
from pyspark.sql.functions import col
df = df.withColumn("totalAmount", 
    col("fareAmount") + col("extra") + col("tipAmount") + col("tollsAmount"))

print(f"Generated {df.count()} records")

# Verify partition distribution
from pyspark.sql.functions import date_format, count as sql_count
partition_counts = df.groupBy(date_format(col("lpepPickupDatetime"), "yyyy-MM-dd-HH").alias("hour")) \
    .agg(sql_count("*").alias("count")) \
    .orderBy("hour")

print("\n✓ Partition distribution (records per hour):")
partition_counts.show(100, truncate=False)

max_partition = partition_counts.agg({"count": "max"}).collect()[0][0]
print(f"\n✓ Maximum records in any partition: {max_partition} (limit: {MAX_RECORDS_PER_PARTITION})")

if max_partition > 100:
    print("⚠️ WARNING: Some partitions exceed 100 records - not suitable for transactional batches!")
elif max_partition <= MAX_RECORDS_PER_PARTITION:
    print("✓ All partitions are within limits!")

df.show(5)

Generating 1000000 synthetic taxi trip records...
Distributing across 100000 partitions with max 10 records each
Generated 1000000 records

✓ Partition distribution (records per hour):
+-------------+-----+
|hour         |count|
+-------------+-----+
|2022-01-01-00|10   |
|2022-01-01-01|10   |
|2022-01-01-02|10   |
|2022-01-01-03|10   |
|2022-01-01-04|10   |
|2022-01-01-05|10   |
|2022-01-01-06|10   |
|2022-01-01-07|10   |
|2022-01-01-08|10   |
|2022-01-01-09|10   |
|2022-01-01-10|10   |
|2022-01-01-11|10   |
|2022-01-01-12|10   |
|2022-01-01-13|10   |
|2022-01-01-14|10   |
|2022-01-01-15|10   |
|2022-01-01-16|10   |
|2022-01-01-17|10   |
|2022-01-01-18|10   |
|2022-01-01-19|10   |
|2022-01-01-20|10   |
|2022-01-01-21|10   |
|2022-01-01-22|10   |
|2022-01-01-23|10   |
|2022-01-02-00|10   |
|2022-01-02-01|10   |
|2022-01-02-02|10   |
|2022-01-02-03|10   |
|2022-01-02-04|10   |
|2022-01-02-05|10   |
|2022-01-02-06|10   |
|2022-01-02-07|10   |
|2022-01-02-08|10   |
|2022-01-02-09|10   |
|

## Prepare Data for Cosmos DB

Add required fields - SAME AS NORMAL MODE for fair comparison:
- `id`: Unique identifier
- `partitionKey`: Date-hour format (`yyyy-MM-dd-HH`)
- Order by partition key
- Repartition by partition key

In [0]:
from pyspark.sql.functions import col, date_format, concat_ws, monotonically_increasing_id

# Add id and partitionKey fields - SAME AS NORMAL MODE
df_prepared = df \
    .withColumn("partitionKey", date_format(col("lpepPickupDatetime"), "yyyy-MM-dd-HH")) \
    .withColumn("id", concat_ws("-", 
                                col("partitionKey"),
                                monotonically_increasing_id().cast("string")))

# Order by partitionKey - same as normal mode
df_prepared = df_prepared.orderBy("partitionKey")

print(f"Prepared {df_prepared.count()} records for transactional batch ingestion")
print("\nSample data (showing partition key grouping):")
df_prepared.select("id", "partitionKey", "lpepPickupDatetime", "fareAmount", "tripDistance").show(15)

Prepared 1000000 records for transactional batch ingestion

Sample data (showing partition key grouping):
+----------------+-------------+-------------------+----------+------------+
|              id| partitionKey| lpepPickupDatetime|fareAmount|tripDistance|
+----------------+-------------+-------------------+----------+------------+
| 2022-01-01-00-0|2022-01-01-00|2022-01-01 00:40:00|     20.62|        5.86|
| 2022-01-01-00-1|2022-01-01-00|2022-01-01 00:05:00|      20.3|        1.08|
| 2022-01-01-00-2|2022-01-01-00|2022-01-01 00:41:00|     24.47|        9.26|
| 2022-01-01-00-3|2022-01-01-00|2022-01-01 00:27:00|     58.44|         4.7|
| 2022-01-01-00-4|2022-01-01-00|2022-01-01 00:22:00|     42.54|       14.73|
| 2022-01-01-00-5|2022-01-01-00|2022-01-01 00:53:00|      9.87|        4.25|
| 2022-01-01-00-6|2022-01-01-00|2022-01-01 00:54:00|      49.5|        5.92|
| 2022-01-01-00-7|2022-01-01-00|2022-01-01 00:17:00|     42.39|       12.88|
| 2022-01-01-00-8|2022-01-01-00|2022-01-01 00:5

## Write to Cosmos DB - Transactional Batch Mode

**Only difference from normal mode:** `spark.cosmos.write.bulk.transactional = true`

All other configurations and optimizations are identical for fair performance comparison.

In [0]:
# Cosmos DB write configuration - TRANSACTIONAL BATCH MODE
write_config = {
    **config,  # Include all Managed Identity config
    "spark.cosmos.write.strategy": "ItemOverwrite",
    "spark.cosmos.write.bulk.enabled": "true",
    # ONLY DIFFERENCE: transactional flag set to true
    "spark.cosmos.write.bulk.transactional": "true"
}

print("Starting write to Cosmos DB (transactional batch mode)...")
print("⚡ Each batch will be atomic - all operations succeed or all fail")
start_time = datetime.now()

# Apply same repartitioning as normal mode for fair comparison
df_prepared.repartition("partitionKey") \
    .write \
    .format("cosmos.oltp") \
    .options(**write_config) \
    .mode("append") \
    .save()

end_time = datetime.now()
duration = (end_time - start_time).total_seconds()

print(f"\n✓ Transactional batch write completed in {duration:.2f} seconds")
print(f"Records written: {df_prepared.count()}")
print("All batches committed atomically!")

Starting write to Cosmos DB (transactional batch mode)...
⚡ Each batch will be atomic - all operations succeed or all fail

✓ Transactional batch write completed in 98.41 seconds
Records written: 1000000
All batches committed atomically!


## Verify Data in Cosmos DB

In [0]:
# Read back from Cosmos DB to verify
read_config = {
    **config  # Use same Managed Identity config
}

df_verify = spark.read.format("cosmos.oltp").options(**read_config).load()
print(f"Total records in Cosmos DB: {df_verify.count()}")

# Show distribution by partition
print("\nRecords by partition key:")
from pyspark.sql.functions import count
df_verify.groupBy("partitionKey") \
    .agg(count("*").alias("count")) \
    .orderBy("partitionKey") \
    .show(20)

print("\nSample records:")
df_verify.select("id", "partitionKey", "lpepPickupDatetime", "fareAmount").show(10)

Total records in Cosmos DB: 1000000

Records by partition key:
+-------------+-----+
| partitionKey|count|
+-------------+-----+
|2022-01-01-00|   10|
|2022-01-01-01|   10|
|2022-01-01-02|   10|
|2022-01-01-03|   10|
|2022-01-01-04|   10|
|2022-01-01-05|   10|
|2022-01-01-06|   10|
|2022-01-01-07|   10|
|2022-01-01-08|   10|
|2022-01-01-09|   10|
|2022-01-01-10|   10|
|2022-01-01-11|   10|
|2022-01-01-12|   10|
|2022-01-01-13|   10|
|2022-01-01-14|   10|
|2022-01-01-15|   10|
|2022-01-01-16|   10|
|2022-01-01-17|   10|
|2022-01-01-18|   10|
|2022-01-01-19|   10|
+-------------+-----+
only showing top 20 rows


Sample records:
+----------------+-------------+------------------+----------+
|              id| partitionKey|lpepPickupDatetime|fareAmount|
+----------------+-------------+------------------+----------+
|2022-01-01-02-20|2022-01-01-02|  1641002820000000|     46.18|
|2022-01-01-02-21|2022-01-01-02|  1641004800000000|     65.17|
|2022-01-01-02-22|2022-01-01-02|  1641003180000000|