# This Example shows the Prediction of Bike Flow in the NYC City using the deep learning model DeepSTN.

Find the details of the ST-ResNet model in the <a href="https://dl.acm.org/doi/10.5555/3298239.3298479">corresponding paper</a>

Details of the dataset can be found <a href="https://github.com/FIBLAB/DeepSTN">here</a>.

### Import Modules

In [1]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from geotorchai.models.grid import STResNet, DeepSTN
from geotorchai.preprocessing.torch_df import SpatiotemporalDfToTorchData
from geotorchai.preprocessing import SedonaRegistration, load_parquet_data
from geotorchai.preprocessing.grid import STManager as stm
from geotorchai.utility import TorchAdapter

# Import Apache Sedona
from sedona.spark import *

## Import PySpark
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, udf, expr, array, concat


## Import distributed modules
from torch.utils.data import DistributedSampler, DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from pyspark.ml.torch.distributor import TorchDistributor
from petastorm import TransformSpec

import warnings
# Ignore FutureWarning warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

## Define spark session and Register with GeoTorchAI

In [2]:
config = SedonaContext.builder().master(MASTER_URL).config('spark.jars.packages',
           'org.apache.sedona:sedona-spark-shaded-3.4_2.12:1.4.1,'
           'org.datasyslab:geotools-wrapper:1.4.0-28.2').getOrCreate()

sedona = SedonaContext.create(config)
sc = sedona.sparkContext

23/08/14 12:06:43 WARN Utils: Your hostname, Kanchans-Laptop.local resolves to a loopback address: 127.0.0.1; using 192.168.1.6 instead (on interface en0)
23/08/14 12:06:43 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Ivy Default Cache set to: /Users/kanchan/.ivy2/cache
The jars for the packages stored in: /Users/kanchan/.ivy2/jars
org.apache.sedona#sedona-spark-shaded-3.4_2.12 added as a dependency
org.datasyslab#geotools-wrapper added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-92768686-8c6e-4bd4-aea2-f844768567d6;1.0
	confs: [default]


:: loading settings :: url = jar:file:/Users/kanchan/.pyenv/versions/3.11.0/lib/python3.11/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


	found org.apache.sedona#sedona-spark-shaded-3.4_2.12;1.4.1 in central
	found org.datasyslab#geotools-wrapper;1.4.0-28.2 in central
:: resolution report :: resolve 98ms :: artifacts dl 2ms
	:: modules in use:
	org.apache.sedona#sedona-spark-shaded-3.4_2.12;1.4.1 from central in [default]
	org.datasyslab#geotools-wrapper;1.4.0-28.2 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   2   |   0   |   0   |   0   ||   2   |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-92768686-8c6e-4bd4-aea2-f844768567d6
	confs: [default]
	0 artifacts copied, 2 already retrieved (0kB/3ms)
23/08/14 12:06:44 WARN NativeCodeLoader: Unable to load

### Register SparkSession with GeoTorchAI

In [3]:
SedonaRegistration.set_sedona_context(sedona)

## Load Taxi Pick Up Data to Sedona

In [4]:
taxi_df = load_parquet_data("data/yellow_trip_10_fraction.parquet")
taxi_df.show(5)

+---------+-------------------+-------------------+---------------+-------------+------------------+---------------+---------+------------------+------------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|vendor_id|    pickup_datetime|   dropoff_datetime|passenger_count|trip_distance|  pickup_longitude|pickup_latitude|rate_code|store_and_fwd_flag| dropoff_longitude|dropoff_latitude|payment_type|fare_amount|surcharge|mta_tax|tip_amount|tolls_amount|total_amount|
+---------+-------------------+-------------------+---------------+-------------+------------------+---------------+---------+------------------+------------------+----------------+------------+-----------+---------+-------+----------+------------+------------+
|      CMT|2010-10-01 01:13:21|2010-10-01 01:31:15|              1|          4.7|        -74.001624|      40.750705|        1|                 N|        -73.961121|       40.768957|         CSH|       13.7|      0.

In [5]:
taxi_df = taxi_df.select("pickup_datetime", "pickup_longitude", "pickup_latitude")
taxi_df = taxi_df.filter(taxi_df.pickup_latitude >= 40.491370).filter(taxi_df.pickup_latitude <= 40.91553).filter(taxi_df.pickup_longitude >= -74.259090).filter(taxi_df.pickup_longitude <= -73.700180)
taxi_df.show(5)

