
## **Forecast Reconciliation for Probabilistic Models: Ensuring Consistency in Hierarchical Predictions**

---

### **Introduction**  

Forecasting is **never perfect**. Whether predicting demand in supply chains, climate patterns, or violent conflict, forecasts are generated at **multiple levels**. 
For example:  
🌍 **Country-Level Prediction:** How many people will will die in a country?  
🌐 **Grid-Level Predictions:** How many will will die in a sub-national region?  

A **common problem** occurs when the sum of regional forecasts **does not match** the national forecast. This happens because forecasts are made **independently** at each level.

📉 **Traditional Forecast Reconciliation Approaches:**  
✔ **Top-down approach:** Start from the country-level and allocate values downward.  
✔ **Bottom-up approach:** Sum regional predictions to get the national forecast.  
✔ **MinT (Minimum Trace Estimator):** Uses historical forecast errors to optimally adjust predictions.  

👎 **Problem with these methods?**  
They work for **point forecasts**, but **fail for probabilistic models** where we need to adjust **full distributions** rather than just mean values.

💡 **Our Solution:**  
- Adjust **each sample independently** rather than just the mean.  
- Use **Quadratic Programming (QP)** to make **the smallest possible adjustments** while enforcing the sum constraint.  
- Ensure **zero-inflation is preserved**, so areas with zero forecasted demand **stay zero**.  

---

### **How Does Our Method Work?**
Instead of applying **simple scaling**, we solve the following optimization problem **for each posterior draw**:  
$$
\min ||x'^{(s)} - x^{(s)}||^2
$$
subject to:  
$$
\sum x'^{(s)} = y^{(s)}, \quad x'^{(s)} \geq 0
$$
where:  
- $ x^{(s)} $ is the **original forecast for grid cells** in sample $ s $.  
- $ x'^{(s)} $ is the **adjusted forecast that preserves structure**.  
- $ y^{(s)} $ is the **country-level forecast for sample $ s $**.  

🛠 **How do we solve this?**  
We use **L-BFGS optimization** because:  
- It’s well-suited for **quadratic optimization**.  
- It efficiently handles **large-scale hierarchical adjustments**.  
- Unlike naïve scaling, it **minimizes distortion** in the probability distribution.  

---

### **📌 Why Not Just Scale Each Sample?**
A simple rescaling approach:  
$$
x_{i}^{(s)} = x_{i}^{(s)} \times \frac{y^{(s)}}{\sum x_{i}^{(s)}}
$$
❌ **Why this fails?**  
- It **alters the shape of the distribution**, affecting variance & skewness.  
- It **does not minimize distortion** in the adjusted samples.  

✅ **Our method ensures:**  
- **Sum consistency** (grid-level samples add up to the country total).  
- **Minimal adjustment** to the original distribution.  
- **Zero-inflation preservation** (areas with zero forecasted demand stay zero).  

---

### **Real-World Applications**

🚩 **Conflict Forecasting (our use case)**
- Forecasting both local and country level violence, while ensuring that that the sub-national are consistent with the national.  

📦 **Supply Chain Forecasting**  
- Predicting **regional demand** while ensuring forecasts match **national supply constraints**.  

🌎 **Climate Modeling**  
- Forecasting **rainfall or temperature** at the grid level while keeping consistency with national/global climate models.  

⚡ **Energy Demand Forecasting**  
- Regional electricity demand forecasts must match the total power generated in a country.  

📊 **Financial Forecasting**  
- Predicting **branch-level revenue** that must sum to a company's **total projected earnings**.  

---


### **📌 Key Takeaways**
✅ **Forecast reconciliation is essential for hierarchical predictions.**  
✅ **Traditional methods fail for probabilistic models—our approach adjusts distributions, not just means.**  
✅ **Quadratic optimization minimizes distortions while ensuring sum consistency.**  
✅ **Real-world applications include supply chains, climate modeling, and financial forecasting.**  

---

