In [1]:
%%writefile /content/gis_utils.py
import geopandas as gpd
import numpy as np
from shapely.geometry import box
from pyproj import Transformer
from tqdm import tqdm
from scipy.stats import zscore

def get_file_paths(city_name):
    base_path = f"data/{city_name}-share_data"
    codes = {"NY":225,"SF":233,"Chicago":97}
    c = codes.get(city_name)
    if c is None:
        raise ValueError(f"Unsupported city: {city_name}")
    return {
        "population": f"{base_path}/{city_name}_zone_{c}_population.shp",
        "pois":       f"{base_path}/{city_name}_pois.shp",
        "road_network":f"{base_path}/{city_name}_roadnetwork.pkl",
        "accidents":  f"{base_path}/{city_name}_crash_information.csv",
        "demand":     f"{base_path}/{city_name}-taxi_zone.npy",
        "save_path":  f"{base_path}/processed"
    }

def transform_network_coordinates(G, from_crs="EPSG:4326", to_crs="EPSG:3395"):
    transformer = Transformer.from_crs(from_crs, to_crs, always_xy=True)
    for node,data in tqdm(G.nodes(data=True), desc="Transforming Network Coordinates"):
        if "x" in data and "y" in data:
            data["x"], data["y"] = transformer.transform(data["x"], data["y"])

def compute_population_density(area, pop_gdf):
    polys = pop_gdf[pop_gdf.intersects(area)]
    if polys.empty: return 0
    polys = polys.copy()
    polys["overlap_area"] = polys.geometry.intersection(area).area / 1e6
    polys["ratio"] = polys["overlap_area"] / polys["area_km2"]
    polys["pop_catch"] = polys["population"] * polys["ratio"]
    total = polys["pop_catch"].sum()
    return total / (area.area / 1e6)

def extract_road_network_subgraph(area, nodes_gdf, G):
    ids = nodes_gdf[nodes_gdf.within(area)]["node"].values
    return G.subgraph(ids).copy()

def compute_poi_densities(area, pois_gdf, categories):
    sub = pois_gdf[pois_gdf.within(area)]
    area_km2 = area.area / 1e6
    return [ sub[sub["category"]==cat].shape[0] / area_km2 for cat in categories ]

def generate_random_points_within_polygon(poly, n):
    minx,miny,maxx,maxy = poly.bounds
    factor = 5
    xs = np.random.uniform(minx, maxx, n*factor)
    ys = np.random.uniform(miny, maxy, n*factor)
    pts = gpd.GeoSeries(gpd.points_from_xy(xs, ys), crs="EPSG:3395")
    pts = pts[pts.within(poly)]
    if len(pts) < n:
        return generate_random_points_within_polygon(poly, n)
    return pts.iloc[:n].tolist()

def compute_accident_counts(area, acc_gdf):
    sub = acc_gdf[acc_gdf.within(area)]
    total = len(sub)
    sev = sub["Severity"].value_counts().to_dict()
    return [ total, sev.get(1,0), sev.get(2,0), sev.get(3,0), sev.get(4,0) ]

def compute_demand(area, pop_gdf, demand_arr):
    center = area.centroid
    match = pop_gdf[pop_gdf.contains(center)]
    if not match.empty:
        return demand_arr[match.index[0]]
    return np.zeros((24,4))

def construct_features(n, ny_population, ny_pois, nodes_gdf, ny_road_network,
                       categories, accident_data_gdf, demand_data, side_length=2000):
    city_poly = ny_population.unary_union
    pts = generate_random_points_within_polygon(city_poly, n)
    catchments = [
        box(p.x - side_length/2, p.y - side_length/2,
            p.x + side_length/2, p.y + side_length/2)
        for p in pts
    ]
    std_dem = zscore(demand_data, axis=None)
    feat_pop, subGs, feat_pois, feats_acc, feats_dem = [], [], [], [], []
    for area in tqdm(catchments, desc="Processing Catchments"):
        feat_pop.append(compute_population_density(area, ny_population))
        subGs.append(extract_road_network_subgraph(area, nodes_gdf, ny_road_network))
        feat_pois.append(compute_poi_densities(area, ny_pois, categories))
        feats_acc.append(compute_accident_counts(area, accident_data_gdf))
        feats_dem.append(compute_demand(area, ny_population, std_dem))
    return feat_pop, subGs, feat_pois, feats_acc, feats_dem


Overwriting /content/gis_utils.py


In [2]:
import gis_utils
from gis_utils import get_file_paths, construct_features


In [3]:
%%writefile gis_processing.py

import pandas as pd
import geopandas as gpd
import argparse
import numpy as np
import pickle
from gis_utils import (get_file_paths, transform_network_coordinates, construct_features)

def main():
    parser = argparse.ArgumentParser(
        description="Process datasets for population, POIs, and road-network analysis."
    )
    parser.add_argument(
        "--city",
        type=str,
        default="Chicago",
        choices=["NY", "SF", "Chicago"],
        help="City name to process (default: Chicago)."
    )
    parser.add_argument(
        "--num_catchment",
        type=int,
        default=50,
        help="Number of random catchment areas."
    )
    parser.add_argument(
        "--side_length",
        type=int,
        default=2000,
        help="Side length of each catchment area in meters."
    )
    args = parser.parse_args()

    # ------------------------------------------------------------------
    # Get all file paths for the chosen city
    # ------------------------------------------------------------------
    file_paths = get_file_paths(args.city)

    # Step 1 ── Load population data
    print(f"Loading population data for {args.city}…")
    population = gpd.read_file(file_paths["population"])
    population = population.to_crs(epsg=3395)
    population["area_km2"] = population.geometry.area / 1e6
    population["pop_den"]  = population["population"] / population["area_km2"]
    print("Population data loaded. Total zones:", len(population))

    # Step 2 ── Load points-of-interest (POI) data
    print(f"Loading POI data for {args.city}…")
    pois = gpd.read_file(file_paths["pois"])
    pois = pois.to_crs(epsg=3395)
    print("POI data loaded. Total points:", len(pois))

    # Step 3 ── Load the road-network graph
    print(f"Loading road-network data for {args.city}…")
    with open(file_paths["road_network"], "rb") as f:
        road_network = pickle.load(f)
    print(
        "Road-network data loaded. "
        "Total nodes:", road_network.number_of_nodes(),
        "Total edges:", road_network.number_of_edges()
    )

    # Step 4 ── Re-project the road network to EPSG 3395
    print("Transforming road-network coordinates to EPSG:3395…")
    transform_network_coordinates(
        road_network,
        from_crs="EPSG:4326",
        to_crs="EPSG:3395"
    )
    print("Coordinate transformation complete.")

    # Step 5 ── Build a GeoDataFrame for road-network nodes
    nodes_data = dict(road_network.nodes(data=True))
    nodes_df = pd.DataFrame.from_dict(nodes_data, orient="index")
    nodes_df["node"] = nodes_df.index
    nodes_gdf = gpd.GeoDataFrame(
        nodes_df,
        geometry=gpd.points_from_xy(nodes_df.x, nodes_df.y),
        crs=population.crs
    )

    # Step 6 ── Define POI categories to aggregate
    categories = [
        "office_tags", "sustenance_tags", "transportation_tags",
        "retail_tags", "leisure_tags", "residence_tags"
    ]

    # Step 7 ── Load traffic-accident data
    print(f"Loading accident data for {args.city}…")
    accident_data = pd.read_csv(file_paths["accidents"])
    accident_data["Start_Lat"] = pd.to_numeric(accident_data["Start_Lat"], errors="coerce")
    accident_data["Start_Lng"] = pd.to_numeric(accident_data["Start_Lng"], errors="coerce")
    accident_data = accident_data.dropna(subset=["Start_Lat", "Start_Lng"])
    accident_data_gdf = gpd.GeoDataFrame(
        accident_data,
        geometry=gpd.points_from_xy(accident_data.Start_Lng, accident_data.Start_Lat),
        crs="EPSG:4326"
    ).to_crs(epsg=3395)
    print("Accident data loaded. Total accidents:", len(accident_data_gdf))

    # Step 8 ── Load ride-demand data
    print("Loading demand data…")
    demand_data = np.load(file_paths["demand"], allow_pickle=True)
    print("Demand data loaded. Shape:", demand_data.shape)

    # Step 9 ── Generate features for random catchment areas
    print(f"Generating features for {args.num_catchment} random locations…")
    (
        feat_pop_list,
        sub_graph_list,
        feat_pois_list,
        target_accident_list,
        demand_list
    ) = construct_features(
        n=args.num_catchment,
        ny_population=population,
        ny_pois=pois,
        nodes_gdf=nodes_gdf,
        ny_road_network=road_network,
        categories=categories,
        accident_data_gdf=accident_data_gdf,
        demand_data=demand_data,
        side_length=args.side_length
    )
    print("Feature generation complete.")

    # Step 10 ── Persist results to disk
    save_path = file_paths["save_path"]

    data = pd.DataFrame({
        "population_density": feat_pop_list,
        "poi_densities"     : feat_pois_list,
        "accident_counts"   : target_accident_list,
    })

    # Split nested columns into flat tables
    poi_densities_df = pd.DataFrame(data["poi_densities"].tolist(), columns=categories)
    accident_counts_df = pd.DataFrame(
        data["accident_counts"].tolist(),
        columns=["total_accidents", "severity_1", "severity_2", "severity_3", "severity_4"]
    )
    data_output = pd.concat(
        [
            data.drop(["poi_densities", "accident_counts"], axis=1),
            poi_densities_df,
            accident_counts_df
        ],
        axis=1
    )
    data_output.to_csv(f"{save_path}/feature_data.csv", index=False)
    print(f"Feature data saved to {save_path}/feature_data.csv.")

    demand_array = np.stack(demand_list, axis=0)
    np.save(f"{save_path}/demand_features.npy", demand_array)
    print(f"Demand features saved to {save_path}/demand_features.npy.")

    with open(f"{save_path}/sub_graph_list.pkl", "wb") as f:
        pickle.dump(sub_graph_list, f)
    print(f"Subgraph list saved to {save_path}/sub_graph_list.pkl.")