+-------------------+------------------+---------------+
|    pickup_datetime|  pickup_longitude|pickup_latitude|
+-------------------+------------------+---------------+
|2010-10-01 01:13:21|        -74.001624|      40.750705|
|2010-10-01 00:48:05|-73.99296599999998|      40.753019|
|2010-10-01 09:07:57|        -73.981216|      40.767694|
|2010-10-03 12:51:53|        -73.781824|      40.644774|
|2010-10-01 02:24:12|        -73.995343|      40.719125|
+-------------------+------------------+---------------+
only showing top 5 rows



In [None]:
layer = stm.getHexagonalLayer(taxi_df, col_lat='pickup_latitude', col_lon='pickup_longitude', fraction=0.5)
layer.to_html('hexagon-example.html')

## Convert into a Spatial DataFrame

In [8]:
taxi_df = stm.add_spatial_points(taxi_df, lat_column="pickup_latitude", lon_column="pickup_longitude", new_column_alias="geometry").drop(*("pickup_latitude", "pickup_longitude"))
taxi_df.show(5, False)

+-------------------+------------------------------------+
|pickup_datetime    |geometry                            |
+-------------------+------------------------------------+
|2010-10-01 01:13:21|POINT (-74.001624 40.750705)        |
|2010-10-01 00:48:05|POINT (-73.99296599999998 40.753019)|
|2010-10-01 09:07:57|POINT (-73.981216 40.767694)        |
|2010-10-03 12:51:53|POINT (-73.781824 40.644774)        |
|2010-10-01 02:24:12|POINT (-73.995343 40.719125)        |
+-------------------+------------------------------------+
only showing top 5 rows



## Convert into a Spatiotemporal Grid DataFrame

In [9]:
polygons = stm.get_grid_cell_polygons(taxi_df, "geometry", partitions_x=12, partitions_y=12)
taxi_df = stm.get_st_grid_dataframe(geo_df = taxi_df, geometry = "geometry", partitions_x = 12, partitions_y = 12, col_date = "pickup_datetime", step_duration_sec = 3600)
taxi_df.show(5)

+-------+------------+------------------+
|cell_id|_id_timestep|aggregated_feature|
+-------+------------+------------------+
|     65|         172|                 1|
|     77|         516|                26|
|     78|         379|                 7|
|    102|         137|                13|
|     77|         589|                18|
+-------+------------+------------------+
only showing top 5 rows



[Stage 18:>                                                         (0 + 1) / 1]                                                                                

In [10]:
layer = stm.getStGridLayer(taxi_df, timestamp_index=20, col_timestamp="_id_timestep", col_feature="aggregated_feature", col_id="cell_id", polygons=polygons)
layer.to_html()

## Define parameters

In [11]:
learning_rate = 0.0002
batch_size = 32
epoch_nums = 500
map_height = 12
map_width = 12

### Convert Sedona DatFrame into GeoTorchAI Dataset

In [12]:
num_timsteps = stm.get_temporal_steps_count(taxi_df, temporal_steps_column="_id_timestep")
full_dataset = SpatiotemporalDfToTorchData(taxi_df, "_id_timestep", "cell_id", ["aggregated_feature"], num_timsteps, map_height, map_width)
full_dataset.set_periodical_representation()
min_max_difference, min_max_sum = full_dataset.get_min_max_info()

## Split into Train and Test

In [13]:
train_ratio = 0.8
train_size = int(train_ratio * len(full_dataset))

train_dataset = torch.utils.data.Subset(full_dataset, range(0, train_size))
test_dataset = torch.utils.data.Subset(full_dataset, range(train_size, len(full_dataset)))

### Method to Return Model

In [14]:
def get_model():
    map_height, map_width, nb_flow = 12, 12, 1   
    is_plus=False     
    is_pt=False
    drop=0.1

    ## Define Model
    model = DeepSTN(H=map_height, W=map_width, channel=1, is_plus=is_plus, is_pt=is_pt, dropVal=drop)
    return model

### Train the Model
Error will be high since the training is performed only for 20 epochs

In [15]:
def train_one_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    for i, sample in enumerate(train_loader):
        X_c = sample["x_closeness"].type(torch.FloatTensor).to(device)
        X_p = sample["x_period"].type(torch.FloatTensor).to(device)
        X_t = sample["x_trend"].type(torch.FloatTensor).to(device)
        Y_batch = sample["y_data"].type(torch.FloatTensor).to(device)

        with torch.set_grad_enabled(True):
            optimizer.zero_grad()

            # Forward pass
            outputs = model(X_c, X_p, X_t)
            loss = loss_fn(outputs, Y_batch)

            # Backward and optimize
            loss.backward()
            optimizer.step()
    return loss.item()

