In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import pandas as pd
import networkx as nx
import folium
from math import radians, sin, cos, sqrt, atan2

In [2]:
airports_df = pd.read_csv("airports.dat", header=None)
routes_df = pd.read_csv("routes.dat", header=None)

In [3]:
airports_df.columns = ['ID', 'Name', 'City', 'Country', 'IATA', 'ICAO', 'Latitude', 'Longitude',
                       'Altitude', 'Timezone', 'DST', 'Tz database time zone', 'Type', 'Source']
routes_df.columns = ['Airline', 'Airline ID', 'Source Airport', 'Source Airport ID',
                     'Destination Airport', 'Destination Airport ID', 'Codeshare',
                     'Stops', 'Equipment']

In [4]:
airports_df = airports_df[airports_df['IATA'].apply(lambda x: isinstance(x, str) and len(x) == 3)]
airport_to_idx = {code: i for i, code in enumerate(airports_df['IATA'])}
idx_to_airport = {i: code for code, i in airport_to_idx.items()}

In [5]:
def haversine(coord1, coord2):
    R = 6371  # Earth radius in km
    lat1, lon1 = map(radians, coord1)
    lat2, lon2 = map(radians, coord2)
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = sin(dlat/2)**2 + cos(lat1)*cos(lat2)*sin(dlon/2)**2
    return R * (2 * atan2(sqrt(a), sqrt(1 - a)))

In [6]:
edges = []
weights = []
for _, row in routes_df.iterrows():
    src, dst = row["Source Airport"], row["Destination Airport"]
    if src in airport_to_idx and dst in airport_to_idx:
        i, j = airport_to_idx[src], airport_to_idx[dst]
        src_info = airports_df[airports_df["IATA"] == src].iloc[0]
        dst_info = airports_df[airports_df["IATA"] == dst].iloc[0]
        d = haversine((src_info["Latitude"], src_info["Longitude"]),
                      (dst_info["Latitude"], dst_info["Longitude"]))
        edges.append([i, j])
        weights.append(d)

In [7]:
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_weight = torch.tensor(weights, dtype=torch.float)

# Step 6: Create node features
features = torch.tensor(airports_df[["Latitude", "Longitude"]].values, dtype=torch.float)
num_nodes = features.shape[0]

In [8]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        return x

In [9]:

data = Data(x=features, edge_index=edge_index, edge_weight=edge_weight)


In [10]:
model = GCN(in_channels=2, hidden_channels=16, out_channels=2)

# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(data)
    loss = F.mse_loss(out, features[:, :2])  # Now shapes match: [num_nodes, 2]
    loss.backward()
    optimizer.step()

In [11]:
embeddings = model(data).detach().numpy()
G = nx.Graph()
for i, emb in enumerate(embeddings):
    G.add_node(i)

In [12]:
for i in range(edge_index.shape[1]):
    src, dst = edge_index[0, i].item(), edge_index[1, i].item()
    G.add_edge(src, dst, weight=edge_weight[i].item())

In [13]:
src_code = "RPR"  # Raipur
dst_code = "LAX"  # Los Angeles

In [14]:
src_idx = airport_to_idx.get(src_code)
dst_idx = airport_to_idx.get(dst_code)

In [15]:
if src_idx is not None and dst_idx is not None:
    try:
        path = nx.dijkstra_path(G, source=src_idx, target=dst_idx, weight="weight")
        path_iata = [idx_to_airport[i] for i in path]
        print("✅ GNN Path:", path_iata)

        # Step 12: Show path on Folium map
        m = folium.Map(location=[20.5937, 78.9629], zoom_start=3)
        coords = []
        for node in path:
            airport = airports_df[airports_df["IATA"] == idx_to_airport[node]].iloc[0]
            lat, lon = airport["Latitude"], airport["Longitude"]
            coords.append((lat, lon))
            folium.Marker([lat, lon], tooltip=airport["IATA"]).add_to(m)

        folium.PolyLine(coords, color="blue", weight=3).add_to(m)
        m.save("gnn_path_map.html")
        print("🗺️ Map saved as 'gnn_path_map.html'")
    except nx.NetworkXNoPath:
        print("❌ No path found by GNN")