if __name__ == "__main__":
    main()


Overwriting gis_processing.py


In [4]:
%%writefile 2-generate_graph_data_homo.py
import argparse
import torch
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import StandardScaler
import numpy as np
from tqdm import tqdm
import pandas as pd
import pickle
import os

def main():
    parser = argparse.ArgumentParser(
        description="Generate homogeneous PyG Data objects from city-specific datasets."
    )
    parser.add_argument(
        "--city",
        type=str,
        default="Chicago",
        choices=["NY", "SF", "Chicago"],
        help="City name to process (default: Chicago)."
    )
    args = parser.parse_args()

    # ------------------------------------------------------------------
    # Utility helpers: resolve I/O paths for the chosen city
    # ------------------------------------------------------------------
    def get_paths_by_city(city_name):
        """Return paths to feature CSV, sub-graph list and demand tensor."""
        base_path = f"data/{city_name}-share_data/processed"
        return {
            "feature_data_path"   : f"{base_path}/feature_data.csv",
            "sub_graph_list_path" : f"{base_path}/sub_graph_list.pkl",
            "demand_data_path"    : f"{base_path}/demand_features.npy"  #  shape (N, 24, 4)
        }

    def get_data_list_save_path(city_name):
        base_path = f"data/{city_name}-share_data/processed"
        return f"{base_path}/data_list.pt"

    # Resolve all paths
    paths              = get_paths_by_city(args.city)
    feature_data_path  = paths["feature_data_path"]
    sub_graph_list_path = paths["sub_graph_list_path"]
    demand_data_path   = paths["demand_data_path"]
    data_list_path     = get_data_list_save_path(args.city)

    # ------------------------------------------------------------------
    # 1)  Load tabular feature data
    # ------------------------------------------------------------------
    data = pd.read_csv(feature_data_path)
    print(f"Feature table loaded from '{feature_data_path}'.")

    # ------------------------------------------------------------------
    # 2)  Load sub-graph list (road-network ego-graphs)
    # ------------------------------------------------------------------
    with open(sub_graph_list_path, "rb") as f:
        sub_graph_list = pickle.load(f)
    print(f"Sub-graph list loaded from '{sub_graph_list_path}'.")

    # ------------------------------------------------------------------
    # 3)  Load demand tensor  (N, 24, 4)
    # ------------------------------------------------------------------
    demand_data = np.load(demand_data_path)    # (N, 24, 4)
    print(f"Demand tensor loaded from '{demand_data_path}', shape: {demand_data.shape}")

    # ------------------------------------------------------------------
    # 4)  Assemble input features
    # ------------------------------------------------------------------
    categories = [
        "office_tags", "sustenance_tags", "transportation_tags",
        "retail_tags", "leisure_tags", "residence_tags"
    ]
    pois = data[categories].values

    # ------------------------------------------------------------------
    # 5)  Target variables (traffic-accident statistics)
    # ------------------------------------------------------------------
    accident_columns = [
        "total_accidents", "severity_1", "severity_2", "severity_3", "severity_4"
    ]
    accidents = data[accident_columns].values

    # ------------------------------------------------------------------
    # 6)  Standardise numerical features
    # ------------------------------------------------------------------
    pois_scaled      = StandardScaler().fit_transform(pois)
    accidents_scaled = StandardScaler().fit_transform(accidents)

    # ------------------------------------------------------------------
    # 7)  Clean node/edge attributes inside each sub-graph
    # ------------------------------------------------------------------
    node_types = set()
    for g in sub_graph_list:
        for node, attrs in g.nodes(data=True):
            node_type = attrs.get("highway", "intersection")
            attrs["type"] = node_type
            attrs.pop("highway", None)
            attrs.pop("ref", None)
            node_types.add(node_type)

    node_type_to_id = {t: i for i, t in enumerate(sorted(node_types))}
    num_node_types  = len(node_type_to_id)
    print(f"Detected {num_node_types} unique node types.")

    for g in sub_graph_list:
        # Node attributes
        for node, attrs in g.nodes(data=True):
            attrs["type_id"] = node_type_to_id[attrs["type"]]
        # Edge attributes: keep only highway / maxspeed
        for _, _, attrs in g.edges(data=True):
            for k in list(attrs.keys()):
                if k not in ["highway", "maxspeed"]:
                    del attrs[k]
            attrs.setdefault("highway", "unknown")
            attrs.setdefault("maxspeed", 20)

    # ------------------------------------------------------------------
    # 8)  Helper: convert a single NetworkX graph → PyG Data
    # ------------------------------------------------------------------
    def convert_to_pyg_data(sub_graph, poi, accident, demand, is_homogeneous=True):
        """Convert one sub-graph plus its tabular features into a PyG Data object."""
        if sub_graph.number_of_nodes() == 0:
            return None

        sub_graph = sub_graph.copy()

        # Ensure every node has a numeric type_id and drop other attributes
        for n in sub_graph.nodes:
            attrs = sub_graph.nodes[n]
            attrs.setdefault("type_id", node_type_to_id.get("unknown", 0))
            for k in list(attrs.keys()):
                if k not in ["type_id"]:
                    del attrs[k]

        # Drop edge attributes if a homogeneous graph is required
        if is_homogeneous:
            for _, _, d in sub_graph.edges(data=True):
                d.clear()

        data = from_networkx(sub_graph)

        # One-hot node feature matrix
        type_ids = torch.tensor(
            [attr["type_id"] for _, attr in sub_graph.nodes(data=True)],
            dtype=torch.long
        )
        data.x = torch.nn.functional.one_hot(type_ids, num_classes=num_node_types).float()

        # Remove leftover keys except x / edge_index
        for k in set(data.keys()) - {"x", "edge_index"}:
            del data[k]

        # Attach graph-level attributes
        data.poi      = torch.tensor(poi,      dtype=torch.float).unsqueeze(0)
        data.accident = torch.tensor(accident, dtype=torch.float).unsqueeze(0)
        data.demand   = torch.tensor(demand,   dtype=torch.float).unsqueeze(0)

        return data

    # ------------------------------------------------------------------
    # 9)  Build the full list of Data objects
    # ------------------------------------------------------------------
    data_list          = []
    num_skipped_graphs = 0

    for i in tqdm(range(len(sub_graph_list)), desc="Converting to PyG Data"):
        sub_graph = sub_graph_list[i]
        poi       = pois_scaled[i]       # (6,)
        accident  = accidents_scaled[i]  # (5,)
        demand    = demand_data[i]       # (24, 4)

        graph_data = convert_to_pyg_data(
            sub_graph, poi, accident, demand, is_homogeneous=True
        )
        if graph_data is None:
            num_skipped_graphs += 1
            continue
        data_list.append(graph_data)

    print(
        f"Converted {len(data_list)} sub-graphs; "
        f"skipped {num_skipped_graphs} empty graphs."
    )

    # ------------------------------------------------------------------
    # 10)  Persist the list to disk
    # ------------------------------------------------------------------
    os.makedirs(os.path.dirname(data_list_path), exist_ok=True)
    torch.save(data_list, data_list_path)
    print(f"PyG Data list saved to '{data_list_path}'.")

