In [1]:
import torch

def reconcile_forecast_torch(grid_forecast: torch.Tensor, country_forecast: float, lr=0.01, max_iters=1000, tol=1e-6):
    """
    Adjusts grid-level forecasts in PyTorch while preserving zero values and ensuring
    the sum matches the country-level forecast. Final output is rounded to integers with high precision.
    
    Parameters:
        grid_forecast (torch.Tensor): Original grid-level forecasts (non-negative).
        country_forecast (float): The total forecast for the country.
        lr (float): Learning rate for optimization.
        max_iters (int): Maximum iterations for gradient descent.
        tol (float): Convergence tolerance for precise adjustment.

    Returns:
        torch.Tensor: Adjusted grid forecasts (integer counts) summing to country_forecast.
    """
    # Ensure input tensor is float and non-negative
    grid_forecast = grid_forecast.clone().float()
    assert torch.all(grid_forecast >= 0), "Grid forecasts must be non-negative"
    assert country_forecast >= 0, "Country forecast must be non-negative"

    # Identify nonzero elements
    mask_nonzero = grid_forecast > 0
    nonzero_values = grid_forecast[mask_nonzero]

    # If all values are zero, return unchanged
    if len(nonzero_values) == 0:
        return grid_forecast

    # Initial guess: proportional scaling for nonzero values
    adjusted_values = nonzero_values * (country_forecast / max(nonzero_values.sum(), 1e-8))
    adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

    # Use LBFGS optimizer for more precise optimization
    optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

    def closure():
        optimizer.zero_grad()
        loss = torch.sum((adjusted_values - nonzero_values) ** 2)
        loss.backward()
        return loss

    optimizer.step(closure)

    # Projection Step: Ensure sum constraint and non-negativity
    with torch.no_grad():
        scaling_factor = country_forecast / max(adjusted_values.sum(), 1e-8)
        adjusted_values *= scaling_factor  # Scale to match country total
        adjusted_values.clamp_(min=0)  # Ensure non-negativity

    # **Step 2: Round values with precise sum adjustment**
    with torch.no_grad():
        rounded_values = adjusted_values.round()  # Round to nearest integer

        # Compute rounding error
        rounding_error = int(country_forecast - rounded_values.sum())

        if rounding_error != 0:
            # Compute fractional parts
            fractional_parts = adjusted_values - adjusted_values.floor()

            # Sort indices by largest fractional part to minimize distortion
            sorted_indices = torch.argsort(fractional_parts, descending=True)

            # Redistribute rounding error
            for i in range(abs(rounding_error)):
                idx = sorted_indices[i % len(sorted_indices)]
                if rounding_error > 0:
                    rounded_values[idx] += 1  # Add 1 to highest fractional value
                else:
                    rounded_values[idx] -= 1  # Subtract 1 from lowest fractional value

    # Create final adjusted forecast (preserve zero values)
    adjusted_forecast = grid_forecast.clone()
    adjusted_forecast[mask_nonzero] = rounded_values.detach()

    return adjusted_forecast.long()  # Convert to integer tensor

# ✅ **Test with a Large Zero-Inflated Right-Skewed Distribution**
#torch.manual_seed(42)  # For reproducibility

# Generate a highly zero-inflated dataset
num_grid_cells = 100  # Large number of grid cells
zero_mask = torch.rand(num_grid_cells) < 0.7  # 70% zeros
grid_forecast_torch = torch.randint(1, 100, (num_grid_cells,), dtype=torch.float32)  # Right-skewed
grid_forecast_torch[zero_mask] = 0  # Apply zero-inflation

country_forecast_torch = grid_forecast_torch.sum().item() * 1.2  # 20% over-forecast at the country level

# Run reconciliation
adjusted_grid_forecast_torch = reconcile_forecast_torch(grid_forecast_torch, country_forecast_torch)

# ✅ **Results**
print("\n🔹 Original Grid Forecasts:", grid_forecast_torch.numpy())
print("\n🔹 Adjusted Grid Forecasts:", adjusted_grid_forecast_torch.numpy())
print("\n🔹 Sum of Adjusted Forecasts:", adjusted_grid_forecast_torch.sum().item())  # Should match country_forecast
print("\n🔹 Country Forecast:", country_forecast_torch)  # Should match sum of adjusted forecasts



🔹 Original Grid Forecasts: [ 0.  0.  0.  0.  0.  5.  0.  0.  0.  0.  0. 58.  0.  0.  0. 86.  0.  0.
  0.  0. 82.  0.  0.  0.  0. 63. 45.  0.  0. 28.  0.  0.  0. 75.  0. 50.
 13.  0.  0.  0. 39.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0. 54.  4. 53.  0. 64. 55.  0.  0.  0. 82.  0.  0. 98.  0.  0.
  0.  0.  0.  0.  0.  0. 43.  0.  0. 14.  0.  0.  0. 33. 97.  0. 66. 68.
 75.  0.  0.  0.  0.  0.  0.  9.  0.  0.]