else:
    print("❌ Invalid IATA codes.")

✅ GNN Path: ['RPR', 'CCU', 'KMG', 'PEK', 'LAX']
🗺️ Map saved as 'gnn_path_map.html'


In [16]:
import networkx as nx

def shortest_path_via_gnn(G, src_idx, dst_idx, idx_to_airport, max_stops=3, allowed_airlines=None):
    """
    Compute the shortest path using GNN-modified graph based on allowed airlines and max stops.
    """
    if allowed_airlines:
        allowed_airlines = set(a.strip().upper() for a in allowed_airlines)

    # Create filtered graph
    H = nx.DiGraph()

    for u, v, attr in G.edges(data=True):
        airline = attr.get("airline", "").upper()
        if allowed_airlines and airline not in allowed_airlines:
            continue
        H.add_edge(u, v, weight=attr["weight"], airline=airline)

    try:
        # All paths up to max_stops (i.e., max_stops + 1 airports)
        all_paths = list(nx.all_simple_paths(H, source=src_idx, target=dst_idx, cutoff=max_stops + 1))
        if not all_paths:
            raise nx.NetworkXNoPath(f"No path found within {max_stops} stops.")

        # Select path with minimum total weight
        best_path = min(all_paths, key=lambda path: sum(H[u][v]["weight"] for u, v in zip(path[:-1], path[1:])))
        return best_path

    except nx.NetworkXNoPath:
        raise Exception(f"No path found from {idx_to_airport[src_idx]} to {idx_to_airport[dst_idx]} within {max_stops} stops.")


In [17]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import folium

# Dropdowns
airport_options = sorted([iata for iata in airport_to_idx if isinstance(iata, str) and len(iata) == 3])

used_nodes = set(sum([[u, v] for u, v in G.edges()], []))
airport_options = sorted([
    idx_to_airport[i] for i in used_nodes
    if i in idx_to_airport and isinstance(idx_to_airport[i], str)
])

src_widget = widgets.Dropdown(options=airport_options, description='Source:')
dst_widget = widgets.Dropdown(options=airport_options, description='Destination:')

# Max stops
max_stops_slider = widgets.IntSlider(value=1, min=0, max=5, step=1, description='Max Stops:')

# Airline filter
airline_textbox = widgets.Text(value='', placeholder='e.g. AA,BA,LH', description='Allowed Airlines:')

# Run button
run_button = widgets.Button(description='Find Route', button_style='success')

# Output display
output = widgets.Output()

# Button click handler
def on_run_button_click(b):
    with output:
        clear_output()

        src_airport = src_widget.value
        dst_airport = dst_widget.value
        max_stops = max_stops_slider.value
        allowed_airlines = [a.strip().upper() for a in airline_textbox.value.split(',') if a.strip()]

        if src_airport == dst_airport:
            print("⚠️ Source and destination cannot be the same.")
            return

        try:
            path = shortest_path_via_gnn(G, airport_to_idx[src_airport], airport_to_idx[dst_airport],
                                         idx_to_airport, max_stops=max_stops, allowed_airlines=allowed_airlines)

            print("✅ Optimized Path (IATA):", [idx_to_airport[i] for i in path])

            # Plot map
            m = folium.Map(location=[0, 0], zoom_start=2)
            coords = []
            for i in path:
                row = airports_df[airports_df['IATA'] == idx_to_airport[i]]
                if not row.empty:
                    lat, lon = row.iloc[0]['Latitude'], row.iloc[0]['Longitude']
                    coords.append((lat, lon))
                    folium.Marker(location=(lat, lon), popup=idx_to_airport[i]).add_to(m)

            folium.PolyLine(locations=coords, color='blue').add_to(m)
            display(m)

        except Exception as e:
            print(f"❌ Error: {e}")

# Link button to function
run_button.on_click(on_run_button_click)

# Show UI
ui = widgets.VBox([src_widget, dst_widget, max_stops_slider, airline_textbox, run_button, output])
display(ui)


VBox(children=(Dropdown(description='Source:', options=('AAE', 'AAL', 'AAN', 'AAQ', 'AAR', 'AAT', 'AAX', 'AAY'…