if __name__ == "__main__":
    main()


Overwriting 2-generate_graph_data_homo.py


In [5]:
# If DIG (required for GraphCL) is not installed, uncomment the next line:
# !pip install dive-into-graphs -q

# Pull the latest PyG code from the main repository and install
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

!pip install --no-index --upgrade torch-scatter torch-sparse torch-cluster pyg-lib \
  --find-links https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__.split('+')[0])")+cpu.html

# Install PyG core and its extensions compatible with the current Torch version (CPU build)
!pip install --no-index --upgrade torch-geometric torch-scatter torch-sparse \
  --find-links https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__.split('+')[0])")+cpu.html


from google.colab import drive
drive.mount('/content/drive')


Collecting git+https://github.com/pyg-team/pytorch_geometric.git
  Cloning https://github.com/pyg-team/pytorch_geometric.git to /tmp/pip-req-build-16cwk5ov
  Running command git clone --filter=blob:none --quiet https://github.com/pyg-team/pytorch_geometric.git /tmp/pip-req-build-16cwk5ov
  Resolved https://github.com/pyg-team/pytorch_geometric.git to commit ea74852394618f8f29377c7605b2b673783ca881
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cpu.html
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cpu.html
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [55]:


# !! DUE TO THE CONFIDENTIALITY OF THE ORIGINAL DATASET, WE CANNOT DISTRIBUTE IT. !!

# !! THIS CODE BLOCK PERFORMS 100 RANDOM SAMPLES OF THE DATASET AND SAVES THEM, !!

# !! ALLOWING EXPERIMENTS TO BE REPRODUCED WITHOUT SHARING THE RAW DATA. !!



# # sampling_pipeline.py
# from google.colab import drive
# import subprocess
# import os
# from tqdm import trange

# # 1. Mount your Drive
# drive.mount('/content/drive', force_remount=True)

# # 2. Configuration
# CITIES     = ["Chicago", "SF", "NY"]
# NUM_CATCH  = 100
# SIDE_LEN   = 3000
# DRIVE_DATA = "/content/drive/MyDrive/FractalGCL/city/data"
# TEMP_DATA  = "/content/data"

# # 3. Run the pipeline 100 times with a progress bar
# for i in trange(1, 101, desc="Sampling runs"):
#     # a) Refresh raw data
#     subprocess.run(["rm", "-rf", TEMP_DATA], check=True)
#     subprocess.run(["cp", "-r", DRIVE_DATA, TEMP_DATA], check=True)

#     # b) Process each city
#     for city in CITIES:
#         # Generate GIS features
#         subprocess.run([
#             "python3", "/content/1-gis_processing.py",
#             "--city", city,
#             "--num_catchment", str(NUM_CATCH),
#             "--side_length",  str(SIDE_LEN)
#         ], check=True)

#         # Build homogeneous PyG graph data
#         subprocess.run([
#             "python3", "/content/2-generate_graph_data_homo.py",
#             "--city", city
#         ], check=True)

#         # c) Save this iteration's output
#         src = os.path.join(
#             TEMP_DATA,
#             f"{city}-share_data",
#             "processed"
#         )
#         dst = os.path.join(
#             DRIVE_DATA,
#             f"{city}-share_data",
#             f"processed_{i}"
#         )
#         subprocess.run(["rm", "-rf", dst], check=True)
#         subprocess.run(["cp", "-r", src, dst], check=True)

# # Done: you will find directories
# #   drive/MyDrive/FractalGCL/city/data/Chicago-share_data/processed_1
# #   ...
# #   drive/MyDrive/FractalGCL/city/data/Chicago-share_data/processed_100
# # and similarly for SF and NY.


Mounted at /content/drive


Sampling runs: 100%|██████████| 100/100 [2:23:29<00:00, 86.09s/it] 


In [None]:


# !! THIS FUNCTION SAMPLES DATA FOR THE FULL PIPELINE, BUT IT IS COMMENTED OUT BECAUSE THE DATA CANNOT BE FULLY OPENED. !!



# %%bash
# set -e
#
# # List of cities you want to process
# CITIES=("Chicago" "SF" "NY")
#
# # Sampling count and window size for each city (tweak as needed)
# NUM_CATCH=100
# SIDE_LEN=3000
#
# # Copy the original data directory to /content
# rm -rf /content/data
# cp -r /content/drive/MyDrive/FractalGCL/city/data /content/
#
# for city in "${CITIES[@]}"; do
#   echo "▶️  Pre-processing ${city} …"
#
#   # 1) Generate GIS features
#   python3 /content/1-gis_processing.py \
#     --city "$city" \
#     --num_catchment $NUM_CATCH \
#     --side_length $SIDE_LEN
#
#   # 2) Build homogeneous PyG graph data
#   python3 /content/2-generate_graph_data_homo.py \
#     --city "$city"
#
#   # 3) Sync back to Drive (overwrite the city’s processed directory)
#   rm -rf "/content/drive/MyDrive/FractalGCL/city/data/${city}-share_data/processed"
#   cp -r "/content/data/${city}-share_data/processed" \
#         "/content/drive/MyDrive/FractalGCL/city/data/${city}-share_data/"
#
#   echo "✅  ${city} pre-processing complete"
# done
#
# echo "🎉  All cities processed successfully"


▶️  Pre-processing Chicago …
Loading population data for Chicago…
Population data loaded. Total zones: 97
Loading POI data for Chicago…
POI data loaded. Total points: 54541
Loading road-network data for Chicago…
Road-network data loaded. Total nodes: 29117 Total edges: 76702
Transforming road-network coordinates to EPSG:3395…
Coordinate transformation complete.
Loading accident data for Chicago…
Accident data loaded. Total accidents: 4308
Loading demand data…
Demand data loaded. Shape: (97, 24, 4)
Generating features for 100 random locations…
Feature generation complete.
Feature data saved to data/Chicago-share_data/processed/feature_data.csv.
Demand features saved to data/Chicago-share_data/processed/demand_features.npy.
Subgraph list saved to data/Chicago-share_data/processed/sub_graph_list.pkl.
Feature table loaded from 'data/Chicago-share_data/processed/feature_data.csv'.
Sub-graph list loaded from 'data/Chicago-share_data/processed/sub_graph_list.pkl'.
Demand tensor loaded from 'd

