# VN2 Submission: TimesFM 2.5 Quantile Forecasting Demo

**Purpose**: Demonstrate TimesFM 2.5's quantile head for VN2 inventory planning

**Approach**:
1. Load all 599 SKUs from VN2
2. Generate quantile forecasts (P10-P90) using TimesFM 2.5
3. Select cost-optimal quantile: Cu/(Cu+Co) = 1.0/(1.0+0.2) = 0.833
4. Convert to order quantities using base-stock policy
5. Generate submission CSV

**Note**: This is a DEMO of quantile forecasting capabilities, not competing with the official hierarchical Bayes submission.


## 1. Setup


In [None]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
from scipy.stats import norm
import warnings
warnings.filterwarnings('ignore')

# TimesFM
import torch
timesfm_path = Path("..").resolve()
if str(timesfm_path) not in sys.path:
    sys.path.insert(0, str(timesfm_path))
import timesfm

# VN2 policy helpers
vn2_path = Path("../../vn2inventory").resolve()
if str(vn2_path) not in sys.path:
    sys.path.insert(0, str(vn2_path))
from vn2inventory.policy import compute_orders

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


In [None]:
# VN2 Configuration
DATA_DIR = Path("../../vn2inventory/data").resolve()
SUB_DIR = Path("../../vn2inventory/submissions").resolve()
SUB_DIR.mkdir(exist_ok=True)

# Costs (from VN2 competition)
SHORTAGE_COST = 1.0
HOLDING_COST = 0.2
LEAD_WEEKS = 2
REVIEW_WEEKS = 1
PROTECTION_WEEKS = LEAD_WEEKS + REVIEW_WEEKS  # 3 weeks

# Critical fractile (Newsvendor optimal service level)
CRITICAL_RATIO = SHORTAGE_COST / (SHORTAGE_COST + HOLDING_COST)
print(f"\\n📊 VN2 Cost Structure:")
print(f"  Shortage cost (Cu): ${SHORTAGE_COST}")
print(f"  Holding cost (Co): ${HOLDING_COST}")
print(f"  Critical fractile: {CRITICAL_RATIO:.4f} ({CRITICAL_RATIO*100:.2f}%)")
print(f"  Protection period: {PROTECTION_WEEKS} weeks")


## 2. Load VN2 Data


In [None]:
# Load sales history
sales_df = pd.read_csv(DATA_DIR / "Week 0 - 2024-04-08 - Sales.csv")
initial_state = pd.read_csv(DATA_DIR / "Week 0 - 2024-04-08 - Initial State.csv")
template = pd.read_csv(DATA_DIR / "Week 0 - Submission Template.csv")

# Convert to long format
id_cols = ["Store", "Product"]
sales_long = sales_df.melt(id_vars=id_cols, var_name="date", value_name="sales_qty")
sales_long["date"] = pd.to_datetime(sales_long["date"])
sales_long["sales_qty"] = pd.to_numeric(sales_long["sales_qty"], errors="coerce").fillna(0)
sales_long = sales_long.sort_values(["Store", "Product", "date"]).reset_index(drop=True)

# Get SKU list (all 599)
sku_list = template[["Store", "Product"]].copy()

print(f"\\n📦 Data Loaded:")
print(f"  Total SKUs: {len(sku_list)}")
print(f"  Sales history: {sales_long.shape}")
print(f"  Weeks of data: {sales_long['date'].nunique()}")
print(f"  Date range: {sales_long['date'].min()} to {sales_long['date'].max()}")


## 3. Load TimesFM 2.5 with Quantile Head


In [None]:
print("Loading TimesFM 2.5...")
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")

print("Compiling with quantile head...")
model.compile(
    timesfm.ForecastConfig(
        max_context=512,
        max_horizon=128,
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
    )
)
print("✓ Model ready with quantile forecasting enabled")


## 4. Generate Quantile Forecasts (All SKUs)


In [None]:
# Prepare inputs for all SKUs
CONTEXT_LENGTH = 140  # Use last 140 weeks
HORIZON = PROTECTION_WEEKS  # 3 weeks

inputs = []
for idx, row in sku_list.iterrows():
    sku_data = sales_long[
        (sales_long["Store"] == row["Store"]) &
        (sales_long["Product"] == row["Product"])
    ].sort_values("date")
    
    history = sku_data["sales_qty"].values
    inputs.append(history[-CONTEXT_LENGTH:] if len(history) >= CONTEXT_LENGTH else history)

print(f"\\n🔮 Generating forecasts for {len(inputs)} SKUs...")
print(f"  Context: {CONTEXT_LENGTH} weeks")
print(f"  Horizon: {HORIZON} weeks")

# Generate quantile forecasts
point_forecast, quantile_forecast = model.forecast(
    horizon=HORIZON,
    inputs=inputs,
)

print(f"\\n✓ Forecasts generated!")
print(f"  Point forecast shape: {point_forecast.shape}")
print(f"  Quantile forecast shape: {quantile_forecast.shape}")
print(f"  Quantiles: [mean, P10, P20, P30, P40, P50, P60, P70, P80, P90]")


