# spark-bestfit API Demo

This notebook demonstrates the complete API for the `spark-bestfit` library, including:

1. **Distribution Fitting** - Using DistributionFitter with direct parameters
2. **Working with Results** - FitResults and DistributionFitResult objects
3. **Plotting** - Visualization with customizable parameters
4. **Excluding Distributions** - Customizing which distributions to fit

## Setup

First, let's create a Spark session. Note: **You** are responsible for creating and configuring your SparkSession.

In [None]:
import numpy as np
from pyspark.sql import SparkSession

# Create Spark session (your responsibility - configure as needed for your environment)
spark = (
    SparkSession.builder
    .appName("API-Demo")
    .config("spark.sql.shuffle.partitions", "10")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .getOrCreate()
)

print(f"Spark version: {spark.version}")

In [None]:
# Import spark-bestfit components
from spark_bestfit import (
    DistributionFitter,
    DEFAULT_EXCLUDED_DISTRIBUTIONS,
)

## Generate Sample Data

We'll create sample data from known distributions for demonstration.

In [None]:
np.random.seed(42)

# Normal distribution data
normal_data = np.random.normal(loc=50, scale=10, size=50_000)
df_normal = spark.createDataFrame([(float(x),) for x in normal_data], ["value"])

# Exponential distribution data (non-negative)
exp_data = np.random.exponential(scale=5, size=50_000)
df_exp = spark.createDataFrame([(float(x),) for x in exp_data], ["value"])

# Gamma distribution data
gamma_data = np.random.gamma(shape=2.0, scale=2.0, size=50_000)
df_gamma = spark.createDataFrame([(float(x),) for x in gamma_data], ["value"])

print(f"Normal data: {df_normal.count():,} rows, mean={normal_data.mean():.2f}, std={normal_data.std():.2f}")
print(f"Exponential data: {df_exp.count():,} rows, mean={exp_data.mean():.2f}")
print(f"Gamma data: {df_gamma.count():,} rows, mean={gamma_data.mean():.2f}")

---

# Part 1: Excluded Distributions

spark-bestfit excludes some slow distributions by default. You can customize this.

## 1.1 DEFAULT_EXCLUDED_DISTRIBUTIONS

Some distributions are excluded by default because they are very slow to fit.

In [None]:
# View default excluded distributions
print(f"Default excluded distributions ({len(DEFAULT_EXCLUDED_DISTRIBUTIONS)}):")
for dist in sorted(DEFAULT_EXCLUDED_DISTRIBUTIONS):
    print(f"  - {dist}")

In [None]:
# Include a specific distribution that's excluded by default
custom_exclusions = tuple(d for d in DEFAULT_EXCLUDED_DISTRIBUTIONS if d != "wald")

fitter_with_wald = DistributionFitter(spark, excluded_distributions=custom_exclusions)
print(f"Now fitting 'wald' distribution (removed from exclusions)")

---

# Part 2: Distribution Fitting

The `DistributionFitter` class is the main entry point for fitting distributions.

## 2.1 Basic Fitting

In [None]:
# Create fitter with default config
fitter = DistributionFitter(spark)

# Fit distributions to normal data (limit to 20 for demo speed)
print("Fitting distributions to normal data...")
results_normal = fitter.fit(df_normal, column="value", max_distributions=20)

print(f"\nFitted {results_normal.count()} distributions")

## 2.2 Fitting with Custom Parameters

In [None]:
# Fit only non-negative distributions using support_at_zero=True
fitter_nonneg = DistributionFitter(spark)

print("Fitting non-negative distributions to exponential data...")
results_exp = fitter_nonneg.fit(
    df_exp,
    column="value",
    bins=100,
    support_at_zero=True,  # Only fit non-negative distributions
    enable_sampling=True,
    max_distributions=15,
)

print(f"Fitted {results_exp.count()} non-negative distributions")

## 2.3 Using Active SparkSession

If a SparkSession is already active, you don't need to pass it explicitly.

In [None]:
# DistributionFitter can use the active session automatically
fitter_active = DistributionFitter()  # No spark parameter needed
print(f"Using active session: {fitter_active.spark.sparkContext.appName}")

---

# Part 3: Working with Results

The `fit()` method returns a `FitResults` object for easy result manipulation.

## 3.1 Getting Best Distributions

In [None]:
# Get best distribution by SSE (default)
best_sse = results_normal.best(n=1)[0]
print(f"Best by SSE: {best_sse.distribution}")
print(f"  SSE: {best_sse.sse:.6f}")
print(f"  AIC: {best_sse.aic:.2f}")
print(f"  BIC: {best_sse.bic:.2f}")
print(f"  Parameters: {[f'{p:.4f}' for p in best_sse.parameters]}")

In [None]:
# Get top 5 by different metrics
print("Top 5 by SSE:")
for i, r in enumerate(results_normal.best(n=5, metric="sse"), 1):
    print(f"  {i}. {r.distribution:20s} SSE={r.sse:.6f}")

print("\nTop 5 by AIC:")
for i, r in enumerate(results_normal.best(n=5, metric="aic"), 1):
    print(f"  {i}. {r.distribution:20s} AIC={r.aic:.2f}")

print("\nTop 5 by BIC:")
for i, r in enumerate(results_normal.best(n=5, metric="bic"), 1):
    print(f"  {i}. {r.distribution:20s} BIC={r.bic:.2f}")

## 3.2 Filtering Results

