In [1]:
from srai.embedders import GeoVexEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.neighbourhoods import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.plotting import plot_regions
from srai.h3 import ring_buffer_h3_regions_gdf

import warnings

from pytorch_lightning import seed_everything
import pandas as pd
import torch

In [2]:
SEED = 71
seed_everything(SEED)

Seed set to 71


71

## Get City Boundary

In [3]:
area_gdf = geocode_to_region_gdf("Calgary, Alberta")
plot_regions(area_gdf, tiles_style="CartoDB positron")

## Get H3 Regions

In [4]:
resolution = 10
k_ring_buffer_radius = 3

regionalizer = H3Regionalizer(resolution=resolution)
base_h3_regions = regionalizer.transform(area_gdf)

buffered_h3_regions = ring_buffer_h3_regions_gdf(base_h3_regions, distance=k_ring_buffer_radius)
buffered_h3_geometry = buffered_h3_regions.unary_union

print("Base regions:", len(base_h3_regions))
print("Buffered regions:", len(buffered_h3_regions))

Base regions: 59799
Buffered regions: 64248


In [5]:
buffered_h3_regions.head()

Unnamed: 0_level_0,geometry
region_id,Unnamed: 1_level_1
8a12ccd40cf7fff,"POLYGON ((-114.00126 50.99468, -114.00196 50.9..."
8a12ccc54417fff,"POLYGON ((-114.06325 50.84542, -114.06395 50.8..."
8a12ccd50447fff,"POLYGON ((-114.08356 50.97384, -114.08426 50.9..."
8a12ccd52697fff,"POLYGON ((-114.11083 50.95556, -114.11153 50.9..."
8a12ccc0395ffff,"POLYGON ((-114.16807 50.89519, -114.16877 50.8..."


## Get OSM Features/Tags

OSM pbf file can be downlaoded from https://download.geofabrik.de/ and extracting an bbox using osmium

In [6]:
tags = GEOFABRIK_LAYERS

loader = OSMPbfLoader(pbf_file="files/calgary.osm.pbf")

features_gdf = loader.load(buffered_h3_geometry, tags)

[calgary.osm.pbf] Counting pbf features: 2166490it [00:03, 604767.13it/s]
[calgary.osm.pbf] Parsing pbf file #1: 100%|█| 2166490/2166490 [00:30<00:00, 712
Grouping features: 100%|████████████████████████| 28/28 [00:04<00:00,  6.35it/s]


## Join OSM Features and Regions

In [7]:
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(buffered_h3_regions, features_gdf)
joint_gdf

region_id,feature_id
8a12ccd40cf7fff,way/1180114739
8a12ccd40cc7fff,way/1180114739
8a12ccd40cf7fff,node/10960664951
8a12ccd40cf7fff,way/421794295
8a12ccd40caffff,way/421794295
...,...
8a12ccd41277fff,node/9330166960
8a12ccd41277fff,way/541627430
8a12ccd41277fff,node/9330172115
8a12ccd41277fff,node/9330171824


## Train GeoVex Embedder

In [8]:
neighbourhood = H3Neighbourhood(buffered_h3_regions)

embedder = GeoVexEmbedder(
    target_features=GEOFABRIK_LAYERS,
    batch_size=32,
    neighbourhood_radius=k_ring_buffer_radius,
    convolutional_layers=2,
    embedding_size=3
)

In [9]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    embeddings = embedder.fit_transform(
        regions_gdf=buffered_h3_regions,
        features_gdf=features_gdf,
        joint_gdf=joint_gdf,
        neighbourhood=neighbourhood,
        trainer_kwargs={
            "max_epochs": 10,
            "accelerator": (
                "cpu" if torch.backends.mps.is_available() else "auto"
            ),  # GeoVexEmbedder does not support MPS
        },
        learning_rate=0.001,
    )

embeddings.head()

100%|██████████████████████████████████| 64248/64248 [00:05<00:00, 11472.67it/s]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 2.0 M 
1 | decoder | Sequential | 1.4 M 
2 | _loss   | GeoVeXLoss | 0     
---------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.493    Total estimated model params size (MB)


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.




Unnamed: 0_level_0,0,1,2
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
8a12ccd40cf7fff,11.212477,20.826132,24.475948
8a12ccd50447fff,10.915584,11.286699,47.095558
8a12ccd52697fff,32.745316,-3.102635,65.7164
8a12ccc0395ffff,36.695339,35.995281,8.540507
8a12ccc7134ffff,28.678244,11.095741,47.559105


In [10]:
embeddings.to_csv("geovex_embeddings_3.csv")