# Multi Agent Environments for Vehicle Routing Problems

---
# Exploring MAEnvs4VRP library
### April / 2025

---

### Install

Uncomment the following cells:

In [None]:
# !git clone https://github.com/ricgama/maenvs4vrp_beta.git # When using Colab

In [None]:
# When using Colab
# %cd maenvs4vrp_beta/
# ! pip install -e .
#%cd maenvs4vrp/notebooks/

In [None]:
# When using Binder
#%cd ../../
#! pip install -e . 

The objective of this notebook is to guide the user on the exploration of **MAEnvs4VRP** library, presenting a series of small hands-on coding challenges.

In [None]:
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

## Basic API usage example:

We will start with the Team Orienteering Problem with Time Windows environment.

In [None]:
from maenvs4vrp.environments.toptw.env import Environment
from maenvs4vrp.environments.toptw.env_agent_selector import AgentSelector
from maenvs4vrp.environments.toptw.observations import Observations
from maenvs4vrp.environments.toptw.instances_generator import InstanceGenerator
from maenvs4vrp.environments.toptw.env_agent_reward import DenseReward

In [None]:
gen = InstanceGenerator(batch_size = 8)
obs = Observations()
sel = AgentSelector()
rew = DenseReward()

env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

One important Environment attribute is `env.td_state`. Before `reset`: 

In [None]:
env.td_state

In [None]:
td = env.reset(batch_size = 8, num_agents=4, num_nodes=16)

After `reset` the `env.td_state` changes to:

In [None]:
env.td_state

Also, on the `td` we have:

In [None]:
td

In [None]:
td["done"]

Let's run an episode:

In [None]:
while not td["done"].all():  
    td = env.sample_action(td) # this is where we insert our policy
    td = env.step(td)

## Quick walkthrough

Let's now go through the library's building blocks, exploring their functionalities.

### Instance generation

In [None]:
instance = gen.sample_instance(num_agents=2, num_nodes=10)

In [None]:
instance.keys()

It's possible to load a set of pre-generaded instances, to be used as validation/test sets. For example:

In [None]:
gen.get_list_of_benchmark_instances()['servs_100_agents_5']['validation']

In [None]:
set_of_instances = set(gen.get_list_of_benchmark_instances()['servs_100_agents_5']['validation'])

In [None]:
generator = InstanceGenerator(instance_type='validation', set_of_instances=set_of_instances)

In [None]:
instance = generator.sample_instance()

Let's check instance dict keys:

In [None]:
instance.keys()

In [None]:
instance['name']

#### Benchmark instances

In [None]:
from maenvs4vrp.environments.toptw.benchmark_instances_generator import BenchmarkInstanceGenerator

In order to narrow the current gap between the test beds for algorithm benchmarking used in RL
and OR communities, the library allows a straightforward integration of classical OR benchmark
instances. For example, we can load a set of classical benchmark instances. Let's see what benchmark instances we have for the TOPTW:

In [None]:
BenchmarkInstanceGenerator.get_list_of_benchmark_instances()

In [None]:
generator = BenchmarkInstanceGenerator(instance_type='Solomon', set_of_instances={'c101', 'c102'})

In [None]:
instance_c101 = generator.get_instance('c101')

In [None]:
instance_c101.keys()

In [None]:
instance_c101['name']

In [None]:
instance_c101['num_agents']

In [None]:
instance_c101['num_nodes']

###  Observations

Observation features, that will be available to the active agent while interacting with the environment, are handle by `Observations` class. 
The class has a `default_feature_list` attribute where the default configuration dictionary is defined.

In [None]:
obs.default_feature_list

Also, five possible features lists exist, detailing the available features in the class: `POSSIBLE_NODES_STATIC_FEATURES`, `POSSIBLE_NODES_DYNAMIC_FEATURES`, `POSSIBLE_SELF_FEATURES`, `POSSIBLE_AGENTS_FEATURES`, `POSSIBLE_GLOBAL_FEATURES`. For example:

In [None]:
obs.POSSIBLE_NODES_STATIC_FEATURES

In [None]:
obs.POSSIBLE_GLOBAL_FEATURES

While instantiating the `Observations` class, we can pass through a feature list dictionary specifying which features will be available for the agent:

In [None]:
import yaml

