In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import random
import matplotlib.pyplot as plt
from itertools import product
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.optimize import differential_evolution
import numpy as np
import pickle
from datetime import datetime
from collections import defaultdict, deque
from typing import Dict, Tuple, List, Optional
from collections import defaultdict
import pickle
from datetime import datetime
import time
from copy import deepcopy

In [None]:
df = pd.read_csv("/content/retail_price.csv")

In [None]:
def prepare_weekly_data(df: pd.DataFrame, min_weeks: int = 4,
                        date_col: str = 'month_year') -> Dict[str, Tuple[List[int], List[float], float]]:
    """weekly product data with demand-weighted reference prices"""
    try:
        df['date'] = pd.to_datetime(df[date_col], errors='coerce')
        df = df.dropna(subset=['date'])

        df['year_week'] = df['date'].dt.to_period('W').astype(str)

        weekly_demand = (
            df.groupby(['product_id', 'year_week'], observed=True)['qty']
            .sum()
            .reset_index()
            .sort_values(['product_id', 'year_week']))

        product_counts = weekly_demand['product_id'].value_counts()
        valid_products = product_counts[product_counts >= min_weeks].index.tolist()

        if not valid_products:
            raise ValueError(f"No products with ≥ {min_weeks} weeks of data")

        products_data = {}
        for product_id in valid_products:
            product_demand = weekly_demand[weekly_demand['product_id'] == product_id]
            demand_series = (
                product_demand.sort_values('year_week')['qty']
                .tolist())

            price_df = (
                df[df['product_id'] == product_id]
                .groupby('year_week', observed=True)[['unit_price', 'qty']]  # Explicitly select columns
                .apply(lambda x: np.average(x['unit_price'], weights=x['qty']))
                .reset_index(name='weighted_price')
                .sort_values('year_week'))
            price_series = price_df['weighted_price'].tolist()

            product_data = df[df['product_id'] == product_id]
            valid_transactions = product_data[product_data['qty'] > 0]

            if len(valid_transactions) == 0:
                continue

            try:
                ref_price = np.average(
                    valid_transactions['unit_price'],
                    weights=valid_transactions['qty'])
            except:

                ref_price = valid_transactions['unit_price'].median()

            min_length = min(len(demand_series), len(price_series))
            if min_length < min_weeks:
                continue

            products_data[product_id] = (
                demand_series[:min_length],
                price_series[:min_length],
                ref_price)

        print(f"Prepared data for {len(products_data)} products (min {min_weeks} weeks)")
        return products_data

    except Exception as e:
        print(f"Data preparation failed: {str(e)}")
        return {}

products_data = prepare_weekly_data(df)

Prepared data for 28 products (min 4 weeks)


##**Initial run**

In [None]:
def prepare_weekly_data(df: pd.DataFrame, min_weeks: int = 4,
                        date_col: str = 'month_year') -> Dict[str, Tuple[List[int], List[float], float]]:
    try:
        df['date'] = pd.to_datetime(df[date_col], errors='coerce')
        df = df.dropna(subset=['date'])

        df['year_week'] = df['date'].dt.to_period('W').astype(str)

        weekly_demand = (
            df.groupby(['product_id', 'year_week'], observed=True)['qty']
            .sum()
            .reset_index()
            .sort_values(['product_id', 'year_week']))

        product_counts = weekly_demand['product_id'].value_counts()
        valid_products = product_counts[product_counts >= min_weeks].index.tolist()

        if not valid_products:
            raise ValueError(f"No products with ≥ {min_weeks} weeks of data")

        products_data = {}
        for product_id in valid_products:
            product_demand = weekly_demand[weekly_demand['product_id'] == product_id]
            demand_series = (
                product_demand.sort_values('year_week')['qty']
                .tolist())

            price_df = (
                df[df['product_id'] == product_id]
                .groupby('year_week', observed=True)[['unit_price', 'qty']]
                .apply(lambda x: np.average(x['unit_price'], weights=x['qty']))
                .reset_index(name='weighted_price')
                .sort_values('year_week'))
            price_series = price_df['weighted_price'].tolist()

            product_data = df[df['product_id'] == product_id]
            valid_transactions = product_data[product_data['qty'] > 0]

            if len(valid_transactions) == 0:
                continue

            try:
                ref_price = np.average(
                    valid_transactions['unit_price'],
                    weights=valid_transactions['qty'])
            except:
                ref_price = valid_transactions['unit_price'].median()

            min_length = min(len(demand_series), len(price_series))
            if min_length < min_weeks:
                continue

            products_data[product_id] = (
                demand_series[:min_length],
                price_series[:min_length],
                ref_price)

        print(f"Prepared data for {len(products_data)} products (min {min_weeks} weeks)")
        return products_data

    except Exception as e:
        print(f"Data preparation failed: {str(e)}")
        return {}

