# Forecasting

In order to generate models, quality checks, reserves, etc., we need forecasts of the well production. This notebook demonstrates:

1. **Data Processing**: Preparing well production data for forecasting
2. **ARPS Decline Curves**: Automatically fitting exponential, hyperbolic, and harmonic decline curves
3. **Forecasting**: Generating production forecasts for individual wells

## ARPS Decline Curves

The Arps decline curve equations are fundamental tools in petroleum engineering for forecasting oil and gas production:

- **Exponential Decline (b=0)**: `q(t) = qi * exp(-Di * t)`
- **Hyperbolic Decline (0<b<1)**: `q(t) = qi * (1 + b * Di * t)^(-1/b)`
- **Harmonic Decline (b=1)**: `q(t) = qi / (1 + Di * t)`

Where:
- `qi` = initial production rate
- `Di` = initial decline rate
- `b` = decline exponent
- `t` = time

In [0]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pyspark.sql import SparkSession

from petrinex.config import DotConfig
from petrinex.forecast import (
    forecast_spark_workflow
)

In [0]:
config = DotConfig("config.yaml")

In [None]:
# Initialize Spark session
spark = SparkSession.builder.appName("PetrinexForecasting").getOrCreate()


In [0]:
# Define the input table (replace with your actual silver table)
input_table = f"{config.catalog}.{config.schema}.ngl_silver"

# Check if the table exists and get basic info
try:
    input_df = spark.table(input_table)
    row_count = input_df.count()
    well_count = input_df.select("WellID").distinct().count()
    print(f"Input table: {input_table}")
    print(f"Total rows: {row_count:,}")
    print(f"Unique wells: {well_count:,}")
    
    # Show sample of the data
    print("\nSample data:")
    input_df.show(5)
except Exception as e:
    print(f"Error accessing table {input_table}: {e}")
    print("Please ensure the silver table exists or use alternative data source")

In [0]:
# Run Spark forecast workflow with memory-efficient batching
output_tables = forecast_spark_workflow(
    spark=spark,
    config=config,
    input_table=input_table,
    # Optional overrides:
    # batch_size=300,  # Smaller batches for very large datasets
    # forecast_months=24,  # Override config default
    # curve_type='auto',
    # min_r_squared=0.5
)

In [0]:
# The workflow automatically uses batching for memory efficiency
# Adjust batch_size parameter above if you need smaller batches for very large datasets


In [0]:
# Verify and explore the output tables
if 'output_tables' in locals():
    for production_type, tables in output_tables.items():
        print(f"=== {production_type} Tables ===")
        
        for table_type, table_name in tables.items():
            table_df = spark.table(table_name)
            row_count = table_df.count()
            print(f"{table_type.title()}: {table_name} ({row_count:,} rows)")
            
            if row_count > 0:
                table_df.show(3)
else:
    print("No output tables available.")


In [0]:
# Enhanced forecast analysis with ARPS parameters
if 'output_tables' in locals() and 'GasProduction' in output_tables:
    # Get summary table for gas production
    summary_table_name = output_tables['GasProduction']['summary']
    summary_df = spark.table(summary_table_name)
    
    # Convert to Pandas for analysis and plotting
    summary_pd = summary_df.toPandas()
    
    print(f"Gas Production Forecast Summary:")
    print(f"Wells forecast: {len(summary_pd)}")
    print(f"Forecast date: {summary_pd['ForecastDate'].iloc[0]}")
    print(f"Average R-squared: {summary_pd['RSquared'].mean():.3f}")
    print(f"Average AIC: {summary_pd['AIC'].mean():.1f}")
    print(f"Curve types: {summary_pd['CurveType'].value_counts().to_dict()}")
    
    # Display ARPS parameter ranges
    print(f"\nARPS Parameter Ranges:")
    print(f"Initial Rate (qi): {summary_pd['InitialRate_qi'].min():.1f} - {summary_pd['InitialRate_qi'].max():.1f}")
    print(f"Decline Rate (di): {summary_pd['DeclineRate_di'].min():.3f} - {summary_pd['DeclineRate_di'].max():.3f}")
    
    # Show data coverage
    print(f"\nData Coverage:")
    print(f"Earliest data: {summary_pd['HistoricalDataMinDate'].min()}")
    print(f"Latest data: {summary_pd['HistoricalDataMaxDate'].max()}")
    print(f"Average data points per well: {summary_pd['DataPointsUsed'].mean():.1f}")
    
    # Enhanced plotting
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # R-squared histogram
    ax1.hist(summary_pd['RSquared'], bins=20, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('R-squared')
    ax1.set_ylabel('Number of Wells')
    ax1.set_title('Forecast Quality Distribution')
    ax1.grid(True, alpha=0.3)
    
    # Curve type pie chart
    curve_counts = summary_pd['CurveType'].value_counts()
    ax2.pie(curve_counts.values, labels=curve_counts.index, autopct='%1.1f%%')
    ax2.set_title('Decline Curve Types')
    
    # Initial rate vs decline rate scatter
    ax3.scatter(summary_pd['InitialRate_qi'], summary_pd['DeclineRate_di'], 
                c=summary_pd['RSquared'], cmap='viridis', alpha=0.6)
    ax3.set_xlabel('Initial Rate (qi)')
    ax3.set_ylabel('Decline Rate (di)')
    ax3.set_title('ARPS Parameters (colored by R²)')
    ax3.grid(True, alpha=0.3)
    
    # Data points vs R-squared
    ax4.scatter(summary_pd['DataPointsUsed'], summary_pd['RSquared'], alpha=0.6)
    ax4.set_xlabel('Data Points Used')
    ax4.set_ylabel('R-squared')
    ax4.set_title('Data Quality vs Forecast Quality')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Display enhanced summary table sample
    print(f"\nSample Enhanced Summary (first 3 wells):")
    display_cols = ['WellID', 'CurveType', 'RSquared', 'InitialRate_qi', 'DeclineRate_di', 
                   'DataPointsUsed', 'HistoricalDataMinDate', 'HistoricalDataMaxDate']
    summary_pd[display_cols].head(3)


In [0]:
# Tables are stored in Spark/Delta format for efficient querying
# Use spark.table('table_name') to access the full datasets

# Optional: Export samples for external analysis
# if 'output_tables' in locals():
#     for production_type, tables in output_tables.items():
#         summary_df = spark.table(tables['summary'])
#         summary_sample = summary_df.limit(1000).toPandas()
#         summary_sample.to_parquet(f"../fixtures/{production_type.lower()}_forecast_summary.parquet", index=False)


In [0]:
# Summary
if 'output_tables' in locals():
    total_production_types = len(output_tables)
    total_tables = sum(len(tables) for tables in output_tables.values())
    
    print("=== FORECASTING COMPLETE ===")
    print(f"Production types: {total_production_types}")
    print(f"Tables created: {total_tables}")
    print(f"Config: {config.catalog}.{config.schema}, {config.forecast.horizon_months}mo horizon, {config.forecast.min_months}mo minimum data")
    
    print(f"\nOutput tables:")
    for production_type, tables in output_tables.items():
        print(f"\n{production_type}:")
        for table_type, table_name in tables.items():
            count = spark.table(table_name).count()
            print(f"  {table_type}: {table_name} ({count:,} rows)")
    
    print(f"\nNext steps:")
    print(f"- Analyze forecast quality with summary tables")
    print(f"- Use forecast tables for production planning") 
    print(f"- Use combined tables for visualization")
    
else:
    print("No forecasting workflow completed.")
    print("Ensure input table exists and wells have sufficient historical data.")