In [None]:
feature_list = yaml.safe_load("""
    nodes_static:
        x_coordinate_min_max:
            feat: x_coordinate_min_max
            norm: min_max
        x_coordinate_min_max: 
            feat: x_coordinate_min_max
            norm: min_max
        tw_low_mm:
            feat: tw_low
            norm: min_max
        tw_high:
            feat: tw_high
            norm: min_max

    nodes_dynamic:
        - time2open_div_end_time
        - time2close_div_end_time
        - time2open_after_step_div_end_time
        - time2close_after_step_div_end_time
        - fract_time_after_step_div_end_time

    agent:
        - x_coordinate_min_max
        - y_coordinate_min_max
        - frac_current_time

    other_agents:
        - x_coordinate_min_max
        - y_coordinate_min_max
        - frac_current_time
        - dist2agent_div_end_time
    
    global:
        - frac_done_agents
        - frac_colect_profits
""")

In [None]:
obs = Observations(feature_list)

We can test these observations on the environment:

In [None]:
gen = InstanceGenerator(batch_size=8)
sel = AgentSelector()
rew = DenseReward()

env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

In [None]:
td = env.reset(batch_size = 8, num_agents=4, num_nodes=16)

In [None]:
td_observation = env.observe()

In [None]:
td_observation

Let's run an episode:

In [None]:
while not td["done"].all():  
    td = env.sample_action(td) # this is where we insert our policy
    td = env.step(td)

and check the collected profits:

In [None]:
env.td_state['agents']['cum_profit'].sum(-1)

An environment with agents performing random actions is not very impressive. Let's train a policy with [PPO algorithm](https://spinningup.openai.com/en/latest/algorithms/ppo.html) to get smarter agents:

In [None]:
# When using Binder
#%cd maenvs4vrp/notebooks/
# When using Colab
#%cd maenvs4vrp_beta/maenvs4vrp/learning

In [None]:
%run ../learning/train_ma_ppo.py --vrp_env toptw --num_agents 4 --num_nodes 21

## Challenges

In [None]:
# when using Colab
#%cd ../notebooks/

### Ex0. Warm-up

Ok! Let's now try some small hands-on coding challenges. To simplify solution verification, allowing a pen and paper check, let's use some small toy instances.

In [None]:
from maenvs4vrp.environments.toptw.toy_instance_generator import ToyInstanceGenerator

In [None]:
gen = ToyInstanceGenerator()
obs = Observations()
sel = AgentSelector()
rew = DenseReward()

env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

In [None]:
td = env.reset()

The services and depot location is:

In [None]:
fig = plt.figure(figsize=(3,3))
plt.plot(env.td_state['coords'][0][:,0].numpy(), env.td_state['coords'][0][:,1].numpy(), 'o')
plt.plot(env.td_state['coords'][0][0,0].numpy(), env.td_state['coords'][0][0,1].numpy(), 'o', color='red' )

and the time windows:

In [None]:
for k, data in enumerate(zip(env.td_state['tw_low'][0].tolist(), env.td_state['tw_high'][0].tolist())):
    print(f'node {k} time window is: [{data[0]}; {data[1]}]')

All the agents start at the depot (node 0 / red dot). The distance (time) from the depot to all the nodes is:

In [None]:
loc = env.td_state['coords'].gather(1, env.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
time2j = torch.pairwise_distance(loc, env.td_state["coords"], eps=0, keepdim = False)
time2j[0]

I) If the agent selects to visit node 1, what will be the collected profit? 

II) Checking the previous distance values, time windows and the new distances, what will be the mask of the admissible nodes after this step?

(hint: check the `env.td_state` attribute.)

In [None]:
td['action'] = torch.tensor([[1]])

In [None]:
td = env.step(td)

In [None]:
# %load snippets/ex0.py
# your code here!!

Now, let's move on to exploring the `Observations` module:

### Ex1. Team Orienteering Problem with Time Windows - Observations

In [None]:
from maenvs4vrp.environments.toptw.env import Environment
from maenvs4vrp.environments.toptw.env_agent_selector import AgentSelector
from maenvs4vrp.environments.toptw.env_agent_reward import DenseReward

One important aspect of agent's training is their capability to retrieve useful information from the environment in order to act on it. In MAEnvs4VRP we can build our custom observations methods within the `Observations` class.

In [None]:
from maenvs4vrp.environments.toptw.observations import Observations

In [None]:
gen = ToyInstanceGenerator()
sel = AgentSelector()
rew = DenseReward()

In [None]:
obs = Observations()

The class has a `default_feature_list` attribute where the default configuration dictionary is defined.

In [None]:
obs.default_feature_list