class OptimizedSingleProductEnv:
    def __init__(self, demand_series: List[int], ref_price: float,
                 init_inventory: int = 500,
                 lead_time: int = 1,
                 elasticity: float = -0.8,
                 holding_rate: float = 0.015,
                 stockout_penalty: float = 5.0,
                 order_cost: float = 0.5,
                 safety_stock: int = 10,
                 service_weight: float = 100.0):

        self.service_weight = service_weight
        self.original_demand = [max(1, int(d)) for d in demand_series]
        self.ref_price = max(0.1, float(ref_price))
        self.weeks = len(self.original_demand)
        self.init_inventory = max(0, int(init_inventory))
        self.lead_time = min(max(1, int(lead_time)), self.weeks-1)
        self.elasticity = max(-5.0, min(0.0, float(elasticity)))
        self.holding_rate = max(0.0, float(holding_rate))
        self.stockout_penalty = max(0.0, float(stockout_penalty))
        self.order_cost = max(0.0, float(order_cost))
        self.safety_stock = max(0, int(safety_stock))
        self.avg_demand = np.mean(self.original_demand) if self.original_demand else 1
        self._calculate_demand_variability()
        self.reset()

    def _calculate_demand_variability(self):
        try:
            if self.weeks > self.lead_time:
                lt_demands = [sum(self.original_demand[i:i+self.lead_time])
                              for i in range(len(self.original_demand)-self.lead_time+1)]
                self.std_demand_lead_time = np.std(lt_demands) if len(lt_demands) > 1 else 3.0
            else:
                self.std_demand_lead_time = np.std(self.original_demand) if len(self.original_demand) > 1 else 3.0
        except:
            self.std_demand_lead_time = 3.0

    def reset(self):
        self.current_week = 0
        self.inventory = self.init_inventory
        self.total_profit = 0
        self.total_demand = 0
        self.total_sales = 0
        self.weekly_stats = []
        return self._get_state()

    def _get_state(self):
        return {
            'inventory': self.inventory,
            'current_week': self.current_week,
            'demand': self.original_demand[self.current_week] if self.current_week < self.weeks else 0,
            'avg_demand': np.mean(self.original_demand[:self.current_week+1]) if self.current_week > 0 else 0}

    def step(self, price: float, order_qty: int = 0) -> float:
        if self.current_week >= self.weeks:
            return 0.0

        try:
            min_price = self.ref_price * 0.7
            price = max(min_price, price)

            base_demand = self.original_demand[self.current_week]
            price_factor = (max(0.1, price) / self.ref_price) ** self.elasticity
            noise = np.random.normal(0, 0.1 * base_demand)
            demand = max(1, int(base_demand * price_factor + noise))

            self.inventory += order_qty
            available = max(self.inventory - self.safety_stock, 0)
            sales_qty = min(available, demand)

            revenue = sales_qty * price
            holding_cost = self.inventory * self.holding_rate
            stockout_cost = max(demand - sales_qty, 0) * self.stockout_penalty
            order_cost = order_qty * self.order_cost
            profit = revenue - holding_cost - stockout_cost - order_cost
            service_level = sales_qty / demand if demand > 0 else 0.0

            profit_penalty = np.tanh(profit / 1000.0)
            service_bonus = self.service_weight * service_level * (1.0 if profit > 0 else 0.2)
            inventory_penalty = -0.01 * (self.inventory / max(1, self.avg_demand))

            scaled_reward = profit_penalty + service_bonus + inventory_penalty

            prev_inventory = self.inventory
            self.inventory -= sales_qty
            self.total_profit += profit
            self.total_demand += demand
            self.total_sales += sales_qty

            self.weekly_stats.append({
                'week': self.current_week,
                'price': price,
                'order_qty': order_qty,
                'inventory': self.inventory,
                'demand': demand,
                'sales': sales_qty,
                'profit': profit,
                'service_level': service_level,
                'inventory_change': self.inventory - prev_inventory})

            self.current_week += 1
            return scaled_reward

        except Exception as e:
            print(f"Error in step: {str(e)}")
            return 0.0

    def service_level(self) -> float:
        if self.total_demand > 0:
            return self.total_sales / self.total_demand
        else:
            return 0.0