In [None]:
def train_model(model, loader, device):
    ## Define hyper-parameters
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for e in range(epoch_nums):
        epoch_loss = train_one_epoch(model, loader, optimizer, loss_fn, device)
        print('Epoch [{}/{}], Training Loss: {:.4f}'.format(e + 1, epoch_nums, epoch_loss))

In [None]:
def train_distributed(use_gpu):
    backend = "nccl" if use_gpu else "gloo"
    dist.init_process_group(backend)
    device = int(os.environ["LOCAL_RANK"]) if use_gpu  else "cpu"
    model = get_model().to(device)
    model_ddp = DDP(model)
    sampler = DistributedSampler(train_dataset)
    loader = DataLoader(full_dataset, batch_size=batch_size, sampler=sampler)

    train_model(model_ddp, loader, device)

## Start Distributed Training

In [None]:
distributor = TorchDistributor(num_processes=2, local_mode=True, use_gpu=False)
distributor.run(train_distributed, False)

## Evaluate a Trained Model

In [17]:
MODEL_PATH = "models/deepstn_nyc/deepstn_nyc.pth"

### Load a pretrained model

In [21]:
def load_model(model_path, device):
    model = get_model()
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    return model

In [16]:
def evaluate_model(model, test_loader, device):
    model.eval()
    
    rmse_list = []
    mae_list = []
    for i, sample in enumerate(test_loader):
        X_c = sample["x_closeness"].type(torch.FloatTensor).to(device)
        X_p = sample["x_period"].type(torch.FloatTensor).to(device)
        X_t = sample["x_trend"].type(torch.FloatTensor).to(device)
        Y_batch = sample["y_data"].type(torch.FloatTensor).to(device)

        # Forward pass
        outputs = model(X_c, X_p, X_t)
        mse, mae, rmse = TorchAdapter.compute_prediction_errors(outputs.cpu().data.numpy(), Y_batch.cpu().data.numpy())

        rmse_list.append(rmse)
        mae_list.append(mae)

    rmse = np.mean(rmse_list)
    mae = np.mean(mae_list)
    return mae, rmse

In [None]:
def test_distributed(use_gpu):
    backend = "nccl" if use_gpu else "gloo"
    dist.init_process_group(backend)
    device = int(os.environ["LOCAL_RANK"]) if use_gpu  else "cpu"
    model = load_model(MODEL_PATH, device).to(device)
    model_ddp = DDP(model)
    sampler = DistributedSampler(test_dataset)
    loader = DataLoader(full_dataset, batch_size=batch_size, sampler=sampler)
    mae, rmse = evaluate_model(model, loader, device)
    print('Test mae (norm): %.6f rmse (norm): %.6f, mae (real): %.6f, rmse (real): %.6f' % (mae, rmse, mae * min_max_difference / 2, rmse * min_max_difference / 2))

### Start Distributed Testing

In [None]:
distributor = TorchDistributor(num_processes=4, local_mode=True, use_gpu=False)
distributor.run(test_distributed, False)

## Output of a Selected Timestep

In [22]:
device = TorchAdapter.get_training_device()
model = load_model(MODEL_PATH, device).to(device)
model.eval()

DeepSTN(
  (conv1): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (conv2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (conv3): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (ptTrans): _PT_trans(
    (t_trans): _T_trans(
      (T_mid): Conv2d(31, 56, kernel_size=(1, 1), stride=(1, 1), padding=same)
      (T_fin): Conv2d(56, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)
    )
    (multiply): _Multiply()
    (conv): Conv2d(9, 9, kernel_size=(1, 1), stride=(1, 1), padding=same)
  )
  (cpt1_0): _Conv_unit1(
    (batchNorm2d): BatchNorm2d(201, eps=0.001, momentum=0.99, affine=False, track_running_stats=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (conv): Conv2d(201, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)
  )
  (cpt0_0): _Conv_unit0(
    (batchNorm2d): BatchNorm2d(201, eps=0.001, momentum=0.99, affine=False, track_running_stats=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (conv): Conv2d(201, 64

### Select Timestep Index

In [24]:
timestep_index = 1

In [29]:
sample = test_dataset[timestep_index]
X_c = sample["x_closeness"].type(torch.FloatTensor).to(device)
X_p = sample["x_period"].type(torch.FloatTensor).to(device)
X_t = sample["x_trend"].type(torch.FloatTensor).to(device)
Y_batch = sample["y_data"].type(torch.FloatTensor).to(device)
outputs = model(X_c, X_p, X_t)
topPickupDf = stm.get_cells_df((Y_batch[0] * min_max_difference + min_max_sum) / 2, "cell_ids", "num_pickups")

In [30]:
layer = stm.getGridLayer(topPickupDf, "num_pickups", "cell_ids", polygons)
layer.to_html()