Rzeczy do (ewentualnego) poprawienia/obsłużenia:
- zamiast zwykłego multigrafu, zrobić multigraf skierowany, bo traktujemy loty w jedną strone jako loty w obie naraz *//poniekąd obsłużone*
- zawsze wyświetlana jest data wylotu. Oznacza to, że jeżeli przylot ma miejsce nastęonego dnia, ma on miejsce dzień później niż data w `arrival time` *//obsłużone*

In [169]:
import networkx as nx
import pandas as pd
import numpy as np
import queue
import datetime
import time
import random


def get_price(dist):
    return round(np.random.normal(0.20, 0.07) * dist, 2)

def parse_date(year, month, day, time):
    return datetime.datetime(year, month, day, int('{:04d}'.format(time)[:2]), int('{:04d}'.format(time)[2:]))

COLS = ['YEAR', 'MONTH', 'DAY', 'ORIGIN_AIRPORT', 'DESTINATION_AIRPORT', 'SCHEDULED_DEPARTURE', 'DISTANCE', 'SCHEDULED_TIME', 'SCHEDULED_ARRIVAL']
N = 100000 #max = 5819079

flights = pd.read_csv("data/flights.csv")[COLS]
flights = flights.head(N)
flights = flights[flights['SCHEDULED_TIME'].notna()]
flights['PRICE'] = flights['DISTANCE'].apply(lambda x: get_price(x))
# print(flights.head(10))

MG = nx.MultiGraph()

for idx, row in flights.iterrows():
    departure_time = parse_date(row['YEAR'], row['MONTH'], row['DAY'], row['SCHEDULED_DEPARTURE'])
    arrival_time = parse_date(row['YEAR'], row['MONTH'], row['DAY'], row['SCHEDULED_ARRIVAL'])
    if arrival_time < departure_time:
        arrival_time += datetime.timedelta(days=1)

    MG.add_edge(row['ORIGIN_AIRPORT'], row['DESTINATION_AIRPORT'], origin=row['ORIGIN_AIRPORT'], destination=row['DESTINATION_AIRPORT'], departure_time=departure_time, arrival_time=arrival_time, flight_time=row['SCHEDULED_TIME'], price=row['PRICE'], pheromone_level=0, pheromone_update_time=0)

  exec(code_obj, self.user_global_ns, self.user_ns)


In [174]:
ANTS_NUMBER = 100
ANTS_SPAWN_ITERS = 10
SIMULATION_TIME_S = 20
SIMULATION_ITERS_NUM = 100
CONNECTION_SAMPLES = 3
DIRECT_CONNECTION_IMPACT = 0.8
TIME_IMPACT = 0.4
COST_IMPACT = 0.2
PHEROMONE_IMPACT = 0.4
ANT_PHEROMONE_LEVEL = 100
PHEROMONE_UPDATE = 0.5
PHEROMONE_UPDATING_TIME = 1000

In [175]:
class Ant:
    def __init__(self, curr_time, curr_airport):
        self.curr_time = curr_time
        self.curr_airport = curr_airport
        self.birth_time = curr_time
        
        self.curr_trav_cost = 0
        self.curr_conn_numb = 0
        self.mode = 0
        self.path = [curr_airport]
        
    def __lt__ (self, other):
        if self.mode != other.mode:
            return self.mode < other.mode
        return self.curr_time < other.curr_time
        
    def __gt__ (self, other):
        if self.mode != other.mode:
            return self.mode > other.mode
        return self.curr_time > other.curr_time

    def __eq__ (self, other):
        return self.mode == other.mode and self.curr_time == other.curr_time

    def __ne__ (self, other):
        return not self.__eq__(other)
    
    def _print_ant(self):
        print(f'\ntravel cost: {self.curr_trav_cost}\nconnections number: {self.curr_conn_numb}\nflights:\n {self.path}\n')
        
    def _update(self, next_flight):
        self.curr_time = next_flight[2]['arrival_time']
        self.curr_airport = next_flight[0]
        self.curr_trav_cost += next_flight[2]['price']
        self.curr_conn_numb += 1
        self.path.append((next_flight[0], next_flight[2]['departure_time'], next_flight[2]['arrival_time']))

