In [13]:
import duckdb
import polars as pl
import torch
from torch_geometric.data import Data
import numpy as np
from torch_geometric.utils import to_networkx
from viz_utils import visualize_graph

In [2]:
con = duckdb.connect(database="distanze.db")

In [3]:
df = pl.from_arrow(
    con.execute("""
    SELECT 
        trim(split_part(OR_DEST,' - ',1)) as OR, trim(split_part(OR_DEST,' - ',2)) as DEST, TEP_TOT, KM_TOT, TTP_TOT 
    FROM '../../data/01_raw/Italia/*/*.csv'
    """
    ).fetch_arrow_table()
)
print(df.head())

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

shape: (5, 5)
┌───────┬───────┬─────────┬────────┬─────────┐
│ OR    ┆ DEST  ┆ TEP_TOT ┆ KM_TOT ┆ TTP_TOT │
│ ---   ┆ ---   ┆ ---     ┆ ---    ┆ ---     │
│ str   ┆ str   ┆ i64     ┆ f64    ┆ i64     │
╞═══════╪═══════╪═════════╪════════╪═════════╡
│ 66001 ┆ 66001 ┆ 0       ┆ 0.0    ┆ 0       │
│ 66001 ┆ 66055 ┆ 6       ┆ 4.9    ┆ 5       │
│ 66001 ┆ 66086 ┆ 9       ┆ 7.1    ┆ 7       │
│ 66001 ┆ 66031 ┆ 10      ┆ 8.1    ┆ 9       │
│ 66001 ┆ 66027 ┆ 13      ┆ 10.8   ┆ 11      │
└───────┴───────┴─────────┴────────┴─────────┘


In [4]:
df.group_by("OR").agg(pl.count("DEST").alias("DEST_COUNT")).sort("DEST_COUNT")

OR,DEST_COUNT
str,u32
"""71026""",1
"""81014""",1
"""84020""",1
"""49005""",7
"""82075""",10
…,…
"""15176""",1362
"""15211""",1372
"""15182""",1381
"""15170""",1385


In [5]:
df.group_by("DEST").agg(pl.count("OR").alias("OR_COUNT")).sort("OR_COUNT")

DEST,OR_COUNT
str,u32
"""84020""",1
"""81014""",1
"""7041""",5
"""49005""",7
"""7022""",10
…,…
"""15250""",1332
"""15182""",1334
"""15211""",1350
"""15170""",1351


In [6]:
def create_pyg_graph_from_polars(df: pl.DataFrame, or_col: str, dest_col: str, edge_attr_cols: list):
    """
    Build a PyG Data graph from columns of a Polars DataFrame.

    Parameters:
        df: polars.DataFrame
        or_col: str - Name of the origin node column
        dest_col: str - Name of the destination node column
        edge_attr_cols: list of str - List of column names for edge features
        
    Returns:
        PyG Data object with attributes: edge_index, edge_attr, num_nodes, node2idx, idx2node
    """
    # 1. Unique node labels and mapping
    all_nodes = np.unique(
        np.concatenate([df[or_col].to_numpy(), df[dest_col].to_numpy()])
    )
    node2idx = {code: idx for idx, code in enumerate(all_nodes)}
    
    # 2. Build edge_index
    src_idx = [node2idx[x] for x in df[or_col].to_numpy()]
    dest_idx = [node2idx[x] for x in df[dest_col].to_numpy()]
    edge_index = torch.tensor([src_idx, dest_idx], dtype=torch.long)
    
    # 3. Edge attributes
    edge_attr = torch.tensor(
        np.column_stack([df[col].to_numpy() for col in edge_attr_cols]),
        dtype=torch.float
    )
    
    # 4. Create Data object
    data = Data(
        edge_index=edge_index,
        edge_attr=edge_attr,
        num_nodes=len(all_nodes)
    )
    data.node2idx = node2idx
    data.idx2node = {idx: code for code, idx in node2idx.items()}
    
    return data

In [7]:
data = create_pyg_graph_from_polars(df, "OR", "DEST", ["TEP_TOT", "KM_TOT", "TTP_TOT"])
print(data)

Data(
  edge_index=[2, 3245271],
  edge_attr=[3245271, 3],
  num_nodes=7903,
  node2idx={
    100001=0,
    100002=1,
    100003=2,
    100004=3,
    100005=4,
    100006=5,
    100007=6,
    10001=7,
    10002=8,
    10003=9,
    10004=10,
    10005=11,
    10006=12,
    10007=13,
    10008=14,
    10009=15,
    1001=16,
    10010=17,
    10011=18,
    10012=19,
    10013=20,
    10014=21,
    10015=22,
    10016=23,
    10017=24,
    10018=25,
    10019=26,
    1002=27,
    10020=28,
    10021=29,
    10022=30,
    10023=31,
    10024=32,
    10025=33,
    10026=34,
    10027=35,
    10028=36,
    10029=37,
    1003=38,
    10030=39,
    10031=40,
    10032=41,
    10033=42,
    10034=43,
    10035=44,
    10036=45,
    10037=46,
    10038=47,
    10039=48,
    1004=49,
    10040=50,
    10041=51,
    10042=52,
    10043=53,
    10044=54,
    10045=55,
    10046=56,
    10047=57,
    10048=58,
    10049=59,
    10050=60,
    10051=61,
    10052=62,
    10053=63,
    10054=64,
    100

In [10]:
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of nodes: 7903
Number of edges: 3245271
Average node degree: 410.64
Has isolated nodes: True
Has self-loops: True
Is undirected: False


In [None]:
visualize_graph(
    to_networkx(data, to_undirected=False),
    'lightblue'
)