Transforming Network Coordinates:   0%|          | 0/29117 [00:00<?, ?it/s]Transforming Network Coordinates: 100%|██████████| 29117/29117 [00:00<00:00, 929566.75it/s]
Processing Catchments:   0%|          | 0/100 [00:00<?, ?it/s]Processing Catchments:   4%|▍         | 4/100 [00:00<00:02, 37.87it/s]Processing Catchments:   8%|▊         | 8/100 [00:00<00:02, 38.68it/s]Processing Catchments:  12%|█▏        | 12/100 [00:00<00:02, 35.46it/s]Processing Catchments:  16%|█▌        | 16/100 [00:00<00:02, 35.45it/s]Processing Catchments:  20%|██        | 20/100 [00:00<00:02, 36.69it/s]Processing Catchments:  24%|██▍       | 24/100 [00:00<00:02, 36.86it/s]Processing Catchments:  28%|██▊       | 28/100 [00:00<00:01, 37.46it/s]Processing Catchments:  32%|███▏      | 32/100 [00:00<00:01, 37.60it/s]Processing Catchments:  36%|███▌      | 36/100 [00:00<00:01, 37.04it/s]Processing Catchments:  40%|████      | 40/100 [00:01<00:01, 37.06it/s]Processing Catchments:  44%|████▍     | 44/100 [0

In [6]:
# —— Block: Train five baseline encoders and keep them in memory ——

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dropout_edge, subgraph
import numpy as np
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
from gis_utils import get_file_paths

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Generic Encoder ───────────────────────────────────────────────────────────
class Encoder(torch.nn.Module):
    def __init__(self, in_dim, hidden, embed_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.proj  = torch.nn.Linear(hidden, embed_dim)
    def forward(self, x, ei, batch=None):
        x = F.relu(self.conv1(x, ei))
        x = F.relu(self.conv2(x, ei))
        x = self.proj(x)
        # Fix potential (embed_dim, N) output
        if x.dim() == 2 and x.size(0) == self.proj.out_features and x.size(1) != self.proj.out_features:
            x = x.t()
        return x

# ─── DGI ───────────────────────────────────────────────────────────────────────
def train_dgi(data_list, in_dim, hidden, embed_dim,
              lr, epochs, batch_size):
    enc = Encoder(in_dim, hidden, embed_dim).to(device)
    opt = torch.optim.Adam(enc.parameters(), lr=lr)
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)
    for ep in range(1, epochs + 1):
        total = 0
        for batch in loader:
            batch = batch.to(device)
            z_pos = enc(batch.x, batch.edge_index, batch.batch)
            perm  = torch.randperm(batch.x.size(0), device=device)
            z_neg = enc(batch.x[perm], batch.edge_index, batch.batch)
            s      = torch.sigmoid(global_mean_pool(z_pos, batch.batch))
            s_node = s[batch.batch]
            pos = (z_pos * s_node).sum(dim=1)
            neg = (z_neg * s_node).sum(dim=1)
            logits = torch.stack([pos, neg], dim=1)
            labels = torch.zeros(batch.x.size(0), device=device, dtype=torch.long)
            loss   = F.cross_entropy(logits, labels)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"[DGI] {ep}/{epochs} — Loss: {total/len(loader):.4f}")
    return enc

# ─── InfoGraph ────────────────────────────────────────────────────────────────
def train_infograph(data_list, in_dim, hidden, embed_dim,
                    lr, epochs, batch_size, temp):
    enc = Encoder(in_dim, hidden, embed_dim).to(device)
    summary_net = torch.nn.Linear(embed_dim, embed_dim).to(device)
    opt = torch.optim.Adam(list(enc.parameters()) + list(summary_net.parameters()), lr=lr)
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)
    for ep in range(1, epochs + 1):
        total = 0
        for batch in loader:
            batch = batch.to(device)
            h = enc(batch.x, batch.edge_index, batch.batch)      # (N, E)
            g = global_mean_pool(h, batch.batch)                 # (B, E)
            s = summary_net(g)                                   # (B, E)
            logits = (h @ s.t()) / temp                          # (N, B)
            labels = batch.batch                                 # node → graph idx
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"[InfoGraph] {ep}/{epochs} — Loss: {total/len(loader):.4f}")
    return enc

# ─── GCL (hand-written) ────────────────────────────────────────────────────────
def train_gcl_manual(data_list, in_dim, hidden, embed_dim,
                     lr, epochs, batch_size,
                     drop_prob, subgraph_ratio, temp):
    enc = Encoder(in_dim, hidden, embed_dim).to(device)
    opt = torch.optim.Adam(enc.parameters(), lr=lr)
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

    def nt_xent(z1, z2):
        z1, z2 = F.normalize(z1, dim=1), F.normalize(z2, dim=1)
        logits = torch.mm(z1, z2.t()) / temp
        labels = torch.arange(z1.size(0), device=device)
        return F.cross_entropy(logits, labels)

    for ep in range(1, epochs + 1):
        total = 0
        for batch in loader:
            batch = batch.to(device)

            # Augmentation 1: edge dropout
            ei1, _ = dropout_edge(batch.edge_index, p=drop_prob)
            b1 = Data(x=batch.x, edge_index=ei1, batch=batch.batch).to(device)

            # Augmentation 2: node-induced subgraph
            N = batch.x.size(0)
            k = max(2, int(N * subgraph_ratio))
            idx = torch.randperm(N, device=device)[:k].tolist()          # to list
            ei2, _ = subgraph(idx,
                              batch.edge_index,
                              relabel_nodes=True,
                              num_nodes=N)                                # specify num_nodes
            b2 = Data(x=batch.x[idx], edge_index=ei2, batch=batch.batch[idx]).to(device)

            # Contrastive loss
            g1 = global_mean_pool(enc(b1.x, b1.edge_index, b1.batch), b1.batch)
            g2 = global_mean_pool(enc(b2.x, b2.edge_index, b2.batch), b2.batch)
            loss = nt_xent(g1, g2)

            opt.zero_grad()
            loss.backward()
            opt.step()

            total += loss.item()

        print(f"[GCL] {ep}/{epochs} — Loss: {total/len(loader):.4f}")

    return enc

# ─── JOAO ──────────────────────────────────────────────────────────────────────
def train_joao(data_list, in_dim, hidden, embed_dim,
               lr, epochs, batch_size):
    # Two augmentations: drop-edge & attribute mask
    aug_fns = [
        lambda d: Data(x=d.x,
                       edge_index=dropout_edge(d.edge_index, p=0.2)[0],
                       batch=d.batch),
        lambda d: Data(x=d.x.masked_fill(torch.rand_like(d.x) < 0.2, 0),
                       edge_index=d.edge_index,
                       batch=d.batch)
    ]
    alpha = torch.nn.Parameter(torch.zeros(len(aug_fns), device=device))
    enc   = Encoder(in_dim, hidden, embed_dim).to(device)
    opt   = torch.optim.Adam(list(enc.parameters()) + [alpha], lr=lr)
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

    for ep in range(1, epochs + 1):
        total = 0
        for batch in loader:
            batch = batch.to(device)
            weights = torch.softmax(alpha, dim=0)
            loss_sum = 0
            for w, aug in zip(weights, aug_fns):
                b1 = aug(batch).to(device)
                b2 = aug(batch).to(device)
                g1 = global_mean_pool(enc(b1.x, b1.edge_index, b1.batch), b1.batch)
                g2 = global_mean_pool(enc(b2.x, b2.edge_index, b2.batch), b2.batch)
                logits = (F.normalize(g1, 1) @ F.normalize(g2, 1).t()) / 0.5
                labels = torch.arange(g1.size(0), device=device)
                loss = F.cross_entropy(logits, labels)
                loss_sum = loss_sum + w * loss
            opt.zero_grad(); loss_sum.backward(); opt.step()
            total += loss_sum.item()
        print(f"[JOAO] {ep}/{epochs} — Loss: {total/len(loader):.4f}")
    return enc

# ─── SimGRACE ─────────────────────────────────────────────────────────────────
def train_simgrace(data_list, in_dim, hidden, embed_dim,
                   lr=1e-3, epochs=20, batch_size=32,
                   subgraph_ratio=0.8, mask_prob=0.2, temp=0.5):
    """
    SimGRACE: subgraph + mask augmentations with explicit num_nodes.
    """
    def subview(data):
        # For a Batch object, num_nodes is the total node count
        N = data.num_nodes if hasattr(data, 'num_nodes') else data.x.size(0)
        k = max(2, int(N * subgraph_ratio))
        idx = torch.randperm(N, device=data.x.device)[:k].tolist()
        ei, _ = subgraph(idx, data.edge_index,
                         relabel_nodes=True,
                         num_nodes=N)
        x_new   = data.x[idx]
        batch_n = data.batch[idx]
        return Data(x=x_new, edge_index=ei, batch=batch_n)

    def maskview(data):
        N = data.x.size(0)
        mask = torch.rand(N, device=data.x.device) < mask_prob
        x_new = data.x.clone()
        x_new[mask] = 0
        return Data(x=x_new, edge_index=data.edge_index, batch=data.batch)

    enc = Encoder(in_dim, hidden, embed_dim).to(device)
    opt = torch.optim.Adam(enc.parameters(), lr=lr)
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

    def nt_xent(z1, z2):
        z1, z2 = F.normalize(z1, dim=1), F.normalize(z2, dim=1)
        logits = torch.mm(z1, z2.t()) / temp
        labels = torch.arange(z1.size(0), device=device)
        return F.cross_entropy(logits, labels)

    for ep in range(1, epochs + 1):
        total = 0.0
        for batch in loader:
            batch = batch.to(device)
            b1 = subview(batch).to(device)
            b2 = maskview(batch).to(device)
            g1 = global_mean_pool(enc(b1.x, b1.edge_index, b1.batch), b1.batch)
            g2 = global_mean_pool(enc(b2.x, b2.edge_index, b2.batch), b2.batch)
            loss = nt_xent(g1, g2)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"[SimGRACE] {ep}/{epochs} — Loss: {total/len(loader):.4f}")
    return enc

