In [37]:
import torch

def reconcile_forecast_torch(grid_forecast: torch.Tensor, country_forecast: float, lr=0.01, max_iters=1000, tol=1e-6):
    """
    Adjusts grid-level forecasts in PyTorch while preserving zero values and ensuring
    the sum matches the country-level forecast. Final output is rounded to integers with high precision.
    
    Parameters:
        grid_forecast (torch.Tensor): Original grid-level forecasts (non-negative).
        country_forecast (float): The total forecast for the country.
        lr (float): Learning rate for optimization.
        max_iters (int): Maximum iterations for gradient descent.
        tol (float): Convergence tolerance for precise adjustment.

    Returns:
        torch.Tensor: Adjusted grid forecasts (integer counts) summing to country_forecast.
    """
    # Ensure input tensor is float and non-negative
    grid_forecast = grid_forecast.clone().float()
    assert torch.all(grid_forecast >= 0), "Grid forecasts must be non-negative"
    assert country_forecast >= 0, "Country forecast must be non-negative"

    # Identify nonzero elements
    mask_nonzero = grid_forecast > 0
    nonzero_values = grid_forecast[mask_nonzero]

    # If all values are zero, return unchanged
    if len(nonzero_values) == 0:
        return grid_forecast

    # Initial guess: proportional scaling for nonzero values
    adjusted_values = nonzero_values * (country_forecast / max(nonzero_values.sum(), 1e-8))
    adjusted_values = adjusted_values.clone().detach().requires_grad_(True)

    # Use LBFGS optimizer for more precise optimization
    optimizer = torch.optim.LBFGS([adjusted_values], lr=lr, max_iter=max_iters, tolerance_grad=tol)

    def closure():
        optimizer.zero_grad()
        loss = torch.sum((adjusted_values - nonzero_values) ** 2)
        loss.backward()
        return loss

    optimizer.step(closure)

    # Projection Step: Ensure sum constraint and non-negativity
    with torch.no_grad():
        scaling_factor = country_forecast / max(adjusted_values.sum(), 1e-8)
        adjusted_values *= scaling_factor  # Scale to match country total
        adjusted_values.clamp_(min=0)  # Ensure non-negativity

    # **Step 2: Round values with precise sum adjustment**
    with torch.no_grad():
        rounded_values = adjusted_values.round()  # Round to nearest integer

        # Compute rounding error
        rounding_error = int(country_forecast - rounded_values.sum())

        if rounding_error != 0:
            # Compute fractional parts
            fractional_parts = adjusted_values - adjusted_values.floor()

            # Sort indices by largest fractional part to minimize distortion
            sorted_indices = torch.argsort(fractional_parts, descending=True)

            # Redistribute rounding error
            for i in range(abs(rounding_error)):
                idx = sorted_indices[i % len(sorted_indices)]
                if rounding_error > 0:
                    rounded_values[idx] += 1  # Add 1 to highest fractional value
                else:
                    rounded_values[idx] -= 1  # Subtract 1 from lowest fractional value

    # Create final adjusted forecast (preserve zero values)
    adjusted_forecast = grid_forecast.clone()
    adjusted_forecast[mask_nonzero] = rounded_values.detach()

    return adjusted_forecast.long()  # Convert to integer tensor

# ✅ **Test with a Large Zero-Inflated Right-Skewed Distribution**
#torch.manual_seed(42)  # For reproducibility

# Generate a highly zero-inflated dataset
num_grid_cells = 100  # Large number of grid cells
zero_mask = torch.rand(num_grid_cells) < 0.7  # 70% zeros
grid_forecast_torch = torch.randint(1, 100, (num_grid_cells,), dtype=torch.float32)  # Right-skewed
grid_forecast_torch[zero_mask] = 0  # Apply zero-inflation

country_forecast_torch = grid_forecast_torch.sum().item() * 1.2  # 20% over-forecast at the country level

# Run reconciliation
adjusted_grid_forecast_torch = reconcile_forecast_torch(grid_forecast_torch, country_forecast_torch)

# ✅ **Results**
print("\n🔹 Original Grid Forecasts:", grid_forecast_torch.numpy())
print("\n🔹 Adjusted Grid Forecasts:", adjusted_grid_forecast_torch.numpy())
print("\n🔹 Sum of Adjusted Forecasts:", adjusted_grid_forecast_torch.sum().item())  # Should match country_forecast
print("\n🔹 Country Forecast:", country_forecast_torch)  # Should match sum of adjusted forecasts



🔹 Original Grid Forecasts: [ 0. 71.  0.  0.  0. 73.  0.  0.  0. 53.  0.  0.  0.  0. 14.  0.  0.  0.
 37.  0.  0.  0. 78.  0. 31.  0. 30.  0.  0. 99. 62.  0.  0.  0.  0.  0.
  0.  0. 58.  0.  0.  0.  0.  0.  0.  0. 22.  0.  0.  0.  0.  0.  0.  0.
 29.  0.  4.  0. 22.  0.  0.  0. 35.  0.  0.  0.  0. 81.  0.  0.  0.  0.
  0. 55. 51. 80.  0.  0. 70.  0.  0.  0. 58. 79.  0. 48.  0.  0.  0.  0.
  0.  0. 12. 36.  0.  0.  0.  0.  0.  0.]

🔹 Adjusted Grid Forecasts: [  0  85   0   0   0  88   0   0   0  64   0   0   0   0  17   0   0   0
  44   0   0   0  94   0  37   0  36   0   0 119  74   0   0   0   0   0
   0   0  70   0   0   0   0   0   0   0  26   0   0   0   0   0   0   0
  35   0   5   0  26   0   0   0  42   0   0   0   0  97   0   0   0   0
   0  66  61  96   0   0  84   0   0   0  70  95   0  58   0   0   0   0
   0   0  14  43   0   0   0   0   0   0]

🔹 Sum of Adjusted Forecasts: 1546

🔹 Country Forecast: 1545.6
