# Distributed Sampling & Fit Quality Demo

This notebook demonstrates:
1. **Distributed Sampling** - Generate millions of samples using Spark parallelism
2. **Fit Quality Warnings** - Automatic detection of poor fits
3. **Performance Comparison** - Local vs distributed sampling

In [1]:
import time
import warnings

import numpy as np
from pyspark.sql import SparkSession

from spark_bestfit import DistributionFitter

In [2]:
# Create Spark session
spark = SparkSession.builder \
    .master("local[4]") \
    .appName("SamplingDemo") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")



Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/02 22:53:50 WARN Utils: Your hostname, 2025m5.local, resolves to a loopback address: 127.0.0.1; using 192.168.1.201 instead (on interface en0)
26/01/02 22:53:50 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


26/01/02 22:53:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## 1. Fit a Distribution

First, let's generate some sample data and fit distributions to it.

In [3]:
# Generate gamma-distributed data
np.random.seed(42)
data = np.random.gamma(shape=2.0, scale=5.0, size=10_000)

# Create DataFrame
df = spark.createDataFrame([(float(x),) for x in data], ["value"])

# Fit distributions
fitter = DistributionFitter(spark, random_seed=42)
results = fitter.fit(df, column="value", max_distributions=20)

# Get best fit
best = results.best(n=1)[0]
print(f"Best fit: {best.distribution}")
print(f"Parameters: {best.parameters}")
print(f"K-S statistic: {best.ks_statistic:.4f}")
print(f"P-value: {best.pvalue:.4f}")

[Stage 0:>                                                          (0 + 0) / 4][Stage 0:>                                                          (0 + 4) / 4]



[Stage 15:>                                                         (0 + 4) / 4]

                                                                                

Best fit: chi2
Parameters: [4.071206092834473, 0.009233290329575539, 2.472316026687622]
K-S statistic: 0.0051
P-value: 0.9572


## 2. Distributed Sampling with `sample_spark()`

Generate samples using Spark's distributed computing. This is ideal for generating millions of samples efficiently.

In [4]:
# Generate 100,000 samples using Spark
samples_df = best.sample_spark(
    n=100_000,
    spark=spark,
    random_seed=42,
)

print(f"Schema: {samples_df.schema}")
print(f"Sample count: {samples_df.count():,}")
samples_df.show(5)

Schema: StructType([StructField('sample', DoubleType(), False)])
Sample count: 100,000
+------------------+
|            sample|
+------------------+
| 21.62513549334104|
|10.715187326200923|
|14.942145266139688|
|6.2094943252486114|
| 11.09381252827893|
+------------------+
only showing top 5 rows


In [5]:
# Custom column name and explicit partitions
samples_df = best.sample_spark(
    n=50_000,
    spark=spark,
    num_partitions=8,
    column_name="generated_values",
    random_seed=123,
)

samples_df.show(5)

+------------------+
|  generated_values|
+------------------+
| 10.83504621526457|
| 4.173895196887788|
|22.627553500211572|
|14.512516565715536|
| 7.652403897328518|
+------------------+
only showing top 5 rows


## 3. Performance Comparison: Local vs Distributed

Compare the performance of local sampling (`sample()`) vs distributed sampling (`sample_spark()`).

In [6]:
def benchmark_sampling(n_samples: int, n_runs: int = 3):
    """Benchmark local vs distributed sampling."""
    
    # Local sampling
    local_times = []
    for _ in range(n_runs):
        start = time.time()
        samples = best.sample(size=n_samples, random_state=42)
        local_times.append(time.time() - start)
    
    # Distributed sampling
    spark_times = []
    for _ in range(n_runs):
        start = time.time()
        samples_df = best.sample_spark(n=n_samples, spark=spark, random_seed=42)
        _ = samples_df.count()  # Force execution
        spark_times.append(time.time() - start)
    
    return {
        "n_samples": n_samples,
        "local_avg_ms": np.mean(local_times) * 1000,
        "spark_avg_ms": np.mean(spark_times) * 1000,
    }

# Run benchmarks for different sizes
print(f"{'N Samples':>12} | {'Local (ms)':>12} | {'Spark (ms)':>12} | {'Winner':>10}")
print("-" * 55)

for n in [1_000, 1_000_000, 10_000_000, 50_000_000]:
    result = benchmark_sampling(n)
    winner = "Local" if result["local_avg_ms"] < result["spark_avg_ms"] else "Spark"
    print(f"{result['n_samples']:>12,} | {result['local_avg_ms']:>12.1f} | {result['spark_avg_ms']:>12.1f} | {winner:>10}")

   N Samples |   Local (ms) |   Spark (ms) |     Winner
-------------------------------------------------------


       1,000 |          0.4 |        369.9 |      Local


   1,000,000 |         16.5 |         64.9 |      Local


  10,000,000 |        154.3 |        152.3 |      Spark


  50,000,000 |        774.2 |        538.9 |      Spark