# ─── SVM evaluation ───────────────────────────────────────────────────────────
def extract_embeddings(enc, data_list):
    enc.eval()
    embs = []
    for d in data_list:
        d = d.to(device)
        h = enc(d.x, d.edge_index, d.batch)
        g = global_mean_pool(h, d.batch)
        embs.append(g.detach().cpu().numpy())
    return np.vstack(embs)

from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
from sklearn.svm import SVC
import numpy as np

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
import numpy as np

def evaluate_repeated_cv_fast(X, y,
                              n_splits=10,
                              n_repeats=1000,
                              random_state=42,
                              classifier=None):
    """
    Repeated Stratified K-Fold, performing n_repeats × n_splits evaluations,
    and using a faster linear classifier (SGDClassifier by default) to speed up training.
    Returns (mean_accuracy, std_accuracy) rounded to 4 decimal places.
    """
    # Replace RBF-SVC with a linear SVM (SGDClassifier with loss='hinge') to speed up training
    if classifier is None:
        classifier = SGDClassifier(loss='hinge', max_iter=1000, tol=1e-3, random_state=random_state)

    rkf = RepeatedStratifiedKFold(
        n_splits=n_splits,
        n_repeats=n_repeats,
        random_state=random_state
    )
    # Run in parallel
    scores = cross_val_score(
        classifier, X, y,
        cv=rkf,
        scoring='accuracy',
        n_jobs=-1
    )

    mean = float(np.mean(scores))
    std  = float(np.std(scores))
    return round(mean, 4), round(std, 4)

# —— Example usage ——
# mean, std = evaluate_repeated_cv_fast(X_dgi, y, n_splits=10, n_repeats=5)
# print(f"SVM 5×10-fold acc: {mean:.4f} ± {std:.4f}")


In [7]:
import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphSAGE, global_mean_pool
from torch_geometric.utils import to_networkx, subgraph, dropout_edge

from sklearn.linear_model import LinearRegression


def build_sage(in_dim, hidden_channels=64, num_layers=2, embed_dim=128):
    """Build a GraphSAGE encoder."""
    return GraphSAGE(
        in_channels=in_dim,
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        out_channels=embed_dim
    )


def compute_box_dim(G):
    """Compute the box dimension of a NetworkX graph."""
    if nx.is_directed(G):
        G = G.to_undirected()
    if not nx.is_connected(G):
        G = G.subgraph(max(nx.connected_components(G), key=len)).copy()
    try:
        D = nx.diameter(G)
    except Exception:
        return 0.0, 0.0

    max_l, ls, Ns = max(1, D // 2), [], []
    for l in range(1, max_l + 1):
        r, unc, cnt = l // 2, set(G.nodes()), 0
        if l % 2 == 0:                        # single-center covering
            while unc:
                best = max((set(nx.single_source_shortest_path_length(G, u, cutoff=r)) & unc)
                           for u in unc)
                unc -= best
                cnt += 1
        else:                                 # try dual-center covering
            while unc:
                ubest = set()
                for u in unc:
                    for v in G.neighbors(u):
                        uni = ((set(nx.single_source_shortest_path_length(G, u, cutoff=r)) |
                                set(nx.single_source_shortest_path_length(G, v, cutoff=r)))
                               & unc)
                        if len(uni) > len(ubest):
                            ubest = uni
                if ubest:
                    unc -= ubest
                else:                         # fall back to single center
                    best = max((set(nx.single_source_shortest_path_length(G, u, cutoff=r)) & unc)
                               for u in unc)
                    unc -= best
                cnt += 1
        ls.append(l)
        Ns.append(cnt)

    if len(ls) < 2:
        return 0.0, 0.0

    log_l = np.log(np.array(ls)).reshape(-1, 1)
    log_N = np.log(np.array(Ns))
    reg = LinearRegression().fit(log_l, log_N)
    return reg.score(log_l, log_N), -reg.coef_[0]


def renormalise(data, r=1):
    """Three-Step Random-Centre Renormalisation augmentation."""
    G = to_networkx(data, to_undirected=True)
    if not nx.is_connected(G):
        G = G.subgraph(max(nx.connected_components(G), key=len)).copy()

    nodes, supers = set(G.nodes()), []
    while nodes:
        u = nodes.pop()
        reach = set(nx.single_source_shortest_path_length(G, u, cutoff=r).keys())
        U = {u} | (reach & nodes)
        supers.append(U)
        nodes -= U

    assign = {v: i for i, U in enumerate(supers) for v in U}
    edges = {
        (min(assign[u], assign[v]), max(assign[u], assign[v]))
        for u, v in G.edges() if assign[u] != assign[v]
    }
    feats = [
        data.x[torch.tensor(list(U), dtype=torch.long)].mean(dim=0)
        for U in supers
    ]
    x_new = torch.stack(feats, dim=0)
    ei = (torch.tensor(list(edges), dtype=torch.long).t().contiguous()
          if edges else torch.empty((2, 0), dtype=torch.long))
    return Data(x=x_new, edge_index=ei)


def augment_drop(data, p=0.2):
    """Node-drop augmentation with fallback when everything is dropped."""
    Nn = data.x.size(0)
    mask = torch.rand(Nn, device=data.x.device) > p
    idx = mask.nonzero(as_tuple=False).view(-1)
    if idx.numel() == 0:  # if all nodes were dropped, randomly keep one
        idx = torch.tensor([torch.randint(Nn, (1,), device=data.x.device)],
                           dtype=torch.long)
    ei, _ = subgraph(idx, data.edge_index, relabel_nodes=True, num_nodes=Nn)
    return Data(x=data.x[idx], edge_index=ei)


def extract_embeddings(enc, data_list):
    """Extract graph-level embeddings for a list of PyG Data objects."""
    enc.eval()
    out = []
    for d in data_list:
        batch = torch.zeros(d.x.size(0), dtype=torch.long, device=d.x.device)
        h = enc(d.x, d.edge_index, batch)
        g = global_mean_pool(h, batch)
        out.append(g.detach().cpu().numpy())
    return np.vstack(out)


def train_fractal(data_list, in_dim, diameters, dims,
                  hidden, layers, embed_dim,
                  alpha=1.0, renorm_r=1,
                  lr=1e-3, epochs=20, batch_size=16, drop_prob=0.2):
    """Train a FractalGCL encoder and return the trained model."""
    enc = build_sage(in_dim, hidden, layers, embed_dim)
    opt = torch.optim.Adam(enc.parameters(), lr=lr)

    def tau2(D):
        """Variance schedule used in the Gn perturbation."""
        Dm = max(D, 2)
        return 6.0 / (Dm * (np.log(Dm) ** 2 + 1e-12))

    enc.train()
    for ep in range(1, epochs + 1):
        perm = torch.randperm(len(data_list)).tolist()
        total = 0.0
        for i in range(0, len(data_list), batch_size):
            ids = perm[i:i + batch_size]

            # Two complementary views
            B1 = [renormalise(data_list[j], renorm_r) for j in ids]
            B2 = [augment_drop(data_list[j], drop_prob) for j in ids]
            b1 = Batch.from_data_list(B1)
            b2 = Batch.from_data_list(B2)

            h1 = enc(b1.x, b1.edge_index, b1.batch)
            h2 = enc(b2.x, b2.edge_index, b2.batch)
            z1 = global_mean_pool(h1, b1.batch)
            z2 = global_mean_pool(h2, b2.batch)

            # Contrastive similarity matrix
            S = F.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=-1) / 0.5  # (B, B)

            B = S.size(0)
            # Dynamically construct Gn
            Gn = torch.zeros((B, B), device=S.device)
            for m in range(B):
                for n in range(B):
                    if m == n:
                        Gn[m, n] = torch.randn(1, device=S.device) * np.sqrt(tau2(diameters[ids[m]]))
                    else:
                        mu  = abs(dims[ids[m]] - dims[ids[n]])
                        var = tau2(diameters[ids[m]]) + tau2(diameters[ids[n]])
                        Gn[m, n] = mu + torch.randn(1, device=S.device) * np.sqrt(var)

            S2 = S + alpha * Gn
            E  = torch.exp(S2)
            loss = (-torch.log(E.diag() / E.sum(dim=1))).mean()

            opt.zero_grad()
            loss.backward()
            opt.step()

            total += loss.item()

        n_batches = (len(data_list) - 1) // batch_size + 1
        print(f"[Fractal] Epoch {ep}/{epochs}  Loss = {total / n_batches:.4f}")

    return enc