🔹 Adjusted Grid Forecasts: [  0   0   0   0   0   6   0   0   0   0   0  70   0   0   0 103   0   0
   0   0  98   0   0   0   0  76  54   0   0  34   0   0   0  90   0  60
  16   0   0   0  46   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0  64   5  64   0  76  66   0   0   0  98   0   0 118   0   0
   0   0   0   0   0   0  52   0   0  17   0   0   0  40 116   0  79  82
  90   0   0   0   0   0   0  11   0   0]

🔹 Sum of Adjusted Forecasts: 1631

🔹 Country Forecast: 1630.8


# With distribution:

In [2]:
import torch

def reconcile_forecast_samples_torch(grid_forecast_samples: torch.Tensor, country_forecast_samples: torch.Tensor, 
                                     lr=0.01, max_iters=500, tol=1e-6):
    """
    Adjusts each posterior sample of grid-level forecasts to ensure sum consistency with the 
    corresponding country-level forecast, using per-sample quadratic optimization.
    
    Parameters:
        grid_forecast_samples (torch.Tensor): Tensor of shape (num_samples, num_grid_cells) 
                                              containing posterior samples for grid forecasts.
        country_forecast_samples (torch.Tensor): Tensor of shape (num_samples,) containing 
                                                 posterior samples for the country-level forecast.
        lr (float): Learning rate for optimization.
        max_iters (int): Maximum iterations for gradient descent.
        tol (float): Convergence tolerance for precise adjustment.

    Returns:
        torch.Tensor: Adjusted grid forecasts with sum-matching for each posterior sample.
    """
    # Ensure input is float and non-negative
    grid_forecast_samples = grid_forecast_samples.clone().float()
    country_forecast_samples = country_forecast_samples.clone().float()
    
    assert torch.all(grid_forecast_samples >= 0), "Grid forecasts must be non-negative"
    assert torch.all(country_forecast_samples >= 0), "Country forecasts must be non-negative"
    assert grid_forecast_samples.shape[0] == country_forecast_samples.shape[0], "Mismatch in sample count"

    # Identify nonzero values for each sample (boolean mask)
    mask_nonzero = grid_forecast_samples > 0  # Shape: (num_samples, num_grid_cells)

    # Extract nonzero values only for optimization
    nonzero_values = grid_forecast_samples.clone()
    nonzero_values[~mask_nonzero] = 0  # Set zero values explicitly

    # Initial proportional scaling for nonzero values
    sum_nonzero = nonzero_values.sum(dim=1, keepdim=True)  # Sum per sample
    scaling_factors = country_forecast_samples.view(-1, 1) / (sum_nonzero + 1e-8)
    adjusted_values = nonzero_values * scaling_factors  # Proportional scaling
    
    # Ensure requires_grad for optimization
    adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

    # Optimizer: L-BFGS (better for constrained optimization)
    optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

    def closure():
        optimizer.zero_grad()
        loss = torch.sum((adjusted_values - nonzero_values) ** 2)  # Minimize distortion
        loss.backward()
        return loss

    optimizer.step(closure)

    # Projection Step: Ensure sum constraint and non-negativity
    with torch.no_grad():
        sum_adjusted = adjusted_values.sum(dim=1, keepdim=True)
        scaling_factors = country_forecast_samples.view(-1, 1) / (sum_adjusted + 1e-8)
        adjusted_values *= scaling_factors  # Scale to match country total
        adjusted_values.clamp_(min=0)  # Ensure non-negativity

    # Corrected Assignment: Use Masked Indexing Properly
    final_adjusted = grid_forecast_samples.clone()
    final_adjusted[mask_nonzero] = adjusted_values[mask_nonzero].detach()

    return final_adjusted


# ✅ **Comprehensive Testing**
def run_tests():
    torch.manual_seed(42)  # For reproducibility
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("\n🧪 Running Tests on Forecast Reconciliation...\n")

    test_cases = [
        {
            "name": "Standard case (1000 samples, 100 grid cells)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 0.7,
            "scaling_factor": 1.2
        },
        {
            "name": "All zeros (should remain zero)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 1.0,  # All zeros
            "scaling_factor": 1.2
        },
        {
            "name": "Extreme skew (heavy right tail)",
            "num_samples": 1000,
            "num_grid_cells": 100,
            "zero_fraction": 0.3,  # Some zeros
            "scaling_factor": 10  # Extreme upscaling
        },
        {
            "name": "Large scale (10,000 samples, 500 grid cells)",
            "num_samples": 10000,
            "num_grid_cells": 500,
            "zero_fraction": 0.5,
            "scaling_factor": 1.1
        }
    ]

    for test in test_cases:
        print(f"🔹 {test['name']}")

        num_samples = test["num_samples"]
        num_grid_cells = test["num_grid_cells"]

        zero_mask = torch.rand((num_samples, num_grid_cells)) < test["zero_fraction"]
        grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
        grid_forecast_samples[zero_mask] = 0  # Apply zero-inflation

        country_forecast_samples = grid_forecast_samples.sum(dim=1) * test["scaling_factor"]

        # Move tensors to GPU if available
        grid_forecast_samples = grid_forecast_samples.to(device)
        country_forecast_samples = country_forecast_samples.to(device)

        import time
        start_time = time.time()

        # Run reconciliation
        adjusted_grid_forecast_samples = reconcile_forecast_samples_torch(grid_forecast_samples, country_forecast_samples)

        end_time = time.time()
        print(f"   ✅ Completed in {end_time - start_time:.3f} sec")

        # **Validation Checks**
        sum_diff = torch.abs(adjusted_grid_forecast_samples.sum(dim=1) - country_forecast_samples).max().item()
        assert sum_diff < 1e-2, "❌ Sum constraint violated!"

        zero_preservation = torch.all(grid_forecast_samples == 0) == torch.all(adjusted_grid_forecast_samples == 0)
        assert zero_preservation, "❌ Zero-inflation not preserved!"

        print(f"   🔍 Max Sum Difference: {sum_diff:.6f}")
        print(f"   🔍 Zeros Correctly Preserved: {zero_preservation}\n")

    print("\n✅ All Tests Passed Successfully!")