Also, five possible features lists exist, detailing the available features in the class: `possible_nodes_static_features`, `possible_nodes_dynamic_features`, `possible_agent_features`, `possible_agents_features`, `possible_global_features`. For example:

In [None]:
obs.possible_nodes_dynamic_features

Lets see how to add another nodes dynamic observation.

I) Change the code below in order to implement the nodes dynamic feature `wait_time_div_end_time`:

In [None]:
# %load snippets/ex1.py
class Observations(Observations):
    
    def __init__(self, feature_list:dict = None):
        super().__init__()
        
        self.default_feature_list['nodes_dynamic'].append('wait_time_div_end_time')
        self.possible_nodes_dynamic_features.append('wait_time_div_end_time')
    
    def get_feat_wait_time_div_end_time(self):
        """ dynamic feature
        Args:

        Returns: 
            Tensor: waiting time at nodes divided by end time.
        """
        loc = self.env.td_state['coords'].gather(1, self.env.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
        ptime = self.env.td_state['cur_agent']['cur_time'].clone()
        time2j = torch.pairwise_distance(loc, self.env.td_state["coords"], eps=0, keepdim = False)
        #arrivej = !! your code here !!
        #wait = !! your code here !!
        return wait / self.env.td_state['end_time'].unsqueeze(dim=-1)
    

In [None]:
obs = Observations()

We can re-check the possible nodes dynamic features available:

In [None]:
obs.possible_nodes_dynamic_features

and the ones the that are going to used by the agent:

In [None]:
obs.default_feature_list['nodes_dynamic']

Ok! Now, let's creat the `TOPTW` environment:

In [None]:
env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

II) Check if your answer is correct, by running a couple of environment steps.

Note: the observation feature will be on the `obs.default_feature_list['nodes_dynamic'].index('wait_time_div_end_time')` position of the `node_dynamic_obs` tensor.

In [None]:
# %load snippets/ex2.py
# check the new feature position here

In [None]:
td = env.reset()

In [None]:
# %load snippets/ex3.py
#check the nodes dynamic observations on the td 

Let's choose a node to move to and perform an env step (change the number `3` to any other option):

In [None]:
td['action'] = torch.tensor([[3]])

In [None]:
td = env.step(td)

In [None]:
# your code here

In [None]:
# your code here

III) Think of another potentially useful observation feature for this environment. Implement and test it.

In [None]:
# your code here

In [None]:
# your code here

### Ex2. Split Delivery Vehicle Routing Problem with Time Windows (SDVRPTW)

The Split Delivery Vehicle Routing Problem with Time Windows (SDVRPTW) is a generalization of the CVRPTW where each customer can be visited more than once by several vehicles and a fraction of the demand can be met.

We are going to walk through the changes that we have to perform on the CVRPTW environment to obtain SDVRPTW, but first let's check the CVRPTW environment.

In [None]:
from maenvs4vrp.environments.cvrptw.toy_instance_generator import ToyInstanceGenerator

In [None]:
gen = ToyInstanceGenerator()
inst = gen.sample_instance()

The services and depot location is:

In [None]:
fig = plt.figure(figsize=(3,3))
plt.plot(inst['data']['coords'][0][:,0].numpy(), inst['data']['coords'][0][:,1].numpy(), 'o')
plt.plot(inst['data']['coords'][0][0,0].numpy(), inst['data']['coords'][0][0,1].numpy(), 'o', color='red' )

With time windows and demands:

In [None]:
for k, data in enumerate(zip(inst['data']['tw_low'][0].tolist(), inst['data']['tw_high'][0].tolist(), inst['data']['demands'][0].tolist())):
    print(f'node {k} time window is: [{data[0]}; {data[1]}], with demand {data[2]}')

In [None]:
from maenvs4vrp.environments.cvrptw.env import Environment
from maenvs4vrp.environments.cvrptw.env_agent_selector import AgentSelector
from maenvs4vrp.environments.cvrptw.env_agent_reward import DenseReward
from maenvs4vrp.environments.cvrptw.observations import Observations

In [None]:
gen = ToyInstanceGenerator()
sel = AgentSelector()
rew = DenseReward()
obs = Observations()

In [None]:
env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

In [None]:
td = env.reset()

I) check what agent is active and what actions are admissible for him.

Note: on the `td` acess `cur_agent_idx` and `observations`/`action_mask` keys.

In [None]:
# %load snippets/ex4.py
# your code here

This information is also available by accessing the environment `td_state` attribute on the `cur_agent` key.