class MultiProductRetailEnv:
    def __init__(self, products_data: Dict[str, Tuple[List[int], List[float], float]],
                 warehouse_capacity: int = 4412,
                 shared_transport_cost: float = 0.2):
        self.products = {}
        self.warehouse_capacity = warehouse_capacity
        self.shared_transport_cost = shared_transport_cost
        self.current_week = 0
        self.shared_shipment_count = 0
        self.shared_costs = 0
        self.total_orders_this_week = defaultdict(int)
        self.weekly_warehouse_utilization = []
        self.metrics = {
            'shared_costs_total': 0,
            'over_utilization_penalties': 0,
            'cumulative_rewards': defaultdict(float)}

        for product_id, (demand_series, price_series, ref_price) in products_data.items():
            try:
                self.products[product_id] = {
                    'env': OptimizedSingleProductEnv(demand_series, ref_price),
                    'pending_orders': [],
                    'inventory': 0,
                    'space_required': max(1, int(ref_price/15))}
            except Exception as e:
                print(f"Error initializing product {product_id}: {str(e)}")
                continue

        if not self.products:
            raise ValueError("No valid products initialized")
        self.reset()

    def reset(self):
        self.current_week = 0
        self.shared_costs = 0
        self.shared_shipment_count = 0
        self.total_orders_this_week = defaultdict(int)
        self.weekly_warehouse_utilization = []
        self.metrics = {
            'shared_costs_total': 0,
            'over_utilization_penalties': 0,
            'cumulative_rewards': defaultdict(float)}

        for product_id, data in self.products.items():
            data['env'].reset()
            data['inventory'] = data['env'].init_inventory
            data['pending_orders'] = []
        return self.get_global_state()

    def get_global_state(self):
        return {
            'current_week': self.current_week,
            'warehouse_utilization': self.get_warehouse_utilization(),
            'total_orders': sum(self.total_orders_this_week.values()),
            'shared_costs': self.shared_costs}

    def get_warehouse_utilization(self):
        try:
            total_used = sum(data['inventory'] * data['space_required'] for data in self.products.values())
            return min(1.0, total_used / self.warehouse_capacity)
        except:
            return 0.0

    def step(self, actions: Dict[str, Tuple[float, int]]):
        max_weeks = max(len(data['env'].original_demand) for data in self.products.values())
        if self.current_week >= max_weeks:
            return self.get_global_state(), {}, True, {}

        individual_rewards = {}
        self.total_orders_this_week.clear()

        for product_id, (price, order_qty) in actions.items():
            if product_id not in self.products:
                continue
            if order_qty is None:
                max_order = 0
            else:
                product_data = self.products[product_id]
                max_order = self._get_max_order(product_id, order_qty)
            self.total_orders_this_week[product_id] = max_order
            if max_order > 0:
                arrival_week = self.current_week + product_data['env'].lead_time
                product_data['pending_orders'].append((arrival_week, max_order))

        for product_id, (price, order_qty) in actions.items():
            if product_id not in self.products:
                continue
            product_data = self.products[product_id]
            env = product_data['env']
            actual_order_qty = self.total_orders_this_week.get(product_id, 0)
            profit = env.step(price, actual_order_qty)
            product_data['inventory'] = env.inventory
            shared_cost_reduction = self.shared_transport_cost * len([pid for pid in actions if self.total_orders_this_week.get(pid, 0) > 0])
            individual_rewards[product_id] = profit + shared_cost_reduction

        active_orders = sum(1 for qty in self.total_orders_this_week.values() if qty > 0)
        if active_orders > 1:
            self.shared_shipment_count += 1
            self.shared_costs += self.shared_transport_cost * (active_orders - 1)
            self.metrics['shared_costs_total'] += self.shared_transport_cost * (active_orders - 1)

        self._process_order_arrivals()
        utilization = self.get_warehouse_utilization()
        self.weekly_warehouse_utilization.append(utilization)

        if utilization > 0.85:
            penalty = 50 * (utilization - 0.85)
            self.metrics['over_utilization_penalties'] += penalty
            penalty_per_product = penalty / len(individual_rewards) if individual_rewards else 0
            for product_id in individual_rewards:
                individual_rewards[product_id] -= penalty_per_product

        self.current_week += 1
        done = self.current_week >= max_weeks
        return self.get_global_state(), individual_rewards, done, {}

    def _get_max_order(self, product_id: str, requested_qty: int) -> int:
        if product_id not in self.products:
            return 0
        product_data = self.products[product_id]
        space_per_unit = product_data['space_required']
        used_space = 0
        for pid, data in self.products.items():
            current_inventory_space = data['inventory'] * data['space_required']
            pending_space = sum(qty * data['space_required'] for week, qty in data['pending_orders'] if week > self.current_week)
            used_space += current_inventory_space + pending_space
        available_space = max(0, self.warehouse_capacity - used_space)
        max_possible = available_space / space_per_unit
        return min(requested_qty, int(max_possible))

    def _process_order_arrivals(self):
        for product_id, product_data in self.products.items():
            arrived_qty = sum(qty for week, qty in product_data['pending_orders'] if week == self.current_week)
            product_data['inventory'] += arrived_qty
            product_data['pending_orders'] = [
                (week, qty) for week, qty in product_data['pending_orders']
                if week > self.current_week]

    def summarize(self):
        summary = {
            'total_weeks': self.current_week,
            'avg_utilization': np.mean(self.weekly_warehouse_utilization) if self.weekly_warehouse_utilization else 0,
            'shared_shipments': self.shared_shipment_count,
            'shared_costs_total': self.metrics['shared_costs_total'],
            'over_utilization_penalties': self.metrics['over_utilization_penalties'],
            'cumulative_rewards': dict(self.metrics['cumulative_rewards']),
            'product_service_levels': {pid: round(data['env'].service_level(), 3)
                                       for pid, data in self.products.items()}}
        return summary