### **🔗 Further Reading**
- [**Optimal Forecast Reconciliation for Hierarchical and Grouped Time Series Through Trace Minimization** - Wickramasuriya et al. (2019)](https://www.tandfonline.com/doi/full/10.1080/01621459.2018.1448825?scroll=top&needAccess=true)   
- [**Probabilistic forecast reconciliation: Properties, evaluation and score optimisation** - Panagiotelis et al. (2023)](https://pdf.sciencedirectassets.com/271700/1-s2.0-S0377221722X00246/1-s2.0-S0377221722006087/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGIaCXVzLWVhc3QtMSJIMEYCIQCrbY591sgaJlV61FOFSAmmoAoEoiU7tz%2Bl3FiLG72z2AIhAP0JNngKZos058kZSv%2FyvGDNbEbLtOEzEh0kGUBhR33DKrwFCIv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1IgzO5pQVul35NKZP5HYqkAVq%2Fr9Iyos6sX%2F628uGYdJ5fMB241GsJjIxuL6WgV4sZ8Zpiw4s93lyQlC5p1I7a6XfrFjxBvG1U1aPd0%2F1uj4dkmgkRN8e56%2Fgr9A86JpYUbNBNj9I61v0TgrfKuMISUbdLGoV7k7DrwyxLeFjKCfqrMyfZvRGr3gtgYq%2FrdoovZTHnfeEcWhwO5pxDOaalI0dVPqf00WVpa1K6xQ0wDvVk%2FfD3w1ykhIbo3fj3shj6Wghw9IjVy9VP2CQoz98GU7wzNAxO74BqGcK7gU53l10ywdr6ph5V%2Bs7YTSKXMpOF1SAbzkEqTCwNkLWzrkoHKnDukrMsVCQUk7WGikr3bS5qpA1rOJjT2r3ICDBy%2Be9DcIbOiFl8UHsYCT5nf23OIHAcyC6FpJjjYuRjODcZTZ1XzKRq97kXXn7Y5GUMHdciQIJVfhxj%2FF1UczVaDamnOMHsSX5oXMmo7pfHZFov8%2FzyPm7GNwSLD53pyDsLXPX2DMke20yW7dhFdsS5ACjcjmbS6JzTNW%2BC6fWeRppjUr75Sa0TN1JrLyYg6PrqH0t3I9uG%2FMI%2F2b7Rpr9P%2FeWaGk3IsbhvO4JknOasoSEfTTWBPXUiXZnq73SxcPQLVlCRmYnLq%2BlbP4JE7KFywqq7%2BJU3SvqSLkFl1JLi70mScU%2BXm9r7l4%2BKLd5jkwntp6RjgJOytdXjQVw41q8ZwQUUetqgQKhoKPwMD%2FYY3LkejJ0vQ22xb0altmZj5Y9IKWrrYp8iFB4hxv%2BNPSf0CfU2OWOsfWDnFx17UmmnKElojHY9i%2FRO9uKUvf7u5McVMAlbXlpfNFYFiFh8eR3wDY%2FGQh8DPkfMocfSxDCODi%2F9wIXxyvxNaNrtLgtDcqwVolZSjDfp9G9BjqwASqN6br73MMJLUB3NUIqyKE6ho5B0B2wz8JhPG528%2FTCbRQQc8Yxm6MFbXch1Fn%2B0gQ4dFs1lVoOP752djZREH%2F9LAI%2BFQqij7mJlNNqL6Ywa5l4H3vgL88RcmpRzKKdVPmcDQJ4EuCDaSZWY7wkdB6agHEvfeFNpfPpIc37HVbEEUrBVZGxAj6x3XHjz3RyDQu7UWI95XuD4iNrbaQmbGNent4RID4JkIgMAjC6sN60&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20250218T095705Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTYZ6LLBK2A%2F20250218%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=f0f16c4d8c01bb8b262d8229861baa9e000646d4f53a89bba52bba6aee7f5759&hash=5b0c8ea061f5f0c0fd30d44f12357ca3551cbfdaad0c4e7dc192c3b64fc1c5bb&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0377221722006087&tid=spdf-f599685e-7b70-4151-b431-aa99f144caaf&sid=a704efa58f2f68465f195886e4e3cfbf89b1gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=14095c565d0a5158065256&rr=913d26a40cb1ac1d&cc=no)
- [**Optimal combination forecasts for hierarchical time series** - Hyndman et al. (2011)](https://pdf.sciencedirectassets.com/271708/1-s2.0-S0167947311X00063/1-s2.0-S0167947311000971/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEGEaCXVzLWVhc3QtMSJIMEYCIQC%2B%2FPYwzvbIgnlfar9k8mOzEg1E9n3NPbmmz195w6qdEQIhAL9rt19I2CZ584QC57bpHLk9cp9zzUbjv7hVnErdgCnwKrsFCIn%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEQBRoMMDU5MDAzNTQ2ODY1IgwlJz8cUZZmTZ7EENUqjwV%2Bp4w1lDzzlBbG%2F4Uw5F4%2BofSeWsJVWy4%2Bul%2B9ObhZY8oV8jUns%2BCVm%2F89nvw1bUO2%2BsUesQ1sSRYGH3TVYh%2FOyG%2BQVRz7Quzq8OjiAeMo%2B4kvORBxOGRDTC3SuzQCx2%2BYVQUkAwSMN6TvogZShvmF%2BGiywXl53edg3JbzootcShrKCKoHZPY2N%2F0dIRuXunup91k9h1p%2FWqhRu9ZORvJqeh9bVviwTK6mQmZb4wHolnvIXNdlb%2FCTwguOs0f9p7nvn9I9sCP%2F%2BWWYhpZU0pVvclkIcEqQKm8iLCKMD0xOqEAkdy1icWaj2znaF016807VTht4R%2F29Rj4CX4uO1QJ5%2BqItGDnwxn%2Bwn3YvVW53e3HXyxH1n7JG0EqEa%2BjaPJvoQhg63PK2U4hNIUUwflUCUkUBpak2NcZkBLdOTv1Hs2SOO9vXA%2BhaqTi8VBmNUncXXNn7YilGiKaVgzAvzr34PefEtyHU53XoISdErviJKr1TG5aOpjcjQ4uo6mkS38WOUvl9LVnce%2Fb3Q6qf6pqSgEqhmcYG5xF6CayPJsgDIYeT%2BHeiWqe0L%2FAH3I%2F%2FU%2FVUkV5XHF8Jo6ROwG%2B5rTA7BEBsDjFKz8D%2Ff2sbnRW%2Bkh3CkXKRp2ACu0axUDPBmkVEsOEpuWlf%2FdfOBL9C6Rz17go5p0cFh%2Bki%2Bg6dT3pzONAYXFyUuYtFTLma%2F1sAdcVrrQ8XAmUbi7hkmleVlQYwAXUX1LyLw5IScuPGU9Z%2F1pd3z9LhvvCvXo%2FqziLQPiI8Ofrs4CCeRpwIMEsAO3jGxu%2BmkZeO2yy7j7QdHZ7KGCKfEiWpovpvm1Te3qSlRaK6Gc69PbI5vaAYLGR8Q7z3W%2F6%2FmAodJoWtaxw5bNppMJeA0b0GOrABe6PmtYvGATp%2FKwT52%2FrKN3Sxqje7aZ7Bw2ZktGIBZiuIpU8%2FtSeLkC8FQoTPRNljL%2FDxZ7fJTIZSX87z4CmeuJGTYUGlkVrIFs%2Fqd6mgufcGYif9cJsYqST3N4ZlFFgT0B8ZK59EhL1pyH7WErR8themIDauPYsh6GA%2Bxhg2%2F%2FJUPBOY9UHIVuR%2FZxg5TakzXtTCibpIH7Ne5LDqMd54glCYGxqwp8KMwc0G9ICZZIo%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20250218T093446Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTYW6U5AVDX%2F20250218%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=27cf896580c2da2856e03389ca4d54d397bb5209645c351f9545ee8998f302b0&hash=b03cf7bf04fe9582cc1b457bd6b88bd779906beb4820a4fc9373d2addeb950a7&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0167947311000971&tid=spdf-d4a023c1-b806-4e1f-b2b9-4631c88d13de&sid=a704efa58f2f68465f195886e4e3cfbf89b1gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=14095c565d0a535b015202&rr=913d05f4db40ac17&cc=no)
- [**Forecasting: principles and practice 3rd ed** - Hyndman and Athanasopoulos (2018)](https://otexts.com/fpp3/rec-prob.html)



# Class

In [1]:
from forecast_reconciler import ForecastReconciler

import logging

# Configure logging (only needed if it's not already configured)
logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
reconciler = ForecastReconciler(device='cuda')

2025-02-18 12:40:15,730 - forecast_reconciler - INFO - Using device: cuda


In [4]:
reconciler.run_tests_probabilistic()

2025-02-18 12:40:21,452 - forecast_reconciler - INFO - 
🧪 Running Tests on Forecast Reconciliation...

2025-02-18 12:40:21,454 - forecast_reconciler - INFO - 🔹 Running Test: Basic Reconciliation
2025-02-18 12:40:24,153 - forecast_reconciler - INFO -    ✅ Completed in 1.154 sec
2025-02-18 12:40:24,171 - forecast_reconciler - INFO -    🔍 Max Sum Difference: 0.0007324219
2025-02-18 12:40:24,171 - forecast_reconciler - INFO -    🔍 Zeros Correctly Preserved: True

2025-02-18 12:40:24,172 - forecast_reconciler - INFO - 🔹 Running Test: All Zeros (Should Stay Zero)
2025-02-18 12:40:24,175 - forecast_reconciler - INFO -    ✅ Completed in 0.001 sec
2025-02-18 12:40:24,175 - forecast_reconciler - INFO -    🔍 Max Sum Difference: 0.0000000000
2025-02-18 12:40:24,176 - forecast_reconciler - INFO -    🔍 Zeros Correctly Preserved: True

2025-02-18 12:40:24,176 - forecast_reconciler - INFO - 🔹 Running Test: Extreme Skew (Right-Tailed)
2025-02-18 12:40:24,179 - forecast_reconciler - INFO -    ✅ Complete

In [None]:
import torch
import time

class ForecastReconciler:
    """
    A class for reconciling hierarchical forecasts at the country and grid levels.
    
    Supports:
    - Probabilistic forecast reconciliation (adjusting posterior samples).
    - Point estimate reconciliation (for deterministic forecasts).
    - Automatic validation tests for correctness.
    """

    def __init__(self, device=None):
        """
        Initializes the ForecastReconciler class.

        Args:
            device (str, optional): "cuda" for GPU acceleration, "cpu" otherwise. Defaults to auto-detect.
        """
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")

    def reconcile_probabilistic(self, grid_forecast_samples, country_forecast_samples, lr=0.01, max_iters=500, tol=1e-6):
        """
        Adjusts grid-level probabilistic forecasts to match the country-level forecasts using per-sample quadratic optimization.
        
        Args:
            grid_forecast_samples (torch.Tensor): (num_samples, num_grid_cells) posterior samples.
            country_forecast_samples (torch.Tensor): (num_samples,) country-level forecast samples.
        
        Returns:
            torch.Tensor: Adjusted grid forecasts with sum-matching per sample.
        """
        grid_forecast_samples = grid_forecast_samples.clone().float().to(self.device)
        country_forecast_samples = country_forecast_samples.clone().float().to(self.device)

        assert grid_forecast_samples.shape[0] == country_forecast_samples.shape[0], "Mismatch in sample count"

        # Identify nonzero values (to preserve zeros)
        mask_nonzero = grid_forecast_samples > 0
        nonzero_values = grid_forecast_samples.clone()
        nonzero_values[~mask_nonzero] = 0  # Ensure zero values remain unchanged

        # Initial proportional scaling
        sum_nonzero = nonzero_values.sum(dim=1, keepdim=True)
        scaling_factors = country_forecast_samples.view(-1, 1) / (sum_nonzero + 1e-8)
        adjusted_values = nonzero_values * scaling_factors
        adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

        # Optimizer (L-BFGS)
        optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

        def closure():
            optimizer.zero_grad()
            loss = torch.sum((adjusted_values - nonzero_values) ** 2)
            loss.backward()
            return loss

        optimizer.step(closure)

        # Projection Step: Enforce sum constraint
        with torch.no_grad():
            sum_adjusted = adjusted_values.sum(dim=1, keepdim=True)
            scaling_factors = country_forecast_samples.view(-1, 1) / (sum_adjusted + 1e-8)
            adjusted_values *= scaling_factors
            adjusted_values.clamp_(min=0)

        # Preserve zero values
        final_adjusted = grid_forecast_samples.clone()
        final_adjusted[mask_nonzero] = adjusted_values[mask_nonzero].detach()

        return final_adjusted

    def reconcile_point_forecasts(self, grid_forecast, country_forecast):
        """
        Adjusts grid-level point forecasts to match the country-level forecast via simple scaling.

        Args:
            grid_forecast (torch.Tensor): (num_grid_cells,) point forecast.
            country_forecast (float): Country-level forecast.
        
        Returns:
            torch.Tensor: Adjusted grid forecasts summing exactly to country_forecast.
        """
        grid_forecast = grid_forecast.clone().float().to(self.device)
        assert country_forecast >= 0, "Country forecast must be non-negative"

        # Avoid division by zero
        if grid_forecast.sum() == 0:
            return grid_forecast

        return grid_forecast * (country_forecast / grid_forecast.sum())

    def run_tests(self):
        """
        Runs a suite of validation tests to ensure correct implementation.
        """
        print("\n🧪 Running Tests on Forecast Reconciliation...\n")

        test_cases = [
            {"name": "Standard case", "num_samples": 1000, "num_grid_cells": 100, "zero_fraction": 0.7, "scaling_factor": 1.2},
            {"name": "All zeros (should remain zero)", "num_samples": 1000, "num_grid_cells": 100, "zero_fraction": 1.0, "scaling_factor": 1.2},
            {"name": "Extreme skew (heavy right tail)", "num_samples": 1000, "num_grid_cells": 100, "zero_fraction": 0.3, "scaling_factor": 10},
            {"name": "Large scale test", "num_samples": 10000, "num_grid_cells": 500, "zero_fraction": 0.5, "scaling_factor": 1.1}
        ]

        for test in test_cases:
            print(f"🔹 {test['name']}")

            num_samples = test["num_samples"]
            num_grid_cells = test["num_grid_cells"]

            zero_mask = torch.rand((num_samples, num_grid_cells)) < test["zero_fraction"]
            grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
            grid_forecast_samples[zero_mask] = 0

            country_forecast_samples = grid_forecast_samples.sum(dim=1) * test["scaling_factor"]
            grid_forecast_samples, country_forecast_samples = grid_forecast_samples.to(self.device), country_forecast_samples.to(self.device)

            start_time = time.time()
            adjusted_grid_forecast_samples = self.reconcile_probabilistic(grid_forecast_samples, country_forecast_samples)
            end_time = time.time()

            print(f"   ✅ Completed in {end_time - start_time:.3f} sec")

            # Validation checks
            sum_diff = torch.abs(adjusted_grid_forecast_samples.sum(dim=1) - country_forecast_samples).max().item()
            assert sum_diff < 1e-2, "❌ Sum constraint violated!"

            zero_preservation = torch.all(grid_forecast_samples == 0) == torch.all(adjusted_grid_forecast_samples == 0)
            assert zero_preservation, "❌ Zero-inflation not preserved!"

            print(f"   🔍 Max Sum Difference: {sum_diff:.10f}")
            print(f"   🔍 Zeros Correctly Preserved: {zero_preservation}\n")

        print("\n✅ All Tests Passed Successfully!")

# ✅ **Example Usage**
if __name__ == "__main__":
    reconciler = ForecastReconciler()

    # Generate test data
    num_samples, num_grid_cells = 1000, 100
    zero_mask = torch.rand((num_samples, num_grid_cells)) < 0.7
    grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
    grid_forecast_samples[zero_mask] = 0
    country_forecast_samples = grid_forecast_samples.sum(dim=1) * 1.2

    # Run reconciliation
    adjusted_forecasts = reconciler.reconcile_probabilistic(grid_forecast_samples, country_forecast_samples)
    
    print("✅ Adjusted Forecasts Ready!")
    reconciler.run_tests()  # Run all validation tests


# old

In [None]:
import torch

def reconcile_forecast_torch(grid_forecast: torch.Tensor, country_forecast: float, lr=0.01, max_iters=1000, tol=1e-6):
    """
    Adjusts grid-level forecasts in PyTorch while preserving zero values and ensuring
    the sum matches the country-level forecast. Final output is rounded to integers with high precision.
    
    Parameters:
        grid_forecast (torch.Tensor): Original grid-level forecasts (non-negative).
        country_forecast (float): The total forecast for the country.
        lr (float): Learning rate for optimization.
        max_iters (int): Maximum iterations for gradient descent.
        tol (float): Convergence tolerance for precise adjustment.

    Returns:
        torch.Tensor: Adjusted grid forecasts (integer counts) summing to country_forecast.
    """
    # Ensure input tensor is float and non-negative
    grid_forecast = grid_forecast.clone().float()
    assert torch.all(grid_forecast >= 0), "Grid forecasts must be non-negative"
    assert country_forecast >= 0, "Country forecast must be non-negative"

    # Identify nonzero elements
    mask_nonzero = grid_forecast > 0
    nonzero_values = grid_forecast[mask_nonzero]

    # If all values are zero, return unchanged
    if len(nonzero_values) == 0:
        return grid_forecast

    # Initial guess: proportional scaling for nonzero values
    adjusted_values = nonzero_values * (country_forecast / max(nonzero_values.sum(), 1e-8))
    adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

    # Use LBFGS optimizer for more precise optimization
    optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

    def closure():
        optimizer.zero_grad()
        loss = torch.sum((adjusted_values - nonzero_values) ** 2)
        loss.backward()
        return loss

    optimizer.step(closure)

    # Projection Step: Ensure sum constraint and non-negativity
    with torch.no_grad():
        scaling_factor = country_forecast / max(adjusted_values.sum(), 1e-8)
        adjusted_values *= scaling_factor  # Scale to match country total
        adjusted_values.clamp_(min=0)  # Ensure non-negativity

    # **Step 2: Round values with precise sum adjustment**
    with torch.no_grad():
        rounded_values = adjusted_values.round()  # Round to nearest integer

        # Compute rounding error
        rounding_error = int(country_forecast - rounded_values.sum())

        if rounding_error != 0:
            # Compute fractional parts
            fractional_parts = adjusted_values - adjusted_values.floor()

            # Sort indices by largest fractional part to minimize distortion
            sorted_indices = torch.argsort(fractional_parts, descending=True)

            # Redistribute rounding error
            for i in range(abs(rounding_error)):
                idx = sorted_indices[i % len(sorted_indices)]
                if rounding_error > 0:
                    rounded_values[idx] += 1  # Add 1 to highest fractional value
                else:
                    rounded_values[idx] -= 1  # Subtract 1 from lowest fractional value

    # Create final adjusted forecast (preserve zero values)
    adjusted_forecast = grid_forecast.clone()
    adjusted_forecast[mask_nonzero] = rounded_values.detach()

    return adjusted_forecast.long()  # Convert to integer tensor

# ✅ **Test with a Large Zero-Inflated Right-Skewed Distribution**
#torch.manual_seed(42)  # For reproducibility

# Generate a highly zero-inflated dataset
num_grid_cells = 100  # Large number of grid cells
zero_mask = torch.rand(num_grid_cells) < 0.7  # 70% zeros
grid_forecast_torch = torch.randint(1, 100, (num_grid_cells,), dtype=torch.float32)  # Right-skewed
grid_forecast_torch[zero_mask] = 0  # Apply zero-inflation

country_forecast_torch = grid_forecast_torch.sum().item() * 1.2  # 20% over-forecast at the country level

# Run reconciliation
adjusted_grid_forecast_torch = reconcile_forecast_torch(grid_forecast_torch, country_forecast_torch)

# ✅ **Results**
print("\n🔹 Original Grid Forecasts:", grid_forecast_torch.numpy())
print("\n🔹 Adjusted Grid Forecasts:", adjusted_grid_forecast_torch.numpy())
print("\n🔹 Sum of Adjusted Forecasts:", adjusted_grid_forecast_torch.sum().item())  # Should match country_forecast
print("\n🔹 Country Forecast:", country_forecast_torch)  # Should match sum of adjusted forecasts


# Distribution:

In [None]:
import torch

def reconcile_forecast_samples_torch(grid_forecast_samples, country_forecast_samples, 
                                     lr=0.01, max_iters=500, tol=1e-6):
    """
    Reconciles grid-level forecast samples to match the country-level forecasts.
    
    Parameters:
        grid_forecast_samples (torch.Tensor): Shape (num_samples, num_grid_cells)
        country_forecast_samples (torch.Tensor): Shape (num_samples,)
    
    Returns:
        torch.Tensor: Adjusted grid forecasts maintaining sum consistency.
    """
    grid_forecast_samples = grid_forecast_samples.clone().float()
    country_forecast_samples = country_forecast_samples.clone().float()

    # Ensure forecasts are non-negative
    assert torch.all(grid_forecast_samples >= 0), "Grid forecasts must be non-negative"
    assert torch.all(country_forecast_samples >= 0), "Country forecasts must be non-negative"

    mask_nonzero = grid_forecast_samples > 0  # Mask nonzero values
    nonzero_values = grid_forecast_samples.clone()
    nonzero_values[~mask_nonzero] = 0

    # Initial proportional scaling
    sum_nonzero = nonzero_values.sum(dim=1, keepdim=True)
    scaling_factors = country_forecast_samples.view(-1, 1) / (sum_nonzero + 1e-8)
    adjusted_values = nonzero_values * scaling_factors
    adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

    # Optimizer: L-BFGS
    optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

    def closure():
        optimizer.zero_grad()
        loss = torch.sum((adjusted_values - nonzero_values) ** 2)
        loss.backward()
        return loss

    optimizer.step(closure)

    # Projection Step: Final sum correction and non-negativity enforcement
    with torch.no_grad():
        sum_adjusted = adjusted_values.sum(dim=1, keepdim=True)
        scaling_factors = country_forecast_samples.view(-1, 1) / (sum_adjusted + 1e-8)
        adjusted_values *= scaling_factors
        adjusted_values.clamp_(min=0)

    final_adjusted = grid_forecast_samples.clone()
    final_adjusted[mask_nonzero] = adjusted_values[mask_nonzero].detach()

    return final_adjusted


In [None]:

# ✅ **Comprehensive Testing**
def run_tests():
    torch.manual_seed(42)  # For reproducibility
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("\n🧪 Running Tests on Forecast Reconciliation...\n")

    test_cases = [
        {
            "name": "Standard case (1000 samples, 100 grid cells)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 0.7,
            "scaling_factor": 1.2
        },
        {
            "name": "All zeros (should remain zero)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 1.0,  # All zeros
            "scaling_factor": 1.2
        },
        {
            "name": "Extreme skew (heavy right tail)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 0.3,  # Some zeros
            "scaling_factor": 10  # Extreme upscaling
        },
        {
            "name": "Large scale (10,000 samples, 500 grid cells)",
            "num_samples": 10000,
            "num_grid_cells": 500,
            "zero_fraction": 0.5,
            "scaling_factor": 1.1
        }
    ]

    for test in test_cases:
        print(f"🔹 {test['name']}")

        num_samples = test["num_samples"]
        num_grid_cells = test["num_grid_cells"]

        zero_mask = torch.rand((num_samples, num_grid_cells)) < test["zero_fraction"]
        grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
        grid_forecast_samples[zero_mask] = 0  # Apply zero-inflation

        country_forecast_samples = grid_forecast_samples.sum(dim=1) * test["scaling_factor"]

        # Move tensors to GPU if available
        grid_forecast_samples = grid_forecast_samples.to(device)
        country_forecast_samples = country_forecast_samples.to(device)

        import time
        start_time = time.time()

        # Run reconciliation
        adjusted_grid_forecast_samples = reconcile_forecast_samples_torch(grid_forecast_samples, country_forecast_samples)

        end_time = time.time()
        print(f"   ✅ Completed in {end_time - start_time:.3f} sec")

        # **Validation Checks**
        sum_diff = torch.abs(adjusted_grid_forecast_samples.sum(dim=1) - country_forecast_samples).max().item()
        assert sum_diff < 1e-2, "❌ Sum constraint violated!"

        zero_preservation = torch.all(grid_forecast_samples == 0) == torch.all(adjusted_grid_forecast_samples == 0)
        assert zero_preservation, "❌ Zero-inflation not preserved!"

        print(f"   🔍 Max Sum Difference: {sum_diff:.6f}")
        print(f"   🔍 Zeros Correctly Preserved: {zero_preservation}\n")

    print("\n✅ All Tests Passed Successfully!")



In [None]:
import torch
import time

# **Step 1: Set up device (CPU/GPU)**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Running on: {device}")

# **Step 2: Generate Posterior Forecast Samples**
num_samples = 1000  # 1000 posterior draws
num_grid_cells = 100  # 100 spatial grid cells

torch.manual_seed(42)  # Ensure reproducibility

# Generate zero-inflated forecast samples (70% zeros)
zero_mask = torch.rand((num_samples, num_grid_cells)) < 0.7  
grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
grid_forecast_samples[zero_mask] = 0  # Apply zero-inflation

# Compute country-level forecasts (120% over-forecast)
country_forecast_samples = grid_forecast_samples.sum(dim=1) * 1.2

# Move data to GPU if available
grid_forecast_samples = grid_forecast_samples.to(device)
country_forecast_samples = country_forecast_samples.to(device)

# **Step 3: Run Reconciliation and tests**
print("\n🔄 Adjusting Posterior Samples...")
start_time = time.time()

adjusted_grid_forecast_samples = reconcile_forecast_samples_torch(grid_forecast_samples, country_forecast_samples)

end_time = time.time()
print(f"✅ Adjustment Completed in {end_time - start_time:.3f} seconds!")

# Run tests
run_tests()

# **Step 4: Verify Results**
print("\n🔍 Sample Results:")
print("Original Grid Forecast Sum (First 5 Samples):", grid_forecast_samples.sum(dim=1)[:5].cpu().numpy())
print("Adjusted Grid Forecast Sum (First 5 Samples):", adjusted_grid_forecast_samples.sum(dim=1)[:5].cpu().numpy())
print("Country Forecasts (First 5 Samples):", country_forecast_samples[:5].cpu().numpy())

# **Step 5: Validations**
max_sum_diff = torch.abs(adjusted_grid_forecast_samples.sum(dim=1) - country_forecast_samples).max().item()
assert max_sum_diff < 1e-2, "❌ Sum constraint violated!"

zero_preserved = torch.all(grid_forecast_samples == 0) == torch.all(adjusted_grid_forecast_samples == 0)
assert zero_preserved, "❌ Zero-inflation not preserved!"

print(f"\n🎯 Final Checks:")
print(f"✅ Max Sum Difference: {max_sum_diff:.10f}")
print(f"✅ Zero Values Correctly Preserved: {zero_preserved}")
print("\n🎉 Success! Posterior reconciliation is working perfectly!\n")