In [8]:
from gis_utils import get_file_paths
import torch
import random
import os

cities = ["Chicago", "SF", "NY"]
baseline_encoders = {}

# for city in cities:
#     print(f"\n===== Training baselines for {city} =====")
#     paths     = get_file_paths(city)
#     data_list = torch.load(f"{paths['save_path']}/data_list.pt", weights_only=False)
#     in_dim    = data_list[0].x.size(1)
#     print(f"  • {len(data_list)} graphs, feature-dim = {in_dim}")


for city in cities:
    # Randomly select a sample ID between 1 and 100
    sample_id = random.randint(1, 100)
    print(f"\n===== Training baselines for {city} (using sample {sample_id}) =====")

    # Construct the full sample directory path on Drive
    sample_dir = f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}"
    file_path  = f"{sample_dir}/data_list.pt"
    print("  Loading from", file_path)

    # Load the sampled data list
    data_list = torch.load(file_path, weights_only=False)
    in_dim    = data_list[0].x.size(1)
    print(f"  • {len(data_list)} graphs, feature-dim = {in_dim}")



    # Train each method
    print("  ⏳ DGI…", end=" ")
    dgi_enc = train_dgi(      data_list, in_dim, 64, 128, 1e-3, 20, 32)
    print("done")

    print("  ⏳ InfoGraph…", end=" ")
    infograph_enc = train_infograph(
        data_list, in_dim, 64, 128, 1e-3, 20, 32, 0.5
    )
    print("done")

    print("  ⏳ GCL…", end=" ")
    gcl_enc = train_gcl_manual(
        data_list, in_dim, 64, 128, 1e-3, 20, 32, 0.2, 0.8, 0.5
    )
    print("done")

    print("  ⏳ JOAO…", end=" ")
    joao_enc = train_joao(
        data_list, in_dim, 64, 128, 1e-3, 20, 32
    )
    print("done")

    print("  ⏳ SimGRACE…", end=" ")
    simgrace_enc = train_simgrace(
        data_list, in_dim, 64, 128, 1e-3, 20, 32, 0.8, 0.2, 0.5
    )
    print("done")

    baseline_encoders[city] = {
        "dgi": dgi_enc,
        "infograph": infograph_enc,
        "gcl": gcl_enc,
        "joao": joao_enc,
        "simgrace": simgrace_enc,
    }

print("\n✅ All baseline encoders trained and stored in `baseline_encoders`.")



===== Training baselines for Chicago (using sample 21) =====
  Loading from /content/drive/MyDrive/FractalGCL/city/data/Chicago-share_data/processed_21/data_list.pt
  • 100 graphs, feature-dim = 9
  ⏳ DGI… [DGI] 1/20 — Loss: 0.6804
[DGI] 2/20 — Loss: 0.6571
[DGI] 3/20 — Loss: 0.6267
[DGI] 4/20 — Loss: 0.6091
[DGI] 5/20 — Loss: 0.6177
[DGI] 6/20 — Loss: 0.5842
[DGI] 7/20 — Loss: 0.6099
[DGI] 8/20 — Loss: 0.5979
[DGI] 9/20 — Loss: 0.6018
[DGI] 10/20 — Loss: 0.5566
[DGI] 11/20 — Loss: 0.5755
[DGI] 12/20 — Loss: 0.5710
[DGI] 13/20 — Loss: 0.5768
[DGI] 14/20 — Loss: 0.5445
[DGI] 15/20 — Loss: 0.5717
[DGI] 16/20 — Loss: 0.5618
[DGI] 17/20 — Loss: 0.5715
[DGI] 18/20 — Loss: 0.5604
[DGI] 19/20 — Loss: 0.5574
[DGI] 20/20 — Loss: 0.5714
done
  ⏳ InfoGraph… [InfoGraph] 1/20 — Loss: 2.9411
[InfoGraph] 2/20 — Loss: 2.9197
[InfoGraph] 3/20 — Loss: 2.8516
[InfoGraph] 4/20 — Loss: 2.8640
[InfoGraph] 5/20 — Loss: 2.6996
[InfoGraph] 6/20 — Loss: 2.8135
[InfoGraph] 7/20 — Loss: 2.7655
[InfoGraph] 8/20 —

In [9]:
from google.colab import drive
import torch
import networkx as nx
import pandas as pd
from torch_geometric.utils import to_networkx

# Cities to process
cities = ["Chicago", "SF", "NY"]

# ─── Hyperparameters ─────────────────────────────────────────────
batch_size      = 16
hidden_channels = 64
num_layers      = 2
embed_dim       = 128
drop_prob       = 0.1
alpha           = 0.4
renorm_radius   = 1.0
fractal_epochs  = 20
lr              = 1e-3
# ─────────────────────────────────────────────────────────────────

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