class OptimizedProductAgent(nn.Module):
    def __init__(self,
                 input_size: int = 5,
                 hidden_size: int = 64,
                 price_bins: int = 10,
                 order_bins: int = 11,
                 lr: float = 0.00005,
                 batch_size: int = 256,
                 target_update_freq: int = 5,
                 service_weight: float = 20.0):
        super().__init__()

        self.price_bins_count = price_bins
        self.order_bins_count = order_bins
        self.action_size = price_bins * order_bins

        self.price_bins = np.linspace(0.8, 1.5, self.price_bins_count)
        self.order_bins = np.linspace(0, 50, self.order_bins_count)

        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, self.action_size))

        self.optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=0.01)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=50)
        self.memory = []
        self.gamma = 0.97
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        self.tau = 0.01
        self.steps = 0
        self.best_avg_reward = -float('inf')
        self.patience_counter = 0
        self.max_patience = 100
        self.service_weight = service_weight

        self.target_net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, self.action_size))
        self._update_target_net()

    def _update_target_net(self):
        for target_param, param in zip(self.target_net.parameters(), self.net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)

    def forward(self, x):
        return self.net(x)

    def get_action(self, state: Dict) -> Tuple[float, int]:
        if random.random() < self.epsilon:
            price_idx = random.randint(0, len(self.price_bins)-1)
            order_idx = random.randint(0, len(self.order_bins)-1)
            return self.price_bins[price_idx], self.order_bins[order_idx]

        state_tensor = torch.FloatTensor([
            state['inventory'],
            state['demand'],
            state['avg_demand'],
            state['global']['warehouse_utilization'],
            state['global']['current_week'] / 52])

        with torch.no_grad():
            q_values = self(state_tensor)
            q_values_2d = q_values.view(self.price_bins_count, self.order_bins_count)
            max_idx = torch.argmax(q_values_2d)
            price_idx = max_idx // self.order_bins_count
            order_idx = max_idx % self.order_bins_count

        return self.price_bins[price_idx], self.order_bins[order_idx]

    def act(self, state: List) -> Tuple[float, int]:
        state = np.array(state)
        state_normalized = (state - state.mean()) / (state.std() + 1e-8)

        if random.random() < self.epsilon:
            return random.uniform(0.8, 1.2), random.randint(0, 30)  # Tighter exploration

        state_tensor = torch.FloatTensor(state_normalized)
        with torch.no_grad():
            q_values = self(state_tensor.unsqueeze(0))
            action_idx = torch.argmax(q_values).item()
            price_idx = action_idx // self.order_bins_count
            order_idx = action_idx % self.order_bins_count
            price_factor = self.price_bins[price_idx]
            order_qty = self.order_bins[order_idx]
        return price_factor, int(order_qty)

    def remember(self, state, action, reward, next_state, done=False):
        if isinstance(state, dict):
            price, order = action
            price_idx = np.argmin(np.abs(self.price_bins - price))
            order_idx = np.argmin(np.abs(self.order_bins - order))
            action_idx = price_idx * self.order_bins_count + order_idx
            state_tensor = [
                state['inventory'],
                state['demand'],
                state['avg_demand'],
                state['global']['warehouse_utilization'],
                state['global']['current_week'] / 52]
            next_state_tensor = [
                next_state['inventory'],
                next_state['current_week'],
                next_state['demand'],
                next_state['avg_demand'],
                next_state['global']['warehouse_utilization'],
                next_state['global']['current_week'] / 52] if next_state is not None else [0,0,0,0,0,0]
        else:
            price_factor, order_qty = action
            price_idx = np.argmin(np.abs(self.price_bins - price_factor))
            order_idx = np.argmin(np.abs(self.order_bins - order_qty))
            action_idx = price_idx * self.order_bins_count + order_idx
            state_tensor = state
            next_state_tensor = next_state if next_state is not None else [0,0,0,0]

        state_tensor_t = torch.FloatTensor(state_tensor)
        with torch.no_grad():
            current_q = self(state_tensor_t.unsqueeze(0))[0, action_idx].item()
            next_state_tensor_t = torch.FloatTensor(next_state_tensor)
            next_q = self.target_net(next_state_tensor_t.unsqueeze(0)).max().item()
            td_error = abs(reward + self.gamma * next_q - current_q)

        priority = td_error + 1e-5
        self.memory.append((state_tensor, action_idx, reward, next_state_tensor, priority))

    def replay(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size
        if len(self.memory) < batch_size:
            return False

        priorities = np.array([mem[4] for mem in self.memory])
        probabilities = priorities / priorities.sum()
        indices = np.random.choice(len(self.memory), batch_size, p=probabilities)

        is_weights = (len(self.memory) * probabilities[indices]) ** -0.4
        is_weights = torch.FloatTensor(is_weights / is_weights.max())

        batch = [self.memory[i] for i in indices]
        states, action_indices, rewards, next_states, _ = zip(*batch)

        state_tensors = torch.FloatTensor(states)
        next_state_tensors = torch.FloatTensor(next_states)
        action_tensors = torch.LongTensor(action_indices).unsqueeze(1)
        reward_tensors = torch.FloatTensor(rewards)

        current_q = self(state_tensors).gather(1, action_tensors)

        with torch.no_grad():
            next_q = self.target_net(next_state_tensors).max(1)[0]
            target_q = reward_tensors + self.gamma * next_q

        loss = (is_weights * nn.SmoothL1Loss(reduction='none')(current_q.squeeze(), target_q)).mean()

        self.optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=0.5)

        self.optimizer.step()

        recent_rewards = [mem[2] for mem in self.memory[-100:]] if len(self.memory) > 100 else [0]
        avg_reward = np.mean(recent_rewards)
        self.scheduler.step(avg_reward)

        if len(recent_rewards) >= 50:
            if avg_reward > self.best_avg_reward:
                self.best_avg_reward = avg_reward
                self.patience_counter = 0
            else:
                self.patience_counter += 1

            if self.patience_counter >= self.max_patience:
                return True

        self.steps += 1
        if self.steps % self.target_update_freq == 0:
            self._update_target_net()

        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        return False

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def validate(self):
        if len(self.memory) < 50:
            return 0.0

        recent_losses = []
        for i in range(-50, -1):
            state, action, reward, next_state, _ = self.memory[i]
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                current_q = self(state_tensor)[0, action].item()
                next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
                next_q = self.target_net(next_state_tensor).max().item()
                target_q = reward + self.gamma * next_q
                loss = abs(current_q - target_q)
                recent_losses.append(loss)

        return np.mean(recent_losses)