In [None]:
# Filter by SSE threshold
good_fits = results_normal.filter(sse_threshold=0.01)
print(f"Distributions with SSE < 0.01: {good_fits.count()}")

for r in good_fits.best(n=10):
    print(f"  {r.distribution:20s} SSE={r.sse:.6f}")

## 3.3 Converting to Pandas

In [None]:
# Convert to pandas DataFrame for further analysis
df_results = results_normal.to_pandas()
print("Results as pandas DataFrame:")
df_results.head(10)

## 3.4 Using Fitted Distributions

In [None]:
# The DistributionFitResult object wraps the scipy.stats distribution
best = results_normal.best(n=1)[0]

# Generate samples from the fitted distribution
samples = best.sample(size=10000, random_state=42)
print(f"Generated {len(samples)} samples from fitted {best.distribution}")
print(f"  Sample mean: {samples.mean():.2f} (original: {normal_data.mean():.2f})")
print(f"  Sample std: {samples.std():.2f} (original: {normal_data.std():.2f})")

In [None]:
# Evaluate PDF at specific points
x = np.array([30, 40, 50, 60, 70])
pdf_values = best.pdf(x)
cdf_values = best.cdf(x)

print("PDF and CDF values:")
for xi, pdf, cdf in zip(x, pdf_values, cdf_values):
    print(f"  x={xi}: PDF={pdf:.6f}, CDF={cdf:.4f}")

---

# Part 4: Plotting

Visualize the fitted distribution with the data histogram.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

## 4.1 Basic Plot

In [None]:
# Basic plot with default config
fig, ax = fitter.plot(
    best,
    df_normal,
    "value",
    title="Best Fit Distribution (Normal Data)",
    xlabel="Value",
    ylabel="Density"
)
plt.show()

## 4.2 Plot with Custom Parameters

In [None]:
# Custom plot with direct parameters
fig, ax = fitter.plot(
    best,
    df_normal,
    "value",
    figsize=(14, 8),
    dpi=100,
    histogram_alpha=0.7,
    pdf_linewidth=3,
    title_fontsize=18,
    label_fontsize=14,
    legend_fontsize=12,
    grid_alpha=0.4,
    title="Distribution Fit with Custom Styling",
    xlabel="Value",
    ylabel="Density",
)
plt.show()

## 4.3 Plot Non-Negative Distribution

In [None]:
# Best fit for exponential data
best_exp = results_exp.best(n=1)[0]
print(f"Best fit for exponential data: {best_exp.distribution}")

fig, ax = fitter_nonneg.plot(
    best_exp,
    df_exp,
    "value",
    figsize=(14, 8),
    dpi=100,
    histogram_alpha=0.7,
    pdf_linewidth=3,
    title_fontsize=18,
    title=f"Best Fit: {best_exp.distribution.capitalize()}",
    xlabel="Value",
    ylabel="Density",
)
plt.show()

---

# Part 5: Complete Workflow Example

Putting it all together - a complete production-style workflow.

In [None]:
# Complete workflow with all parameters
fitter_gamma = DistributionFitter(spark, random_seed=42)

# Fit distributions
print("Fitting gamma distribution data...")
results = fitter_gamma.fit(
    df_gamma,
    column="value",
    bins=100,
    use_rice_rule=False,
    enable_sampling=True,
    max_sample_size=1_000_000,
    max_distributions=25,
)

# Get best result
best = results.best(n=1)[0]
print(f"\nBest distribution: {best.distribution}")
print(f"SSE: {best.sse:.6f}")
print(f"Parameters: {[f'{p:.4f}' for p in best.parameters]}")

# Plot with custom parameters
fig, ax = fitter_gamma.plot(
    best,
    df_gamma,
    "value",
    figsize=(14, 9),
    dpi=150,
    histogram_alpha=0.6,
    pdf_linewidth=3,
    title_fontsize=16,
    title=f"Gamma Data - Best Fit: {best.distribution.capitalize()}",
    xlabel="Value",
    ylabel="Density",
)
plt.show()

# Show top 5 results
print("\nTop 5 distributions:")
df_top5 = results.to_pandas().head(5)
df_top5[["distribution", "sse", "aic", "bic"]]

---

# Cleanup

In [None]:
spark.stop()
print("Spark session stopped.")

---

## Summary

This notebook demonstrated:

1. **Excluded Distributions**:
   - `DEFAULT_EXCLUDED_DISTRIBUTIONS` - Slow distributions excluded by default
   - Pass custom `excluded_distributions` to `DistributionFitter()` to include/exclude

2. **SparkSession Management**:
   - You create and configure your own SparkSession
   - Pass it to `DistributionFitter(spark)` or use active session

3. **Fitting**:
   - `DistributionFitter.fit()` - Fit distributions to data
   - Parameters: `bins`, `use_rice_rule`, `support_at_zero`, `enable_sampling`, etc.
   - `max_distributions` parameter to limit fitting scope

4. **Results**:
   - `results.best(n, metric)` - Get top N by SSE/AIC/BIC
   - `results.filter()` - Filter by threshold
   - `results.to_pandas()` - Convert to pandas DataFrame
   - `DistributionFitResult.sample()`, `.pdf()`, `.cdf()` - Use fitted distribution

5. **Plotting**:
   - `fitter.plot()` - Visualize fitted distribution with data histogram
   - Customizable with `figsize`, `dpi`, `histogram_alpha`, `pdf_linewidth`, etc.