In [None]:
env.td_state['cur_agent']

In [None]:
env.td_state['cur_agent']['cur_load']

Let's choose to serve node `2`:

In [None]:
td['action'] = torch.tensor([[2]])
td = env.step(td)

In [None]:
action = torch.tensor([[2]])

II) What should the new `cur_load` and `action_mask` be? Check your answer. 

In [None]:
# %load snippets/ex5.py
# your code here

III) What happens if the agents try to serve the node `1`?

In [None]:
# %load snippets/ex6.py
# your code here

OK. Now, lets see what changes are needed to the CVRTPW environment to obtain SDVRPT. We will need to adapt the `_update_feasibility` and `_update_state` methods on the Environment class. Everything else will be the same.

In [None]:
from maenvs4vrp.environments.sdvrptw.toy_instance_generator import ToyInstanceGenerator
from maenvs4vrp.environments.sdvrptw.env import Environment
from maenvs4vrp.environments.sdvrptw.env_agent_selector import AgentSelector
from maenvs4vrp.environments.sdvrptw.env_agent_reward import DenseReward
from maenvs4vrp.environments.sdvrptw.observations import Observations

In [None]:
gen = ToyInstanceGenerator()
sel = AgentSelector()
rew = DenseReward()
obs = Observations()

Let's start with the `_update_feasibility` method:

In [None]:
# %load snippets/ex7.py
# your code here
class Environment(Environment):

    def _update_feasibility(self):

        _mask = self.td_state['nodes']['active_nodes_mask'].clone() * self.td_state['cur_agent']['action_mask'].clone()

        # time windows constraints
        loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
        ptime = self.td_state['cur_agent']['cur_time'].clone()
        time2j = torch.pairwise_distance(loc, self.td_state["coords"], eps=0, keepdim = False)
        if self.n_digits is not None:
            time2j = torch.floor(self.n_digits * time2j) / self.n_digits
        arrivej = ptime + time2j
        waitj = torch.clip(self.td_state['tw_low']-arrivej, min=0)
        service_startj = arrivej + waitj

        c1 = service_startj <= self.td_state['tw_high']
        c2 = service_startj + self.td_state['service_time'] + self.td_state['time2depot'] <= self.td_state['end_time'].unsqueeze(-1)

        # capacity constraints (if there is no load, the agent can only return to the depot)
        c3 = torch.ones_like(_mask, dtype=torch.bool, device=env.device)
        #c3[self.td_state['cur_agent']['cur_load'].le(0).squeeze(-1)] = !!your code here!!
        #c3[self.td_state['cur_agent']['cur_load'].le(0).squeeze(-1), self.td_state['depot_idx']] = !!your code here!!

        _mask = _mask * c1 * c2 * c3
        # update state
        self.td_state['cur_agent'].update({'action_mask': _mask}) 
        self.td_state['agents']['feasible_nodes'].scatter_(1, 
                                            self.td_state['cur_agent_idx'][:,:,None].expand(-1,-1,self.num_nodes), _mask.unsqueeze(1))

Now the `_update_state` method:

In [None]:
# %load snippets/ex8.py
class Environment(Environment):

    def _update_state(self, action):
        loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
        next_loc = self.td_state['coords'].gather(1, action[:,:,None].expand(-1, -1, 2))

        ptime = self.td_state['cur_agent']['cur_time'].clone()
        time2j = torch.pairwise_distance(loc, next_loc, eps=0, keepdim = False)
        if self.n_digits is not None:
            time2j = torch.floor(self.n_digits * time2j) / self.n_digits
        tw = self.td_state['tw_low'].gather(1, action)
        service_time = self.td_state['service_time'].gather(1, action)

        arrivej = ptime + time2j
        waitj = torch.clip(tw-arrivej, min=0)

        time_update = arrivej + waitj + service_time
        # update agent cur node
        self.td_state['cur_agent']['cur_node'] = action
        self.td_state['agents']['cur_node'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_node'])
        # update agent cur time
        self.td_state['cur_agent']['cur_time'] = time_update

        # is agent is done set agent time to end_time
        agents_done = ~self.td_state['agents']['active_agents_mask'].gather(1, self.td_state['cur_agent_idx']).clone()
        self.td_state['cur_agent']['cur_time'] = torch.where(agents_done, self.td_state['end_time'].unsqueeze(-1), 
                                                             self.td_state['cur_agent']['cur_time'])
        self.td_state['agents']['cur_time'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_time'])

        # update agent cum traveled time
        self.td_state['cur_agent']['cur_ttime'] = time2j
        self.td_state['cur_agent']['cum_ttime'] += time2j
        self.td_state['agents']['cur_ttime'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_ttime'])
        self.td_state['agents']['cum_ttime'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cum_ttime'])
        
        # update agent load and node demands
        #cur_demands = !!your code here!!
        #current_load = !!your code here!!
        #load_transfer =  !!your code here!!
        self.td_state['cur_agent']['cur_load'] -= load_transfer

        # if agent is done set agent cur_load to 0
        self.td_state['cur_agent']['cur_load'] = torch.where(agents_done, 0., 
                                                             self.td_state['cur_agent']['cur_load'])
        
        self.td_state['nodes']['cur_demands'].scatter_(1, action, cur_demands-load_transfer)
        # update done nodes
        self.td_state['nodes']['active_nodes_mask'] = self.td_state['nodes']['cur_demands'].gt(0)
        self.td_state['nodes']['active_nodes_mask'].scatter_(1, self.td_state['depot_idx'], True)

        self.td_state['agents']['cur_load'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_load'])
        # update visited nodes
        r = torch.arange(*self.td_state.batch_size, device=self.device)
        self.td_state['agents']['visited_nodes'][r, self.td_state['cur_agent_idx'].squeeze(-1), action.squeeze(-1)] = True
        # update agent step
        self.td_state['cur_agent']['cur_step'] = torch.where(~agents_done, self.td_state['cur_agent']['cur_step']+1, 
                                                             self.td_state['cur_agent']['cur_step'])
        self.td_state['agents']['cur_step'].scatter_(1, self.td_state['cur_agent_idx'], self.td_state['cur_agent']['cur_step'])

        # if all done activate first agent to guarantee batch consistency during agent sampling
        self.td_state['agents']['active_agents_mask'][self.td_state['agents']['active_agents_mask'].sum(1).eq(0), 0] = True
        self._update_feasibility()


Let's test the environment, and repeat the steps we have performed for de CVRPTW:

In [None]:
env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

In [None]:
td = env.reset()

In [None]:
td['action'] = torch.tensor([[2]])
td = env.step(td)

In [None]:
env.td_state['cur_agent']['action_mask']

In [None]:
env.td_state['cur_agent']['cur_load']

IV) What happens if the agents now goes the node `1`?

In [None]:
# %load snippets/ex9.py
# your code here

V) What should the new `cur_load`, `action_mask` and nodes `cur_demands` be? Check your answer. 

In [None]:
# %load snippets/ex10.py
# your code here

It seems to be working!

### Ex3. Capacitated Vehicle Routing Problem with Soft Time Windows (CVRPSTW)

In this variation of the CVRPTW, time window constraints are relaxed and can be violated at a penalty cost (usually linear proportional to the interval between opening/closing times and vehicle arrival). Although the penalty function can be defined in several ways, we consider the formulation studied in [M. A. Figliozzi](https://www.sciencedirect.com/science/article/abs/pii/S0968090X09001119)). 
Concretely, the time window violation cannot exceed $P_{max}$, and consequently, for each customer, we can enlarge its time window to $[o_i - P_{max}, c_i + P_{max}] = [o^s_i , c^s_i]$ outside which the service cannot be performed. When a vehicle arrives at a customer at time $t_i \in [o^s_i , c^s_i]$, it can have an early arrival penalty cost of $p_e \max (o_i-t_i,0)$ and a late arrival penalty cost of $p_l \max (t_i-c_i, 0)$.

Furthermore, the vehicle's maximum waiting time at any customer, $W_{max}$, is imposed. That is, the vehicles can only arrive at each customer after $o_i - P_{max} - W_{max}$, so that its waiting time doesn't exceed $W_{max}$.

The environment for this problem has already been almost done for us. Compared to the base CVRPTW environment, `early_penalty` and `late_penalty` attributes were added to the environment and `tw_high_limit`, `tw_high_limit`, `arrive_limit` attributes were added to `td_state`.

In [None]:
from maenvs4vrp.environments.cvrpstw.env import Environment
from maenvs4vrp.environments.cvrpstw.env_agent_selector import AgentSelector
from maenvs4vrp.environments.cvrpstw.observations import Observations
from maenvs4vrp.environments.cvrpstw.toy_instance_generator import ToyInstanceGenerator
from maenvs4vrp.environments.cvrpstw.env_agent_reward import DenseReward

In [None]:
gen = ToyInstanceGenerator()
sel = AgentSelector()
rew = DenseReward()
obs = Observations()

II) Complete the `_update_feasibility` method in order to take into account the waiting time constraint:

In [None]:
# %load snippets/ex11.py
class Environment(Environment):
 
    def _update_feasibility(self):

        _mask = self.td_state['nodes']['active_nodes_mask'].clone() * self.td_state['cur_agent']['action_mask'].clone()

        # time windows constraints
        loc = self.td_state['coords'].gather(1, self.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
        ptime = self.td_state['cur_agent']['cur_time'].clone()
        time2j = torch.pairwise_distance(loc, self.td_state["coords"], eps=0, keepdim = False)
        if self.n_digits is not None:
            time2j = torch.floor(self.n_digits * time2j) / self.n_digits

        arrivej = ptime + time2j
        waitj = torch.clip(self.td_state['tw_low_limit']-arrivej, min=0)
        service_startj = arrivej + waitj

        #c0 = !! your code here !! # agents can only arrive at each customer after $o_i - P_{max} - W_{max}$
        c1 = service_startj <= self.td_state['tw_high_limit']
        c2 = service_startj + self.td_state['service_time'] + self.td_state['time2depot'] <= self.td_state['end_time'].unsqueeze(-1)

        # capacity constraints
        c3 = self.td_state['demands'] <= self.td_state['cur_agent']['cur_load']

        _mask = _mask * c0 * c1 * c2 * c3
        # update state
        self.td_state['cur_agent'].update({'action_mask': _mask}) 
        self.td_state['agents']['feasible_nodes'].scatter_(1, 
                                            self.td_state['cur_agent_idx'][:,:,None].expand(-1,-1,self.num_nodes), _mask.unsqueeze(1))


II) Complete the the `DenseReward` class in order to take into acount the penalty for time windows violation:

(hint: check `td_state['cur_agent']['cur_penalty']` )

In [None]:
# %load snippets/ex12.py
class DenseReward(DenseReward):
    """Reward class.
    """

    def get_reward(self, action):
        """
        
        """

        # your code here!!
    
        return reward, penalty

In [None]:
rew = DenseReward()

In [None]:
env = Environment(instance_generator_object=gen,  
                  obs_builder_object=obs,
                  agent_selector_object=sel,
                  reward_evaluator=rew,
                  seed=0)

In [None]:
td = env.reset()

Let's get some information about the environment:

In [None]:
fig = plt.figure(figsize=(3,3))
plt.plot(env.td_state['coords'][0][:,0].numpy(), env.td_state['coords'][0][:,1].numpy(), 'o')
plt.plot(env.td_state['coords'][0][0,0].numpy(), env.td_state['coords'][0][0,1].numpy(), 'o', color='red' )

In [None]:
for k, data in enumerate(zip(env.td_state['tw_low'][0].tolist(), env.td_state['tw_high'][0].tolist(), env.td_state['demands'][0].tolist())):
    print(f'node {k} time window is: [{data[0]}; {data[1]}], with demand {data[2]}')

In [None]:
for k, data in enumerate(zip(env.td_state['tw_low_limit'][0].tolist(), env.td_state['tw_high_limit'][0].tolist(), env.td_state['arrive_limit'][0].tolist())):
    print(f'node {k} time window limit is: [{data[0]:.2f}; {data[1]:.2f}], with arrive time limit {data[2]:.2f}')

For the active agent in the depot, the times (distances) to customers will be:

In [None]:
loc = env.td_state['coords'].gather(1, env.td_state['cur_agent']['cur_node'][:,:,None].expand(-1, -1, 2))
time2j = torch.pairwise_distance(loc, env.td_state["coords"], eps=0, keepdim = False)
time2j[0]

I) Make some environment steps to check if our implementation is correct. 

II) What `reward` and `penalty` values are expected?

In [None]:
# %load snippets/ex13.py

In [None]:
# your code here

In [None]:
# your code here

In [None]:
# your code here

Let's do an episode rollout and check the `reward` and `penalty` through every step:

In [None]:
td = env.reset()
while not td["done"].all():  
    td = env.sample_action(td) 
    td = env.step(td)
    step = env.env_nsteps
    reward = td['reward']
    penalty = td['penalty']
    print(f'env step number:{step}, reward: {reward}, penalty: {penalty}')


##### Well done! That's it for today. For any comments and suggestions, please drop us an email.

---