class OptimizedMultiAgentCoordinator:
    def __init__(self, products_data: Dict[str, Tuple[List[int], List[float], float]]):
        if not products_data:
            raise ValueError("No products data provided")

        self.env = MultiProductRetailEnv(products_data)
        self.agents = {}
        self.rewards_history = defaultdict(list)
        self.metrics = {'total_episode_rewards': []}
        self.convergence_metrics = []

        for product_id, (demand_series, price_series, ref_price) in products_data.items():
            self.agents[product_id] = OptimizedProductAgent(
                input_size=4,
                price_bins=10,
                order_bins=11,
                lr=0.00005)

    def train(self, episodes: int = 1000, eval_every: int = 50, batch_size=32):
        max_weeks = max([len(data['env'].original_demand) for data in self.env.products.values()])
        converged = False

        for ep in range(episodes):
            if converged:
                print(f"Training converged at episode {ep}")
                break

            self.env.reset()
            episode_rewards = defaultdict(float)
            done = False

            for week in range(max_weeks):
                actions = {}
                for pid, agent in self.agents.items():
                    if self.env.products[pid]['env'].current_week < self.env.products[pid]['env'].weeks:
                        agent_state_data = self.env.products[pid]['env']._get_state()
                        agent_state = [
                            agent_state_data['inventory'],
                            agent_state_data['current_week'],
                            agent_state_data['demand'],
                            agent_state_data['avg_demand']]
                        price_factor, order_qty = agent.act(agent_state)
                        actual_price = price_factor * self.env.products[pid]['env'].ref_price
                        actions[pid] = (actual_price, order_qty)
                    else:
                        actions[pid] = (None, None)

                global_state, individual_rewards, done, _ = self.env.step(actions)

                for pid, reward in individual_rewards.items():
                    if pid in actions and actions[pid][0] is not None:
                        episode_rewards[pid] += reward
                        agent_state_data = self.env.products[pid]['env']._get_state()
                        current_state = [
                            agent_state_data['inventory'],
                            agent_state_data['current_week'],
                            agent_state_data['demand'],
                            agent_state_data['avg_demand']]
                        next_agent_state_data = self.env.products[pid]['env']._get_state()
                        next_state = [
                            next_agent_state_data['inventory'],
                            next_agent_state_data['current_week'],
                            next_agent_state_data['demand'],
                            next_agent_state_data['avg_demand']]
                        self.agents[pid].remember(
                            current_state,
                            (price_factor, order_qty),
                            reward,
                            next_state)
                        self.agents[pid].decay_epsilon()

                if done:
                    break

            convergence_signals = []
            for pid, agent in self.agents.items():
                validation_loss = agent.validate()
                if validation_loss < 1000:
                    should_stop = agent.replay(batch_size)
                    convergence_signals.append(should_stop)
                else:
                    print(f"Warning: Agent {pid} validation loss too high ({validation_loss:.2f}), skipping replay")
                    convergence_signals.append(False)

            if all(convergence_signals):
                converged = True

            total_episode_reward = sum(episode_rewards.values())
            self.metrics['total_episode_rewards'].append(total_episode_reward)

            for pid, reward in episode_rewards.items():
                self.rewards_history[pid].append(reward)

            if ep % 20 == 0:
                convergence_metrics = self.calculate_convergence_metrics()
                self.convergence_metrics.append(convergence_metrics)
                status_symbol = "✓" if convergence_metrics['convergence_status'] == "CONVERGED" else "➡" if convergence_metrics['convergence_status'] == "STABILIZING" else "↗"
                print(f"Episode {ep+1}/{episodes} {status_symbol} | Reward: {total_episode_reward:.2f} | Q-Drift: {convergence_metrics['q_value_drift']:.4f} | Status: {convergence_metrics['convergence_status']}")

            if (ep+1) % eval_every == 0:
                print(f"Episode {ep+1}/{episodes} | Total Reward: {total_episode_reward:.2f}")

        return self.rewards_history, self.env

