# Deterministic, Scalable Weighted Sampling with PySpark

In [1]:
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.testing import assertDataFrameEqual
from qikits.sampling.weighted_sampling import weighted_sample

In [2]:
spark = (
    SparkSession.builder.appName("DeterministicWeightedSampling")
    .master("local[*]")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/10 18:19:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# This method is scalable to billions of rows. We'll create a sample DataFrame
# with 1 million rows for this demonstration.
# The real DataFrame should have:
# - A column with a unique identifier (here, 'id').
# - A column with the sampling weight (here, 'weight').
print("Creating a sample DataFrame with 1,000,000 rows...")
df = (
    spark.range(1_000_000)
    .withColumn("id", F.col("id"))
    .withColumn("weight", F.rand(seed=42) * 100 + 1)
    .withColumn("data", F.concat(F.lit("record_"), F.col("id")))
)

print(f"Original DataFrame has {df.count()} rows.")
df.printSchema()

Creating a sample DataFrame with 1,000,000 rows...
Original DataFrame has 1000000 rows.
root
 |-- id: long (nullable = false)
 |-- weight: double (nullable = false)
 |-- data: string (nullable = false)



                                                                                

In [4]:
# --- Parameters ---
# The number of elements to sample.
k = 10
# A string seed to ensure the sampling is deterministic.
# Change this seed to get a different sample.
sampling_seed = "my-secret-seed-42"

sampled_df = weighted_sample(
    df, k=k, weight_col="weight", id_columns="id", seed=sampling_seed
)

# Show the Results
print(f"\nSuccessfully sampled {sampled_df.count()} elements.")
print("Sampled DataFrame:")
sampled_df.show(truncate=False)

# To prove it's deterministic, let's run it again with the same seed.
# The result will be identical.
print("\nRunning again with the same seed to prove determinism...")
rerun_sampled_df = weighted_sample(
    df, k=k, weight_col="weight", id_columns="id", seed=sampling_seed
)

# Verify that both sampled DataFrames are identical.
print("Verifying that both sampled DataFrames are identical...")
assertDataFrameEqual(sampled_df, rerun_sampled_df)
print("✅ Verification successful: The samples are identical.")


Successfully sampled 10 elements.
Sampled DataFrame:


                                                                                

+------+------------------+-------------+
|id    |weight            |data         |
+------+------------------+-------------+
|861053|45.50996449144522 |record_861053|
|377342|75.03059429328809 |record_377342|
|664857|50.54429652966017 |record_664857|
|337391|97.76573696149322 |record_337391|
|507042|27.128717353831842|record_507042|
|536173|78.52170144645547 |record_536173|
|426275|93.9210825314597  |record_426275|
|504778|49.232353180455526|record_504778|
|687258|43.459986093404076|record_687258|
|725749|95.8164756277254  |record_725749|
+------+------------------+-------------+


Running again with the same seed to prove determinism...
Verifying that both sampled DataFrames are identical...
✅ Verification successful: The samples are identical.


In [None]:
spark.stop()