# Adaptive Market Planning

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/adaptive_market_planning.ipynb)

In [None]:
# Install JAX and dependencies
!pip install -q jax jaxlib jaxtyping chex numpy matplotlib

# Clone repository (force fresh clone for latest code)
import os
import shutil

if os.path.exists('stochastic-optimization'):
    shutil.rmtree('stochastic-optimization')

!git clone https://github.com/pedronahum/stochastic-optimization.git
os.chdir('stochastic-optimization')

# Clear Python import cache
import sys
for key in list(sys.modules.keys()):
    if key.startswith('problems'):
        del sys.modules[key]

print('✓ Setup complete!')

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Import problem components
from problems.adaptive_market_planning import (
    AdaptiveMarketPlanningConfig,
    AdaptiveMarketPlanningModel,
    HarmonicStepPolicy,
    KestenStepPolicy,
)

print('✓ Imports successful')
print(f'JAX version: {jax.__version__}')
print(f'JAX backend: {jax.default_backend()}')

In [None]:
# Create model configuration
config = AdaptiveMarketPlanningConfig(
    price=1.5,
    cost=0.8,
    demand_mean=100.0,
    initial_order_quantity=50.0
)
model = AdaptiveMarketPlanningModel(config)

# Create a learning policy
policy = HarmonicStepPolicy()

print('✓ Model and policy ready')

In [None]:
config = AdaptiveMarketConfig(n_products=3, horizon=30)
model = AdaptiveMarketModel(config)
key = jax.random.PRNGKey(42)
state = model.init_state(key)
print('✓ Model ready')

In [None]:
# Simple simulation
prices, rewards = [], []
for t in range(30):
    key, k1, k2 = jax.random.split(key, 3)
    decision = jnp.array([1.0, 1.5, 2.0])  # Simple pricing
    exog = model.sample_exogenous(k2, state, t)
    reward = model.reward(state, decision, exog)
    rewards.append(float(reward))
    prices.append(decision.tolist())
    state = model.transition(state, decision, exog)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(rewards)
plt.title('Daily Revenue')
plt.subplot(1, 2, 2)
plt.plot(np.cumsum(rewards))
plt.title('Cumulative Revenue')
plt.tight_layout()
plt.show()
print(f'Total revenue: {sum(rewards):.1f}')