In [None]:
class AblationStudy:
    def __init__(self, products_data, baseline_config):
        self.products_data = products_data
        self.baseline_config = baseline_config
        self.results = {}

    def run_ablation_study(self, num_trials=3, episodes=500):
        print("Starting Ablation Study...")
        print(f"Running {num_trials} trials per configuration")

        self._test_reward_components(num_trials, episodes)
        self._test_learning_parameters(num_trials, episodes)
        self._test_inventory_components(num_trials, episodes)
        self._test_network_architecture(num_trials, episodes)
        self._generate_ablation_report()

        return self.results

    def _test_reward_components(self, num_trials, episodes):
        """Test reward function variations"""
        print("\nTesting Reward Function Components...")

        configurations = {
            'baseline': self.baseline_config,
            'no_service_weight': self._modify_config(self.baseline_config, {'service_weight': 0}),
            'high_service_weight': self._modify_config(self.baseline_config, {'service_weight': 100}),
            'no_inventory_penalty': self._modify_config(self.baseline_config, {'inventory_penalty_weight': 0}),
            'no_stockout_penalty': self._modify_config(self.baseline_config, {'stockout_penalty': 0}),
            'no_profit_scaling': self._modify_config(self.baseline_config, {'profit_scaling': False}),}

        for config_name, config in configurations.items():
            self.results[config_name] = self._run_configuration(config, num_trials, episodes)

    def _test_learning_parameters(self, num_trials, episodes):
        """Testing learning parameter variations"""
        print("\nTesting Learning Parameters...")

        configurations = {
            'high_lr': self._modify_config(self.baseline_config, {'learning_rate': 0.001}),
            'very_low_lr': self._modify_config(self.baseline_config, {'learning_rate': 0.00001}),
            'small_batch': self._modify_config(self.baseline_config, {'batch_size': 32}),
            'large_batch': self._modify_config(self.baseline_config, {'batch_size': 512}),
            'no_target_updates': self._modify_config(self.baseline_config, {'target_update_freq': 1000}),
            'frequent_updates': self._modify_config(self.baseline_config, {'target_update_freq': 1})}

        for config_name, config in configurations.items():
            self.results[config_name] = self._run_configuration(config, num_trials, episodes)

    def _test_inventory_components(self, num_trials, episodes):
        """Testing inventory management variations"""
        print("\nTesting Inventory Components...")

        configurations = {
            'no_safety_stock': self._modify_config(self.baseline_config, {'safety_stock': 0}),
            'high_safety_stock': self._modify_config(self.baseline_config, {'safety_stock': 30}),
            'no_warehouse_constraints': self._modify_config(self.baseline_config, {'warehouse_constraints': False}),
            'long_lead_time': self._modify_config(self.baseline_config, {'lead_time': 3}),}

        for config_name, config in configurations.items():
            self.results[config_name] = self._run_configuration(config, num_trials, episodes)

    def _test_network_architecture(self, num_trials, episodes):
        """Testing network architecture variations"""
        print("\nTesting Network Architecture...")

        configurations = {
            'small_network': self._modify_config(self.baseline_config, {'hidden_size': 32, 'num_layers': 1}),
            'large_network': self._modify_config(self.baseline_config, {'hidden_size': 256, 'num_layers': 4}),
            'no_normalization': self._modify_config(self.baseline_config, {'layer_norm': False}),
            'different_optimizer': self._modify_config(self.baseline_config, {'optimizer': 'SGD'}),}

        for config_name, config in configurations.items():
            self.results[config_name] = self._run_configuration(config, num_trials, episodes)

    def _modify_config(self, config, modifications):
        new_config = deepcopy(config)
        new_config.update(modifications)
        return new_config

    def _run_configuration(self, config, num_trials, episodes):
        trial_results = []

        for trial in range(num_trials):
            print(f"   Trial {trial+1}/{num_trials} for {config.get('name', 'config')}...")

            env = MultiProductRetailEnv(
                self.products_data,
                warehouse_capacity=config.get('warehouse_capacity', 4412),
                shared_transport_cost=config.get('shared_transport_cost', 0.2))

            agents = {}
            for product_id, (demand_series, price_series, ref_price) in self.products_data.items():
                agents[product_id] = OptimizedProductAgent(
                    input_size=config.get('input_size', 4),
                    hidden_size=config.get('hidden_size', 64),
                    price_bins=config.get('price_bins', 10),
                    order_bins=config.get('order_bins', 11),
                    lr=config.get('learning_rate', 0.0001),
                    batch_size=config.get('batch_size', 256),
                    target_update_freq=config.get('target_update_freq', 5),
                    service_weight=config.get('service_weight', 20))

            coordinator = AblationCoordinator(env, agents)
            rewards_history, final_env = coordinator.train(
                episodes=episodes,
                eval_every=50,
                batch_size=config.get('batch_size', 32))

            results = coordinator.evaluate(num_episodes=5)
            performance_report = self._generate_performance_report(results, final_env)

            trial_results.append({
                'config': config,
                'performance': performance_report,
                'rewards_history': rewards_history,
                'convergence_metrics': coordinator.convergence_metrics})

        return trial_results

    def _generate_performance_report(self, results, env):
        inventory_metrics = {}

        for pid, stats in results.items():
            inventory_levels = [w['inventory'] for w in stats['weekly_stats']]
            avg_inventory = np.mean(inventory_levels)
            inventory_turnover = stats['total_sales'] / avg_inventory if avg_inventory > 0 else 0
            stockout_weeks = sum(1 for w in stats['weekly_stats'] if w.get('service_level', 1) < 0.5)
            avg_service = np.mean([w.get('service_level', 0) for w in stats['weekly_stats']])

            inventory_metrics[pid] = {
                'avg_inventory': avg_inventory,
                'inventory_turnover': inventory_turnover,
                'stockout_weeks': stockout_weeks,
                'max_inventory': max(inventory_levels) if inventory_levels else 0,
                'min_inventory': min(inventory_levels) if inventory_levels else 0,
                'avg_service': avg_service,
                'avg_demand': np.mean([w.get('demand', 0) for w in stats['weekly_stats']])}

        total_profit = sum(stats['total_profit'] for stats in results.values())

        return {
            'total_profit': total_profit,
            'avg_service_level': np.mean([metrics['avg_service'] for metrics in inventory_metrics.values()]),
            'stockout_weeks': sum(metrics['stockout_weeks'] for metrics in inventory_metrics.values()),
            'inventory_turnover': np.mean([metrics['inventory_turnover'] for metrics in inventory_metrics.values()]),
            'warehouse_utilization': np.mean(env.weekly_warehouse_utilization) if env.weekly_warehouse_utilization else 0}

    def _generate_ablation_report(self):
        print("\n" + "="*80)
        print("ABLATION STUDY RESULTS")
        print("="*80)

        baseline_perf = np.mean([trial['performance']['total_profit'] for trial in self.results['baseline']])

        impact_analysis = {}

        for config_name, trials in self.results.items():
            if config_name == 'baseline':
                continue

            avg_profit = np.mean([trial['performance']['total_profit'] for trial in trials])
            profit_change = ((avg_profit - baseline_perf) / baseline_perf) * 100

            profits = [trial['performance']['total_profit'] for trial in trials]
            stability = (np.std(profits) / np.mean(profits)) * 100 if np.mean(profits) > 0 else 100

            impact_analysis[config_name] = {
                'avg_profit': avg_profit,
                'profit_change_pct': profit_change,
                'stability': stability,
                'service_level': np.mean([trial['performance']['avg_service_level'] for trial in trials]),
                'stockouts': np.mean([trial['performance']['stockout_weeks'] for trial in trials])}

        print(f"\nBASELINE PERFORMANCE: ${baseline_perf:,.2f}")
        print("\nREWARD FUNCTION COMPONENTS:")
        self._print_category_results(impact_analysis, ['no_service_weight', 'high_service_weight',
                                                     'no_inventory_penalty', 'no_stockout_penalty',
                                                     'no_profit_scaling'])

        print("\nLEARNING PARAMETERS:")
        self._print_category_results(impact_analysis, ['high_lr', 'very_low_lr', 'small_batch',
                                                     'large_batch', 'no_target_updates', 'frequent_updates'])

        print("\nINVENTORY MANAGEMENT:")
        self._print_category_results(impact_analysis, ['no_safety_stock', 'high_safety_stock',
                                                     'no_warehouse_constraints', 'long_lead_time'])

        print("\nNETWORK ARCHITECTURE:")
        self._print_category_results(impact_analysis, ['small_network', 'large_network',
                                                     'no_normalization', 'different_optimizer'])


        self._generate_impact_ranking(impact_analysis)
        self._save_detailed_results(impact_analysis)

    def _print_category_results(self, impact_analysis, config_names):
        for config_name in config_names:
            if config_name in impact_analysis:
                data = impact_analysis[config_name]
                change_symbol = "+" if data['profit_change_pct'] >= 0 else ""
                print(f"   {config_name:25s}: {change_symbol}{data['profit_change_pct']:6.1f}%  "
                      f"(Service: {data['service_level']:.3f}, Stockouts: {data['stockouts']:.1f})")

    def _generate_impact_ranking(self, impact_analysis):
        print("\nPERFORMANCE IMPACT RANKING:")
        print("-" * 80)

        sorted_impact = sorted(impact_analysis.items(),
                             key=lambda x: abs(x[1]['profit_change_pct']),
                             reverse=True)

        print(f"{'Configuration':30s} {'Impact':>8s} {'Stability':>10s} {'Service':>8s} {'Stockouts':>10s}")
        print("-" * 80)

        for config_name, data in sorted_impact:
            impact_str = f"{data['profit_change_pct']:+.1f}%"
            stability_str = f"{data['stability']:.1f}%"
            service_str = f"{data['service_level']:.3f}"
            stockouts_str = f"{data['stockouts']:.1f}"

            print(f"{config_name:30s} {impact_str:>8s} {stability_str:>10s} {service_str:>8s} {stockouts_str:>10s}")

    def _save_detailed_results(self, impact_analysis):
        results_df = pd.DataFrame.from_dict(impact_analysis, orient='index')
        results_df.to_csv('ablation_study_results.csv')
        print(f"\nDetailed results saved to 'ablation_study_results.csv'")