In [179]:
#add max_conn_numb_handling
class AntColonyAlgorithm:
    def __init__(self, flights, origin, destination, min_time, max_time, min_conn_time, max_conn_numb, max_price):
        self.flights = flights
        self.origin = origin
        self.destination = destination
        self.min_time = min_time
        self.max_time = max_time
        self.min_conn_time = min_conn_time
        self.max_conn_numb = max_conn_numb
        self.max_price = max_price

        self.global_time = 0
        self.events = queue.PriorityQueue()

    def _init_ants(self):
        time_available_min = (self.max_time - self.min_time).total_seconds() // 60

        ants_spawn_gap = time_available_min // ANTS_SPAWN_ITERS
        for i in range(ANTS_SPAWN_ITERS):
            for j in range(ANTS_NUMBER // ANTS_SPAWN_ITERS):
                self.events.put((i * ants_spawn_gap, Ant(self.min_time + datetime.timedelta(minutes=i * ants_spawn_gap), self.origin)))
                
    def _run_next_event(self):
        self.global_time, curr_ant = self.events.get()
        available_flights = self._find_available_flights(curr_ant)
        self._update_pheromones(available_flights)
        next_flight = self._choose_flight(available_flights, curr_ant)
        self._make_next_flight(next_flight, curr_ant)

    ###
    def _find_available_flights(self, curr_ant): #todo: change to binary search
        all_flights = self.flights[curr_ant.curr_airport]
        available_flights = []
        for airport in all_flights:
            flights_added = 0

            for flight_idx in range(len(self.flights.adj[curr_ant.curr_airport][airport])):
                flight = self.flights.adj[curr_ant.curr_airport][airport][flight_idx]
                if self._is_accessible_flight(curr_ant, flight):
                    available_flights.append((airport, flight_idx, flight))

                    flights_added += 1
                    if flights_added == CONNECTION_SAMPLES:
                        break
        return available_flights

    def _is_accessible_flight(self, curr_ant, flight):
        good_airport = curr_ant.curr_airport == flight['origin']
        good_departure = curr_ant.curr_time + datetime.timedelta(minutes=self.min_conn_time) < flight['departure_time']
        good_arrival = flight['arrival_time'] < self.max_time
        good_cost = curr_ant.curr_trav_cost + flight['price'] <= self.max_price
        return good_airport and good_departure and good_arrival and good_cost

    # def _find_flight_idx(self, flights, curr_time): #something is not working :ch
    #     acc_time_min = curr_time + datetime.timedelta(minutes=self.min_conn_time)
    #     min_idx, max_idx = 0, len(flights) - 1
    #     while True:
    #         mid_idx = (min_idx + max_idx) // 2
    #         if flights[mid_idx]['departure_time'] < acc_time_min:
    #             min_idx = mid_idx
    #         elif mid_idx - 1 > 0 and flights[mid_idx - 1]['departure_time'] > acc_time_min:
    #             max_idx = mid_idx
    #         elif mid_idx in [(mid_idx + max_idx) // 2, (min_idx + mid_idx) // 2]:
    #             return None
    #         else:
    #             return (min_idx + max_idx) // 2

    ###
    def _update_pheromones(self, available_flights):
        for airport, index, flight in available_flights:
            time_gap = self.global_time - flight['pheromone_update_time']
            flight['pheromone_level'] = flight['pheromone_level'] * PHEROMONE_UPDATE ** (time_gap // PHEROMONE_UPDATING_TIME)
            flight['pheromone_update_time'] = self.global_time

    ###
    def _choose_flight(self, flights, curr_ant):
        if len(flights) == 0:
            return None

        for airport, index, flight in flights:
            if airport == self.destination:
                if random.random() < DIRECT_CONNECTION_IMPACT:
                    flight['pheromone_level'] += 100
                    return airport, index, flight

        waiting_times, prices, pheromones = np.empty(len(flights)), np.empty(len(flights)), np.empty(len(flights))
        min_time = curr_ant.curr_time + datetime.timedelta(minutes=self.min_conn_time)
        for index, (airport, flight_index, flight) in enumerate(flights):
            waiting_times[index] = (flight['departure_time'] - min_time).total_seconds() // 60
            prices[index] = flight['price']
            pheromones[index] = flight['pheromone_level']

        time_coeffs = (1 - waiting_times / np.max(waiting_times)) ** 3
        prices_coeffs = (1 - prices / np.max(prices)) ** 3
        pheromones_coeffs = pheromones / np.max(pheromones)
        combined_params = time_coeffs  * TIME_IMPACT + prices_coeffs  * COST_IMPACT + pheromones_coeffs * PHEROMONE_IMPACT

        flight_pos = random.uniform(0, np.sum(combined_params))
        for flight, params in zip(flights, combined_params):
            if params >= flight_pos:
                flight[2]['pheromone_level'] += 100
                return flight
            flight_pos -= params

    ###
    def _kill_ant(self, curr_ant):
        print(f'ant died :c')
        curr_ant._print_ant()

    def _spawn_ant(self):
        time_available_min = (self.max_time - self.min_time).total_seconds() // 60
        ant_spawn_gap = random.randint(0, time_available_min)
        self.events.put((self.global_time + 1, Ant(self.min_time + datetime.timedelta(minutes=ant_spawn_gap), self.origin)))

    ###
    def _make_next_flight(self, next_flight, curr_ant):
        if next_flight is None:
            self._kill_ant(curr_ant)
            self._spawn_ant()
        else:
            time_diff = (next_flight[2]['arrival_time'] - curr_ant.curr_time).total_seconds() // 60
            new_time = self.global_time + time_diff

            curr_ant._update(next_flight)
            if next_flight[0] == self.destination:
                curr_ant._print_ant()
                self._spawn_ant()
            else:
                self.events.put((new_time, curr_ant))

    ###
    def run(self, mode):
        self._init_ants()

        if mode == 'TIME':
            start_time = time.time()
            counter = 0
            while not self.events.empty() and (counter % 100 != 0 or time.time() - start_time < SIMULATION_TIME_S):
                self._run_next_event()
                counter += 1

        if mode == 'ITERS':
            counter = 0
            while not self.events.empty() and counter < SIMULATION_ITERS_NUM:
                self._run_next_event()
                counter += 1


In [180]:
simulation = AntColonyAlgorithm(MG, 'LAS', 'JAX', datetime.datetime(2015, 1, 1, 8, 0, 0), datetime.datetime(2015, 1, 8, 8, 0, 0), 90, 5, 5000)
simulation.run('ITERS')
print('DONE')

  pheromones_coeffs = pheromones / np.max(pheromones)


ant died :c

travel cost: 58.54
connections number: 1
flights:
 ['LAS', ('SNA', datetime.datetime(2015, 1, 1, 11, 50), datetime.datetime(2015, 1, 1, 12, 55))]

ant died :c

travel cost: 75.53
connections number: 1
flights:
 ['LAS', ('SAN', datetime.datetime(2015, 1, 1, 12, 5), datetime.datetime(2015, 1, 1, 13, 10))]

ant died :c

travel cost: 40.48
connections number: 1
flights:
 ['LAS', ('PHX', datetime.datetime(2015, 1, 1, 11, 15), datetime.datetime(2015, 1, 1, 13, 20))]


travel cost: 437.17
connections number: 2
flights:
 ['LAS', ('ORD', datetime.datetime(2015, 1, 1, 14, 34), datetime.datetime(2015, 1, 1, 20, 8)), ('JAX', datetime.datetime(2015, 1, 2, 7, 50), datetime.datetime(2015, 1, 2, 11, 13))]


travel cost: 383.91999999999996
connections number: 2
flights:
 ['LAS', ('BNA', datetime.datetime(2015, 1, 4, 18, 45), datetime.datetime(2015, 1, 5, 0, 15)), ('JAX', datetime.datetime(2015, 1, 5, 8, 5), datetime.datetime(2015, 1, 5, 10, 30))]


travel cost: 558.63
connections number: 2

In [38]:
MG_test = nx.MultiGraph()

MG_test.add_edge('a', 'b', departure_time=1, arrival_time=2, flight_time=1, price=3)
MG_test.add_edge('b', 'a', departure_time=2, arrival_time=6, flight_time=4, price=8)
MG_test.add_edge('c', 'a', departure_time=3, arrival_time=9, flight_time=6, price=11)

print(MG_test.adj['a'])
print(MG_test.adj['b'])


# for airport_name in MG.nodes:
#     print(airport_name, len(MG.adj['LAS'][airport_name]) if airport_name in MG.adj['LAS'] else 0)


# for flight in MG.adj['LAS']:
#     print(MG.adj['LAS'][flight])
#     print('\n\n\n')

# print(MG.adj['LAS']['HNL'])

# for i in range(0, len(MG.adj['LAS']['HNL'])):
#     print(MG.adj['LAS']['HNL'][i-1]['departure_time'])

{'b': {0: {'departure_time': 1, 'arrival_time': 2, 'flight_time': 1, 'price': 3}, 1: {'departure_time': 2, 'arrival_time': 6, 'flight_time': 4, 'price': 8}}, 'c': {0: {'departure_time': 3, 'arrival_time': 9, 'flight_time': 6, 'price': 11}}}
{'a': {0: {'departure_time': 1, 'arrival_time': 2, 'flight_time': 1, 'price': 3}, 1: {'departure_time': 2, 'arrival_time': 6, 'flight_time': 4, 'price': 8}}}


In [93]:
print((1-np.array([1,2,2])+np.array([2,1,1]))**3)

[8 0 0]


[0.5 1.  1.5]


In [134]:
test_arr = [Ant(0, 'ok'), Ant(1, 'ko')]

def get_ant():
    return test_arr[1]

ant = get_ant()
ant.curr_conn_numb = 100

ant_2_or_1 = get_ant()
ant_2_or_1._print_ant()


travel cost: 0
connections number: 100
flights:
 ['ko']