**Key Takeaway:** Local sampling is faster for smaller sample sizes due to Spark overhead. Distributed sampling becomes advantageous for very large samples (millions) where memory and parallelism matter.

## 4. Fit Quality Warnings

Use `warn_if_poor=True` to automatically detect poor fits.

In [7]:
# Create data that doesn't fit well with common distributions
# (bimodal distribution)
np.random.seed(42)
bimodal_data = np.concatenate([
    np.random.normal(10, 2, 5000),
    np.random.normal(30, 2, 5000)
])

bimodal_df = spark.createDataFrame([(float(x),) for x in bimodal_data], ["value"])

# Fit only 3 distributions for speed (bimodal data is slow to fit)
results = fitter.fit(bimodal_df, column="value", max_distributions=3)

# Get best with warning enabled
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    best_bimodal = results.best(n=1, warn_if_poor=True)[0]
    
    # Filter for UserWarning only (ignore ResourceWarning from Spark sockets)
    user_warnings = [x for x in w if issubclass(x.category, UserWarning)]
    
    if user_warnings:
        print("Warning raised!")
        print(f"  Message: {user_warnings[0].message}")
    else:
        print(f"Best fit: {best_bimodal.distribution} (p-value: {best_bimodal.pvalue:.4f})")

  Message: Best fit 'arcsine' has p-value 0.0000 < 0.05, indicating a potentially poor fit. Consider using quality_report() for detailed diagnostics.


## 5. Quality Report

Get a comprehensive quality assessment with `quality_report()`.

In [8]:
# Get quality report
report = results.quality_report(n=5)

print("=" * 50)
print("QUALITY REPORT")
print("=" * 50)

print(f"\nTotal distributions fitted: {report['summary']['total_distributions']}")
print(f"Distributions meeting thresholds: {report['n_acceptable']}")

print("\nTop 5 Fits:")
for i, fit in enumerate(report['top_fits'], 1):
    print(f"  {i}. {fit.distribution}: KS={fit.ks_statistic:.4f}, p={fit.pvalue:.4f}")

print("\nWarnings:")
if report['warnings']:
    for warning in report['warnings']:
        print(f"  - {warning}")
else:
    print("  None - all fits look good!")

QUALITY REPORT

Total distributions fitted: 3
Distributions meeting thresholds: 0

Top 5 Fits:
  1. arcsine: KS=0.1940, p=0.0000
  2. alpha: KS=0.2374, p=0.0000
  3. anglit: KS=0.2381, p=0.0000

  - Best fit 'arcsine' has low p-value (0.0000 < 0.05)
  - Best fit 'arcsine' has high K-S statistic (0.1940 > 0.1)
  - Best fit 'arcsine' has high A-D statistic (775.4061 > 2.0)
  - No distributions meet all quality thresholds


In [9]:
# Custom thresholds for stricter quality assessment
strict_report = results.quality_report(
    n=3,
    pvalue_threshold=0.10,  # Stricter p-value
    ks_threshold=0.05,      # Stricter KS
    ad_threshold=1.0        # Stricter A-D
)

print(f"With strict thresholds: {strict_report['n_acceptable']} acceptable fits")
print(f"Warnings: {strict_report['warnings']}")

With strict thresholds: 0 acceptable fits


## 6. Complete Workflow Example

Putting it all together: fit, validate, sample.

In [10]:
# Generate well-behaved data
np.random.seed(42)
good_data = np.random.exponential(scale=5.0, size=10_000)
good_df = spark.createDataFrame([(float(x),) for x in good_data], ["value"])

# Fit (using fewer distributions for demo speed)
results = fitter.fit(good_df, column="value", max_distributions=5)

# Validate with quality report
report = results.quality_report()
if report['warnings']:
    print("Quality issues detected - review before using")
else:
    print("Quality check passed!")

# Get best fit
best = results.best(n=1, warn_if_poor=True)[0]
print(f"\nBest distribution: {best.distribution}")
print(f"Parameters: {dict(zip(best.get_param_names(), best.parameters))}")

# Generate samples for downstream use
synthetic_df = best.sample_spark(n=50_000, spark=spark, random_seed=42)
print(f"\nGenerated {synthetic_df.count():,} synthetic samples")

# Verify statistical properties
samples = synthetic_df.toPandas()["sample"].values
print(f"Sample mean: {samples.mean():.2f} (expected ~5.0)")
print(f"Sample std: {samples.std():.2f} (expected ~5.0)")

Quality issues detected - review before using

Best distribution: beta
Parameters: {'a': 0.9914222955703735, 'b': 77.39903259277344, 'loc': 5.8174115110887215e-05, 'scale': 386.3881530761719}

Generated 50,000 synthetic samples
Sample mean: 4.89 (expected ~5.0)
Sample std: 4.86 (expected ~5.0)


In [11]:
# Cleanup
spark.stop()