class AblationCoordinator:
    def __init__(self, env, agents):
        self.env = env
        self.agents = agents
        self.rewards_history = defaultdict(list)
        self.convergence_metrics = []

    def train(self, episodes=500, eval_every=50, batch_size=32):
        max_weeks = max([len(data['env'].original_demand) for data in self.env.products.values()])

        for ep in range(episodes):
            self.env.reset()
            episode_rewards = defaultdict(float)

            for week in range(max_weeks):
                actions = {}
                for pid, agent in self.agents.items():
                    if self.env.products[pid]['env'].current_week < self.env.products[pid]['env'].weeks:
                        agent_state_data = self.env.products[pid]['env']._get_state()
                        agent_state = [
                            agent_state_data['inventory'],
                            agent_state_data['current_week'],
                            agent_state_data['demand'],
                            agent_state_data['avg_demand']]
                        price_factor, order_qty = agent.act(agent_state)
                        actual_price = price_factor * self.env.products[pid]['env'].ref_price
                        actions[pid] = (actual_price, order_qty)
                    else:
                        actions[pid] = (None, None)

                global_state, individual_rewards, done, _ = self.env.step(actions)

                for pid, reward in individual_rewards.items():
                    episode_rewards[pid] += reward

                if done:
                    break

            for pid, reward in episode_rewards.items():
                self.rewards_history[pid].append(reward)

        return self.rewards_history, self.env

    def evaluate(self, num_episodes=5):
        """Evaluation method"""
        results = {}
        for pid in self.agents.keys():
            results[pid] = {
                'total_profit': 0,
                'total_demand': 0,
                'total_sales': 0,
                'weekly_stats': []}

        max_weeks = max([len(data['env'].original_demand) for data in self.env.products.values()])

        for _ in range(num_episodes):
            self.env.reset()
            for week in range(max_weeks):
                actions = {}
                for pid, agent in self.agents.items():
                    if self.env.products[pid]['env'].current_week < self.env.products[pid]['env'].weeks:
                        agent_state_data = self.env.products[pid]['env']._get_state()
                        agent_state = [
                            agent_state_data['inventory'],
                            agent_state_data['current_week'],
                            agent_state_data['demand'],
                            agent_state_data['avg_demand']]
                        price_factor, order_qty = agent.act(agent_state)
                        actual_price = price_factor * self.env.products[pid]['env'].ref_price
                        actions[pid] = (actual_price, order_qty)
                    else:
                        actions[pid] = (None, None)

                global_state, individual_rewards, done, _ = self.env.step(actions)

                for pid in self.agents.keys():
                    if pid in self.env.products:
                        env = self.env.products[pid]['env']
                        results[pid]['total_profit'] += env.total_profit
                        results[pid]['total_demand'] += env.total_demand
                        results[pid]['total_sales'] += env.total_sales
                        if env.weekly_stats and len(env.weekly_stats) > week:
                            results[pid]['weekly_stats'].append(env.weekly_stats[week])

                if done:
                    break

        for pid in results.keys():
            results[pid]['total_profit'] /= num_episodes
            results[pid]['total_demand'] /= num_episodes
            results[pid]['total_sales'] /= num_episodes

        return results

if __name__ == "__main__":
    baseline_config = {
        'name': 'baseline',
        'learning_rate': 0.0001,
        'batch_size': 256,
        'target_update_freq': 5,
        'service_weight': 20,
        'hidden_size': 64,
        'num_layers': 2,
        'safety_stock': 10,
        'warehouse_capacity': 4412,
        'stockout_penalty': 5.0,
        'inventory_penalty_weight': 0.01}


    ablation_study = AblationStudy(products_data, baseline_config)
    results = ablation_study.run_ablation_study(num_trials=5, episodes=793)

    print("\nAblation study completed successfully!")

Starting Ablation Study...
Running 5 trials per configuration

Testing Reward Function Components...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...
   Trial 1/5 for baseline...
   Trial 2/5 for baseline...
   Trial 3/5 for baseline...
   Trial 4/5 for baseline...
   Trial 5/5 for baseline...

Testing Learning Parameters.