In [45]:
import numpy as np
import dimod
import xtools as xt
import matplotlib.pyplot as plt
from openjij import SQASampler

from typing import Callable, Dict, List, Optional, Type, Union

from atm.flight.generator import ScenarioGenerator, Scenario
from atm.separation import recat
from atm.separation.base import Separation
from atm.flight.flight import DUMMY_FLIGHT_RECAT

In [23]:
NP = np.ndarray

In [6]:
num_vol = 10
num_rwy = 2
sep = recat.TBS

cf = xt.Config(dict(
    dt=10,
    num_vol=30,
    scenario=dict(
        interval=60,
        window=200,
        mode="mix",
        standard="recat"
    )
))
gen = ScenarioGenerator(cf.scenario)

In [7]:
vols = gen(num_vol)
vols.to_dataframe()

Unnamed: 0,code,ready,due,category,operation
0,VOL0001,86,286,F,A
1,VOL0002,143,343,F,D
2,VOL0003,169,369,F,A
3,VOL0004,201,401,E,D
4,VOL0005,209,409,B,A
5,VOL0006,274,474,F,D
6,VOL0007,312,512,D,A
7,VOL0008,319,519,A,A
8,VOL0009,589,789,E,D
9,VOL0010,618,818,C,A


In [61]:
def calc_assign_time(indices: NP, scenario: Scenario, separation: Separation) -> NP:
    past_vol = DUMMY_FLIGHT_RECAT
    past_time = past_vol.ready

    times = []
    for idx in indices:
        vol = scenario[idx]
        sep_time = separation(past_vol, vol)
        time = np.max([
            vol.ready,
            past_time + sep_time
        ])
        times.append(time)

        past_vol = vol
        past_time = time

    return np.asarray(times)


def calc_delay(indices: NP, assigned_times: NP, scenario: Scenario) -> NP:
    readies = np.array([scenario[idx].ready for idx in indices])
    return assigned_times - readies

def get_due(indices: NP, scenario: Scenario) -> NP:
    return np.array([
        scenario[idx].due for idx in indices
    ])

def check_overtime(assigned_times: NP, dues: NP) -> NP:
    return np.array([
        assigned_time > due
        for assigned_time, due in zip(assigned_times, dues)
    ])

In [62]:
def calc_assign_time_for_multi_runway(indices: NP, scenario: Scenario, separation: Separation) -> List[NP,]:
    return [
        calc_assign_time(indices_, scenario, separation)
        for indices_ in indices
    ]


def calc_delay_for_multi_runway(indices: NP, assigned_times: List[NP], scenario: Scenario) -> List[NP]:
    return [
        calc_delay(indices_, times_, scenario)
        for indices_, times_ in zip(indices, assigned_times)
    ]

def get_due_for_multi_runway(indices: NP, scenario: Scenario) -> List[NP]:
    return [
        get_due(indices_, scenario)
        for indices_ in indices
    ]

def check_over_time_for_multi_runway(assigned_times: List[NP], dues: List[NP]) -> List[NP]:
    return [
        check_overtime(assigned_times_, dues_)
        for assigned_times_, dues_ in zip(assigned_times, dues)
    ]

def count_num_overtime(is_overtimes: NP) -> int:
    return np.sum(np.asarray(is_overtimes).astype(int)).item()

In [63]:
def get_objective(scenario: Scenario, separation: Separation, penalty_coef: float = 1.0) -> Callable:
    def func(xs):
        times = calc_assign_time_for_multi_runway(xs, scenario, separation)
        delays = calc_delay_for_multi_runway(xs, times, scenario)
        dues = get_due_for_multi_runway(xs, scenario)
        is_overtimes = check_over_time_for_multi_runway(times, dues)
        num_overtime = count_num_overtime(is_overtimes)

        return np.mean(delays).astype(float) + penalty_coef * num_overtime
    return func

In [64]:
obj = get_objective(vols, sep, 0.0)

In [67]:
xs = np.random.choice(np.arange(num_vol), num_vol, replace=False).reshape((num_rwy, num_vol // num_rwy))
obj(xs)

178.8