## 5. Select Cost-Optimal Quantile

Use critical fractile = 0.8333 → closest quantile is P80


In [None]:
# Map critical fractile to closest available quantile
quantile_levels = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
closest_idx = np.argmin(np.abs(quantile_levels - CRITICAL_RATIO))
chosen_quantile = quantile_levels[closest_idx]

print(f"\\n💰 Cost-Optimal Quantile Selection:")
print(f"  Critical fractile: {CRITICAL_RATIO:.4f}")
print(f"  Closest quantile: P{int(chosen_quantile*100)}")
print(f"  This means: Order enough to satisfy demand in {chosen_quantile*100:.0f}% of scenarios")

# Extract chosen quantile (add 1 for mean offset in array)
optimal_quantile_forecast = quantile_forecast[:, :, closest_idx + 1]
print(f"\\n  Selected forecast shape: {optimal_quantile_forecast.shape}")


## 6. Aggregate to Protection Period

Sum 3-week forecasts and estimate uncertainty


In [None]:
# Build demand_stats for policy
demand_stats = []

for i, (idx, sku_row) in enumerate(sku_list.iterrows()):
    # 3-week total demand (sum across horizon)
    demand_3w = optimal_quantile_forecast[i, :].sum()
    
    # Estimate std from quantile spread (IQR method)
    # P80-P50 gives rough sense of upper tail
    q50 = quantile_forecast[i, :, 5].sum()  # Median over 3 weeks
    q80 = quantile_forecast[i, :, 8].sum()  # P80 over 3 weeks
    
    # Rough std estimate: (P80-P50) / 0.84 (z-score for 80th percentile)
    std_3w = max((q80 - q50) / 0.84, demand_3w * 0.1)  # Floor at 10% of mean
    
    demand_stats.append({
        "Store": sku_row["Store"],
        "Product": sku_row["Product"],
        "mean_demand": float(demand_3w),
        "std_demand": float(std_3w)
    })

demand_stats_df = pd.DataFrame(demand_stats).set_index(["Store", "Product"])

print(f"\\n📊 Demand Statistics (3-week protection period):")
print(f"  Mean demand: {demand_stats_df['mean_demand'].mean():.2f} units")
print(f"  Mean std: {demand_stats_df['std_demand'].mean():.2f} units")
print(f"  Max demand: {demand_stats_df['mean_demand'].max():.0f} units")
demand_stats_df.head()


## 7. Generate Orders Using Base-Stock Policy


In [None]:
# Prepare current state
state = initial_state[["Store", "Product", "End Inventory", "In Transit W+1", "In Transit W+2"]].copy()
state.rename(columns={"End Inventory": "on_hand"}, inplace=True)
state["on_order"] = state[["In Transit W+1", "In Transit W+2"]].sum(axis=1)
current_state = state[["Store", "Product", "on_hand", "on_order"]].set_index(["Store", "Product"])

# Get index from template
index_df = template[["Store", "Product"]].set_index(["Store", "Product"])

# Compute orders
print("\\n🎯 Computing orders...")
orders = compute_orders(
    index_df=index_df,
    demand_stats=demand_stats_df,
    current_state=current_state,
    lead_time_weeks=LEAD_WEEKS,
    review_period_weeks=REVIEW_WEEKS,
    shortage_cost_per_unit=SHORTAGE_COST,
    holding_cost_per_unit_per_week=HOLDING_COST,
)

print(f"\\n✓ Orders computed for {len(orders)} SKUs")
print(f"  Total units: {orders.sum():,.0f}")
print(f"  Mean order: {orders.mean():.2f} units")
print(f"  Median order: {orders.median():.0f} units")
print(f"  Max order: {orders.max():.0f} units")
print(f"  Zero orders: {(orders == 0).sum()} SKUs")


## 8. Save Submission


In [None]:
# Create submission DataFrame
submission = index_df.copy()
submission["0"] = orders.values

# Validate
assert len(submission) == len(template), "Row count mismatch!"
assert submission.index.equals(template.set_index(["Store", "Product"]).index), "Index mismatch!"

# Save
output_path = SUB_DIR / "orders_timesfm_quantile_demo.csv"
submission.to_csv(output_path)

print(f"\\n✅ SUBMISSION SAVED: {output_path}")
print(f"\\n📄 Submission Summary:")
print(f"  Rows: {len(submission)}")
print(f"  Columns: {list(submission.columns)}")
print(f"  Total units ordered: {submission['0'].sum():,.0f}")
print(f"  Mean order: {submission['0'].mean():.2f}")
print(f"  Zeros: {(submission['0'] == 0).sum()}")

print(f"\\n🔝 Top 10 Orders:")
top_10 = submission.nlargest(10, "0")
print(top_10)

print(f"\\n💡 This demo shows TimesFM 2.5's quantile forecasting capability.")
print(f"   Cost-optimal quantile (P80) was automatically selected based on Cu/Co ratio.")
print(f"   Ready for VN2 submission validation!")