# Run tests
run_tests()



🧪 Running Tests on Forecast Reconciliation...

🔹 Standard case (1000 samples, 100 grid cells)
   ✅ Completed in 0.445 sec
   🔍 Max Sum Difference: 0.000366
   🔍 Zeros Correctly Preserved: True

🔹 All zeros (should remain zero)
   ✅ Completed in 0.002 sec
   🔍 Max Sum Difference: 0.000000
   🔍 Zeros Correctly Preserved: True

🔹 Extreme skew (heavy right tail)
   ✅ Completed in 0.003 sec
   🔍 Max Sum Difference: 0.000000
   🔍 Zeros Correctly Preserved: True

🔹 Large scale (10,000 samples, 500 grid cells)
   ✅ Completed in 0.010 sec
   🔍 Max Sum Difference: 0.002930
   🔍 Zeros Correctly Preserved: True


✅ All Tests Passed Successfully!


In [6]:
import torch
import time

# **Step 1: Set up device (CPU/GPU)**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Running on: {device}")

# **Step 2: Generate Posterior Forecast Samples**
num_samples = 1000  # 1000 posterior draws
num_grid_cells = 100  # 100 spatial grid cells

torch.manual_seed(42)  # Ensure reproducibility

# Generate zero-inflated forecast samples (70% zeros)
zero_mask = torch.rand((num_samples, num_grid_cells)) < 0.7  
grid_forecast_samples = torch.randint(1, 100, (num_samples, num_grid_cells), dtype=torch.float32)
grid_forecast_samples[zero_mask] = 0  # Apply zero-inflation

# Compute country-level forecasts (120% over-forecast)
country_forecast_samples = grid_forecast_samples.sum(dim=1) * 1.2

# Move data to GPU if available
grid_forecast_samples = grid_forecast_samples.to(device)
country_forecast_samples = country_forecast_samples.to(device)

# **Step 3: Run Reconciliation**
print("\n🔄 Adjusting Posterior Samples...")
start_time = time.time()

adjusted_grid_forecast_samples = reconcile_forecast_samples_torch(grid_forecast_samples, country_forecast_samples)

end_time = time.time()
print(f"✅ Adjustment Completed in {end_time - start_time:.3f} seconds!")

# **Step 4: Verify Results**
print("\n🔍 Sample Results:")
print("Original Grid Forecast Sum (First 5 Samples):", grid_forecast_samples.sum(dim=1)[:5].cpu().numpy())
print("Adjusted Grid Forecast Sum (First 5 Samples):", adjusted_grid_forecast_samples.sum(dim=1)[:5].cpu().numpy())
print("Country Forecasts (First 5 Samples):", country_forecast_samples[:5].cpu().numpy())

# **Step 5: Validations**
max_sum_diff = torch.abs(adjusted_grid_forecast_samples.sum(dim=1) - country_forecast_samples).max().item()
assert max_sum_diff < 1e-2, "❌ Sum constraint violated!"

zero_preserved = torch.all(grid_forecast_samples == 0) == torch.all(adjusted_grid_forecast_samples == 0)
assert zero_preserved, "❌ Zero-inflation not preserved!"

print(f"\n🎯 Final Checks:")
print(f"✅ Max Sum Difference: {max_sum_diff:.10f}")
print(f"✅ Zero Values Correctly Preserved: {zero_preserved}")
print("\n🎉 Success! Posterior reconciliation is working perfectly!\n")

🚀 Running on: cuda

🔄 Adjusting Posterior Samples...
✅ Adjustment Completed in 0.008 seconds!

🔍 Sample Results:
Original Grid Forecast Sum (First 5 Samples): [1886.  983. 1558. 1235.  967.]
Adjusted Grid Forecast Sum (First 5 Samples): [2263.2002 1179.6001 1869.6001 1482.     1160.4   ]
Country Forecasts (First 5 Samples): [2263.2002 1179.6001 1869.6001 1482.     1160.4   ]

🎯 Final Checks:
✅ Max Sum Difference: 0.0003662109
✅ Zero Values Correctly Preserved: True

🎉 Success! Posterior reconciliation is working perfectly!