fractal_encoders = {}
for city in cities:
    print(f"\n===== Training FractalGCL for {city} =====")
    paths      = get_file_paths(city)
    # full_list  = torch.load(f"{paths['save_path']}/data_list.pt", weights_only=False)
    full_list = torch.load(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/data_list.pt",weights_only=False)
    # df         = pd.read_csv(f"{paths['save_path']}/feature_data.csv")
    df = pd.read_csv(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/feature_data.csv")
    median_acc = df['total_accidents'].median()
    y_full     = (df['total_accidents'] > median_acc).astype(int).values

    # Align lengths
    n_graphs  = min(len(full_list), len(y_full))
    data_list = full_list[:n_graphs]
    print(f"{city}: {n_graphs} graphs")

    # Compute diameter & box dimension
    diameters, dims = [], []
    for d in data_list:
        G = to_networkx(d, to_undirected=True)
        if not nx.is_connected(G):
            G = G.subgraph(max(nx.connected_components(G), key=len)).copy()
        diameters.append(nx.diameter(G))
        dims.append(compute_box_dim(G)[1])

    # Pad / trim features to a uniform dimension
    fractal_in_dim = data_list[0].x.size(1)
    def pad_trim(dl, dim):
        for d in dl:
            x, diff = d.x, dim - d.x.size(1)
            if diff > 0:
                pad = torch.zeros(x.size(0), diff, device=x.device)
                d.x = torch.cat([x, pad], dim=1)
            elif diff < 0:
                d.x = x[:, :dim]
        return dl
    data_list = pad_trim(data_list, fractal_in_dim)

    # Train FractalGCL
    fractal_enc = train_fractal(
        data_list, fractal_in_dim, diameters, dims,
        hidden_channels, num_layers, embed_dim,
        alpha=alpha, renorm_r=renorm_radius,
        lr=lr, epochs=fractal_epochs, batch_size=batch_size,
        drop_prob=drop_prob
    )
    fractal_enc.eval()
    fractal_encoders[city] = fractal_enc

# Save all models to disk (e.g., with pickle)
import pickle
with open('/content/drive/MyDrive/fractal_encoders.pkl', 'wb') as f:
    pickle.dump(fractal_encoders, f)
print("\n✅ FractalGCL models have been trained and saved.")


Mounted at /content/drive

===== Training FractalGCL for Chicago =====
Chicago: 100 graphs
[Fractal] Epoch 1/20  Loss = 2.5871
[Fractal] Epoch 2/20  Loss = 2.5127
[Fractal] Epoch 3/20  Loss = 2.2753
[Fractal] Epoch 4/20  Loss = 2.0797
[Fractal] Epoch 5/20  Loss = 2.0905
[Fractal] Epoch 6/20  Loss = 2.0343
[Fractal] Epoch 7/20  Loss = 2.0291
[Fractal] Epoch 8/20  Loss = 1.9891
[Fractal] Epoch 9/20  Loss = 2.2122
[Fractal] Epoch 10/20  Loss = 2.0388
[Fractal] Epoch 11/20  Loss = 2.0781
[Fractal] Epoch 12/20  Loss = 2.0003
[Fractal] Epoch 13/20  Loss = 1.8624
[Fractal] Epoch 14/20  Loss = 1.9165
[Fractal] Epoch 15/20  Loss = 1.8333
[Fractal] Epoch 16/20  Loss = 1.8551
[Fractal] Epoch 17/20  Loss = 1.7560
[Fractal] Epoch 18/20  Loss = 1.8686
[Fractal] Epoch 19/20  Loss = 1.7209
[Fractal] Epoch 20/20  Loss = 1.6942

===== Training FractalGCL for SF =====
SF: 100 graphs
[Fractal] Epoch 1/20  Loss = 2.5492
[Fractal] Epoch 2/20  Loss = 2.3646
[Fractal] Epoch 3/20  Loss = 1.9510
[Fractal] Epoch

In [11]:
from gis_utils import get_file_paths
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
from sklearn.linear_model import SGDClassifier

# Automatically adjust the number of folds; skip the city-task pair if samples are insufficient
def evaluate_with_auto_splits(
    X, y,
    n_splits=10,
    n_repeats=1000,
    random_state=42,
    classifier=None
):
    """
    Perform repeated stratified k-fold evaluation with an adaptive number of splits.
    Returns (mean_acc, std_acc) rounded to 4 decimals, or None if the task is skipped.
    """
    counts = np.bincount(y)
    min_count = counts[counts > 0].min()
    n_splits_adj = min(n_splits, min_count)
    if n_splits_adj < 2:
        return None                     # Not enough samples, skip
    if classifier is None:
        classifier = SGDClassifier(
            loss='hinge',
            max_iter=1000,
            tol=1e-3,
            random_state=random_state
        )

    rkf = RepeatedStratifiedKFold(
        n_splits=n_splits_adj,
        n_repeats=n_repeats,
        random_state=random_state
    )
    scores = cross_val_score(
        classifier, X, y,
        cv=rkf,
        scoring='accuracy',
        n_jobs=-1
    )
    return round(scores.mean(), 4), round(scores.std(), 4)


# Assume these are already in memory:
# baseline_encoders = {"Chicago": {...}, "SF": {...}, "NY": {...}}
# fractal_encoders  = {"Chicago": enc_chi, "SF": enc_sf, "NY": enc_ny}
# extract_embeddings(enc, data_list) is defined elsewhere

cities = ["Chicago", "SF", "NY"]
model_names = list(baseline_encoders[cities[0]].keys()) + ["FractalGCL"]

# Accident-related downstream tasks
tasks = {
    # 0. Total accident volume: high vs. low (binary)
    "total_accidents_high": lambda df: df['total_accidents']
        .gt(df['total_accidents'].median()).astype(int),

    # 1. Accident volume level: low / mid / high (three-class)
    "accident_volume_level": lambda df: pd.cut(
        df['total_accidents'],
        bins=[-1,
              df['total_accidents'].quantile(0.33),
              df['total_accidents'].quantile(0.67),
              df['total_accidents'].max()],
        labels=[0, 1, 2]
    ).astype(int),

    # 2. Entropy of severity distribution: high vs. low (binary)
    "severity_entropy": lambda df: (
        pd.concat([
            df[['severity_1', 'severity_2', 'severity_3', 'severity_4']]
              .div(df['total_accidents'], axis=0)
              .replace(0, np.nan)
              .apply(lambda row: -np.nansum(row * np.log(row)), axis=1)
        ], axis=1).iloc[:, 0]
        .gt(pd.concat([
            df[['severity_1', 'severity_2', 'severity_3', 'severity_4']]
              .div(df['total_accidents'], axis=0)
              .replace(0, np.nan)
              .apply(lambda row: -np.nansum(row * np.log(row)), axis=1)
        ], axis=1).iloc[:, 0].median())
        .astype(int)
    ),

    # 3. Presence of severity-3 accidents (binary)
    "has_sev3": lambda df: (df['severity_3'] > 0).astype(int),

    # 4. Presence of severity-4 accidents (binary)
    "has_sev4": lambda df: (df['severity_4'] > 0).astype(int),

    # 5. Composite risk level: low / mid / high (three-class)
    #    High risk: total > median  &  (sev3+sev4)/total > median  → 2
    #    Low  risk: total ≤ median &  (sev3+sev4)/total ≤ median  → 0
    #    Mid  risk: otherwise                                      → 1
    "risk_level": lambda df: (
        pd.Series(np.where(
            (df['total_accidents'] > df['total_accidents'].median()) &
            (((df['severity_3'] + df['severity_4']) / df['total_accidents'])
              > ((df['severity_3'] + df['severity_4']) / df['total_accidents']).median()),
            2,
            np.where(
                (df['total_accidents'] <= df['total_accidents'].median()) &
                (((df['severity_3'] + df['severity_4']) / df['total_accidents'])
                  <= ((df['severity_3'] + df['severity_4']) / df['total_accidents']).median()),
                0, 1
            )
        ), index=df.index).astype(int)
    )
}

# Aggregate and print results by task
for task_name, label_fn in tasks.items():
    print(f"\n===== Task: {task_name} =====")
    for city in cities:
        paths     = get_file_paths(city)
        # df        = pd.read_csv(f"{paths['save_path']}/feature_data.csv")

        df = pd.read_csv(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/feature_data.csv")

        # data_list = torch.load(f"{paths['save_path']}/data_list.pt", weights_only=False)

        data_list = torch.load(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/data_list.pt",weights_only=False)

        # Build labels; skip if construction fails or only one class exists
        try:
            y = label_fn(df)
        except Exception as e:
            print(f"{city}: failed to build labels ({e}), skipped")
            continue
        if y.nunique() < 2:
            print(f"{city}: num_classes = {y.nunique()} (insufficient), skipped")
            continue

        # Align sample counts
        n = min(len(data_list), len(y))
        data_list = data_list[:n]
        y = y.values[:n]

        # Extract embeddings and evaluate
        results = {}
        for name in model_names:
            enc = baseline_encoders[city].get(name) or fractal_encoders[city]
            X = extract_embeddings(enc, data_list)
            res = evaluate_with_auto_splits(X[:n], y)
            if res is not None:
                results[name] = res

        # Print per-city results
        if results:
            line = city + ": " + ", ".join(
                f"{m} {results[m][0]:.4f}±{results[m][1]:.4f}"
                for m in model_names if m in results
            )
            print(line)
        else:
            print(f"{city}: no model evaluated or all skipped")



===== Task: total_accidents_high =====
Chicago: dgi 0.5389±0.1047, infograph 0.5312±0.0979, gcl 0.5596±0.1227, joao 0.5390±0.1186, simgrace 0.5644±0.1239, FractalGCL 0.6094±0.1382
SF: dgi 0.7322±0.1380, infograph 0.7585±0.1328, gcl 0.7778±0.1296, joao 0.7631±0.1446, simgrace 0.7790±0.1303, FractalGCL 0.7736±0.1287
NY: dgi 0.5065±0.0550, infograph 0.5338±0.0898, gcl 0.5876±0.1269, joao 0.5050±0.1011, simgrace 0.5117±0.1159, FractalGCL 0.6418±0.1288

===== Task: accident_volume_level =====
Chicago: dgi 0.3998±0.1134, infograph 0.4133±0.1184, gcl 0.4878±0.1434, joao 0.4118±0.1267, simgrace 0.4665±0.1329, FractalGCL 0.5656±0.1476
SF: dgi 0.5813±0.1269, infograph 0.5955±0.1240, gcl 0.6042±0.1240, joao 0.5790±0.1158, simgrace 0.6049±0.1233, FractalGCL 0.6044±0.1278
NY: dgi 0.3484±0.0926, infograph 0.3479±0.0906, gcl 0.3901±0.1298, joao 0.3414±0.1099, simgrace 0.3388±0.1186, FractalGCL 0.4205±0.1286

===== Task: severity_entropy =====
Chicago: dgi 0.4964±0.0553, infograph 0.4958±0.0598, gcl 

In [12]:
from gis_utils import get_file_paths
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
from sklearn.linear_model import SGDClassifier

# Automatically adjust the number of folds; skip the city–task pair if samples are insufficient
def evaluate_with_auto_splits(
    X, y,
    n_splits=10,
    n_repeats=1000,
    random_state=42,
    classifier=None
):
    counts = np.bincount(y)
    min_count = counts[counts > 0].min()
    n_splits_adj = min(n_splits, min_count)
    if n_splits_adj < 2:
        return None                                       # Too few samples → skip
    if classifier is None:
        classifier = SGDClassifier(
            loss='hinge',
            max_iter=1000,
            tol=1e-3,
            random_state=random_state
        )
    rkf = RepeatedStratifiedKFold(
        n_splits=n_splits_adj,
        n_repeats=n_repeats,
        random_state=random_state
    )
    scores = cross_val_score(
        classifier, X, y,
        cv=rkf,
        scoring='accuracy',
        n_jobs=-1
    )
    return round(scores.mean(), 4), round(scores.std(), 4)

# Assumed to already exist in memory:
# baseline_encoders = {"Chicago": {...}, "SF": {...}, "NY": {...}}
# fractal_encoders  = {"Chicago": enc_chi, "SF": enc_sf, "NY": enc_ny}
# extract_embeddings(enc, data_list) is defined elsewhere

cities = ["Chicago", "SF", "NY"]
model_names = list(baseline_encoders[cities[0]].keys()) + ["FractalGCL"]

# Define six non-accident downstream tasks
tasks = {
    # 1. Dominant land-use function (6-class):
    #    office, sustenance, transportation, retail, leisure, residence
    "dominant_poi": lambda df: df[
        ['office_tags', 'sustenance_tags', 'transportation_tags',
         'retail_tags', 'leisure_tags', 'residence_tags']
    ].idxmax(axis=1).map({
        'office_tags': 0, 'sustenance_tags': 1, 'transportation_tags': 2,
        'retail_tags': 3, 'leisure_tags': 4, 'residence_tags': 5
    }).astype(int),

    # 2. POI mix-entropy level (3-class): tertiles of POI entropy
    "poi_entropy_level": lambda df: pd.qcut(
        -np.sum(
            (p := df[
                ['office_tags', 'sustenance_tags', 'transportation_tags',
                 'retail_tags', 'leisure_tags', 'residence_tags']
            ].div(
                df[
                    ['office_tags', 'sustenance_tags', 'transportation_tags',
                     'retail_tags', 'leisure_tags', 'residence_tags']
                ].sum(axis=1),
                axis=0
            )).replace(0, 1e-12) * np.log(p),
            axis=1
        ),
        3, labels=[0, 1, 2]
    ).astype(int),

    # 3. Population-density level (4-class): quartiles of population density
    "pop_density_level": lambda df: pd.qcut(
        df['population_density'], 4, labels=[0, 1, 2, 3]
    ).astype(int),

    # 4. Function × density combo (12-class):
    #    dominant_poi + 6 × (high-pop? 1 : 0)
    "func_density_combo": lambda df: (
        df[
            ['office_tags', 'sustenance_tags', 'transportation_tags',
             'retail_tags', 'leisure_tags', 'residence_tags']
        ].idxmax(axis=1).map({
            'office_tags'      : 0, 'sustenance_tags'     : 1,
            'transportation_tags': 2, 'retail_tags'       : 3,
            'leisure_tags'     : 4, 'residence_tags'      : 5
        }).astype(int)
        + 6 * (df['population_density'] > df['population_density'].median()).astype(int)
    ),

    # 5. Night-life hotspot (binary):
    #    leisure / residence ratio above median → 1, else 0
    "nightlife_hotspot": lambda df: (
        (df['leisure_tags'] / (df['residence_tags'] + 1e-8))
        .gt((df['leisure_tags'] / (df['residence_tags'] + 1e-8)).median())
        .astype(int)
    ),

    # 6. Commercial–residential mix (3-class):
    #    tertiles of retail / (retail + residence)
    "commercial_residential_mix": lambda df: pd.qcut(
        df['retail_tags'] / (df['retail_tags'] + df['residence_tags'] + 1e-8),
        3, labels=[0, 1, 2]
    ).astype(int),
}

# Evaluate every task and print aggregated results
for task_name, label_fn in tasks.items():
    print(f"\n===== Task: {task_name} =====")
    for city in cities:
        paths     = get_file_paths(city)

        #  df        = pd.read_csv(f"{paths['save_path']}/feature_data.csv")
        df = pd.read_csv(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/feature_data.csv")

        # data_list = torch.load(f"{paths['save_path']}/data_list.pt", weights_only=False)
        data_list = torch.load(f"/content/drive/MyDrive/FractalGCL/city/data/{city}-share_data/processed_{sample_id}/data_list.pt",weights_only=False)


        # Build labels; skip city if label construction fails or has <2 classes
        try:
            y = label_fn(df)
        except Exception as e:
            print(f"{city}: label construction failed ({e}), skipped")
            continue
        if y.nunique() < 2:
            print(f"{city}: num_classes = {y.nunique()} (insufficient), skipped")
            continue

        # Align sample counts
        n = min(len(data_list), len(y))
        data_list = data_list[:n]
        y = y.values[:n]

        # Extract embeddings and evaluate
        results = {}
        for name in model_names:
            enc = baseline_encoders[city].get(name) or fractal_encoders[city]
            X = extract_embeddings(enc, data_list)
            res = evaluate_with_auto_splits(X[:n], y)
            if res is not None:
                results[name] = res

        # Print per-city results
        if results:
            line = city + ": " + ", ".join(
                f"{m} {results[m][0]:.4f}±{results[m][1]:.4f}"
                for m in model_names if m in results
            )
            print(line)
        else:
            print(f"{city}: no model evaluated or all skipped")



===== Task: dominant_poi =====
Chicago: dgi 0.7959±0.1558, infograph 0.7869±0.1566, gcl 0.7908±0.1153, joao 0.7897±0.1202, simgrace 0.7971±0.1175, FractalGCL 0.7934±0.0886
SF: dgi 0.6476±0.1742, infograph 0.6510±0.1735, gcl 0.6686±0.1491, joao 0.6572±0.1602, simgrace 0.6690±0.1525, FractalGCL 0.6913±0.1342
NY: dgi 0.4849±0.1421, infograph 0.4748±0.1404, gcl 0.5057±0.1398, joao 0.4976±0.1467, simgrace 0.5027±0.1450, FractalGCL 0.5058±0.1405

===== Task: poi_entropy_level =====


  result = func(self.values, **kwargs)
  diff_b_a = subtract(b, a)


Chicago: dgi 0.3588±0.0951, infograph 0.3622±0.0972, gcl 0.3883±0.1173, joao 0.3720±0.1118, simgrace 0.3847±0.1183, FractalGCL 0.4071±0.1301


  result = func(self.values, **kwargs)
  diff_b_a = subtract(b, a)


SF: dgi 0.4220±0.1173, infograph 0.4146±0.1169, gcl 0.4329±0.1270, joao 0.4092±0.1204, simgrace 0.4340±0.1256, FractalGCL 0.4323±0.1315


  result = func(self.values, **kwargs)
  diff_b_a = subtract(b, a)


NY: dgi 0.4378±0.1254, infograph 0.4432±0.1253, gcl 0.4513±0.1382, joao 0.4039±0.1293, simgrace 0.4200±0.1344, FractalGCL 0.4722±0.1374

===== Task: pop_density_level =====
Chicago: dgi 0.3400±0.1015, infograph 0.3400±0.1017, gcl 0.3528±0.1137, joao 0.3451±0.1069, simgrace 0.3538±0.1102, FractalGCL 0.3659±0.1224
SF: dgi 0.4254±0.1338, infograph 0.4455±0.1342, gcl 0.4602±0.1374, joao 0.3965±0.1195, simgrace 0.4642±0.1367, FractalGCL 0.4671±0.1413
NY: dgi 0.4402±0.1311, infograph 0.4587±0.1345, gcl 0.4810±0.1395, joao 0.4389±0.1300, simgrace 0.4667±0.1392, FractalGCL 0.5083±0.1398

===== Task: func_density_combo =====
Chicago: no model evaluated or all skipped
SF: dgi 0.4116±0.1429, infograph 0.4184±0.1428, gcl 0.4314±0.1432, joao 0.4041±0.1486, simgrace 0.4360±0.1401, FractalGCL 0.4373±0.1429
NY: dgi 0.3755±0.1341, infograph 0.3650±0.1357, gcl 0.3839±0.1348, joao 0.3755±0.1384, simgrace 0.3809±0.1347, FractalGCL 0.3927±0.1339

===== Task: nightlife_hotspot =====
Chicago: dgi 0.5589±0.11