# Multi-echelon Supply Chain Optimization Using Deep Reinforcement Learning

This notebook demonstrates how replenishment and transportation decision in a complex multi-echelon supply chain can be optimized using reinforcement learning methods.

### Use Case
We assume a multi-echelon supply chain that serves stochastic tyme-varying demand. The economic model of the chain includes storage and transportation costs, capacity constraints, and other complexities. Our goal is to learn optimal policy that controls replenishment and transportation decisions balancing various costs and stock-out risks/losses. 

### Prototype: Approach and Data
We create a relatively advanced simulator (`World of Supply`) of a multi-echelon supply chain that includes multiple inventory management and transportation controls. We then use an off-the-shelf deep reinforcement learning component to learn the control policy.

### Usage and Productization
This prototype is provided mainly for educational and experimentation purposes. Productization of this approach can be challenging because of a relatively large action space (multiple control variables) which results in instability of the learning process and highly irregular control policies.

In [None]:
#
# Imports and settings
#
import numpy as np
from tqdm import tqdm as tqdm
import importlib

import world_of_supply_tools as wst
importlib.reload(wst)

wst.print_hardware_status()

# Core Simulation Logic and Rendering

In this section, we test the core simulator and renderer (without RL adapters and integrations).

In [None]:
import world_of_supply_environment as ws
import world_of_supply_renderer as wsr
import world_of_supply_tools as wst
for module in [ws, wsr, wst]:
    importlib.reload(module)
        
# Measure the simulation rate, steps/sec
eposod_len = 1000
n_episods = 10
world = ws.WorldBuilder.create()
tracker = wst.SimulationTracker(eposod_len, n_episods, world.facilities.keys())
with tqdm(total=eposod_len * n_episods) as pbar:
    for i in range(n_episods):
        world = ws.WorldBuilder.create()
        policy = ws.SimpleControlPolicy()
        for t in range(eposod_len):
            outcome = world.act(policy.compute_control(world))
            tracker.add_sample(i, t, world.economy.global_balance().total(), 
                               {k: v.total() for k, v in outcome.facility_step_balance_sheets.items() } )
            pbar.update(1)        
tracker.render()
    
# Test rendering
renderer = wsr.AsciiWorldRenderer()
frame_seq = []
world = ws.WorldBuilder.create()
policy = ws.SimpleControlPolicy()
for epoch in tqdm(range(300)):
    frame = renderer.render(world)
    frame_seq.append(np.asarray(frame))
    world.act(policy.compute_control(world))

print('Rendering the animation...')
wsr.AsciiWorldRenderer.plot_sequence_images(frame_seq)

# Policy Training

In this section, we run RLlib policy trainers. These trainers evaluate the hand coded policy, learn a new policy from scrath, or learn a new policy by playing against the hand coded policy.

In [None]:
import world_of_supply_rllib_models as wsm
importlib.reload(wsm)
import world_of_supply_rllib as wsrl
importlib.reload(wsrl)
import world_of_supply_rllib_training as wsrt
importlib.reload(wsrt)

wsrt.print_model_summaries()

# Policy training
#trainer = wsrt.play_baseline(n_iterations = 1)
trainer = wsrt.train_ppo(n_iterations = 600)

# Policy Evaluation

In this section, we evaluate the trained policy.

### Rendering One Episod for the Trained Policy

In [None]:
import world_of_supply_renderer as wsren
import world_of_supply_tools as wst
import world_of_supply_rllib as wsrl
import world_of_supply_rllib_training as wstr
for module in [wsren, wst, wsrl, wstr]:
    importlib.reload(module)

# Parameters of the tracing simulation
policy_mode = 'baseline'   # 'baseline' or 'trained'
episod_duration = 500
steps_to_render = None#(0, episod_duration)  # (0, episod_duration) or None

# Create the environment
renderer = wsren.AsciiWorldRenderer()
frame_seq = []
env_config_for_rendering = wstr.env_config.copy()
env_config_for_rendering.update({
    'downsampling_rate': 1
})
env = wsrl.WorldOfSupplyEnv(env_config_for_rendering)
env.set_iteration(1, 1)
print(f"Environment: Producer action space {env.action_space_producer}, Consumer action space {env.action_space_consumer}, Observation space {env.observation_space}")
states = env.reset()
infos = None
    
def load_policy(agent_id):
    if policy_mode == 'baseline':
        if wsrl.Utils.is_producer_agent(agent_id):
            return wsrl.ProducerSimplePolicy(env.observation_space, env.action_space_producer, wsrl.SimplePolicy.get_config_from_env(env))
        elif wsrl.Utils.is_consumer_agent(agent_id):
            return wsrl.ConsumerSimplePolicy(env.observation_space, env.action_space_consumer, wsrl.SimplePolicy.get_config_from_env(env))
        else:
            raise Exception(f'Unknown agent type {agent_id}')
    
    if policy_mode == 'trained':
        policy_map = wstr.policy_mapping_global.copy()
        wstr.update_policy_map(policy_map)   
        return trainer.get_policy(wstr.create_policy_mapping_fn(policy_map)(agent_id))

policies = {}
rnn_states = {}
for agent_id in states.keys():
    policies[agent_id] = load_policy(agent_id)
    rnn_states[agent_id] = policies[agent_id].get_initial_state()
    
# Simulation loop
tracker = wst.SimulationTracker(episod_duration, 1, env.agent_ids())
for epoch in tqdm(range(episod_duration)):
    
    action_dict = {}
    if epoch % wstr.env_config['downsampling_rate'] == 0:
        for agent_id, state in states.items():
            policy = policies[agent_id]
            rnn_state = rnn_states[agent_id]
            if infos is not None and agent_id in infos:
                action_dict[agent_id], rnn_state, _ = policy.compute_single_action( state, info=infos[agent_id], state=rnn_state ) 
            else:
                action_dict[agent_id], rnn_state, _ = policy.compute_single_action( state, state=rnn_state )   
   
    states, rewards, dones, infos = env.step(action_dict)
    tracker.add_sample(0, epoch, env.world.economy.global_balance().total(), rewards)
    
    if steps_to_render is not None and epoch >= steps_to_render[0] and epoch < steps_to_render[1]:
        frame = renderer.render(env.world)
        frame_seq.append(np.asarray(frame))
 
tracker.render()

if steps_to_render is not None:
    print('Rendering the animation...')
    wsren.AsciiWorldRenderer.plot_sequence_images(frame_seq)