<img src="images/logo_city.png" align="right" width="20%">

# Training A Graph State Prediction Network (GSPNet): Simplified version

This notebook is a prototype on how to train a graph state prediction network using CNN and LSTM.
The content is devided into __3__ parts:

1. Data preprocessing
2. Model building, Training and Tuning
3. Prediction and per demand modifacation

The model is built with [PyTorch](https://pytorch.org/).

<p style='color: darkred'><strong>This version is simplified and most of intermediate explainatory codes and comments are removed except those directly contribute to generating images. For duck process, see the prototype version.</strong></p>

## Part 1: Data Preprocessing

### 1. Load Libraries and Data

Load in the data with dask. Check the data integrity.

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import sqlalchemy
import dask
import dask.dataframe as dd
import time
import psycopg2
import warnings
import re
import torch
from PIL import Image
from IPython.display import display
from scipy import stats
from matplotlib import pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Load data:

In [6]:
table = dd.read_csv('dataset/nytaxi_yellow_2017_mar.csv')
table.head()

Unnamed: 0,tripid,tpep_pickup_datetime,tpep_dropoff_datetime,pulocationid,dolocationid,trip_distance,passenger_count,total_amount,trip_time,trip_avg_speed,trip_time_sec
0,35972297,2017-03-09 21:30:11,2017-03-09 21:44:20,148,48,4.06,1,18.36,00:14:09,17.215548,849
1,35972298,2017-03-09 21:47:00,2017-03-09 21:58:01,48,107,2.73,1,12.8,00:11:01,14.868381,661
2,35972299,2017-03-09 22:01:08,2017-03-09 22:11:16,79,162,2.27,1,14.12,00:10:08,13.440789,608
3,35972300,2017-03-09 22:16:05,2017-03-10 06:26:11,237,41,3.86,1,17.29,08:10:06,0.472557,29406
4,35972301,2017-03-31 06:31:53,2017-03-31 06:41:48,41,162,3.45,1,13.3,00:09:55,20.87395,595


Check table shape:

In [3]:
table.shape[0].compute(), table.shape[1]

(10222693, 11)

Check data:

Ratio of trips last longer than 30 minutes: (better smaller than 0.1)

In [5]:
(table.loc[table['trip_time_sec'] > 1800].shape[0] / table.shape[0]).compute()

0.08854594381343546

Project only needed columns:

In [14]:
# draw four needed columns and a id column, then sort according to time.
# the `tripid` column is for the sake of naming.
tensor_gen = table.loc[:,['tripid',
                          'tpep_pickup_datetime',
                          'tpep_dropoff_datetime',
                          'pulocationid', 'dolocationid']
                      ] # the first sort condition rules
tensor_gen.head()

Unnamed: 0,tripid,tpep_pickup_datetime,tpep_dropoff_datetime,pulocationid,dolocationid
0,35972297,2017-03-09 21:30:11,2017-03-09 21:44:20,148,48
1,35972298,2017-03-09 21:47:00,2017-03-09 21:58:01,48,107
2,35972299,2017-03-09 22:01:08,2017-03-09 22:11:16,79,162
3,35972300,2017-03-09 22:16:05,2017-03-10 06:26:11,237,41
4,35972301,2017-03-31 06:31:53,2017-03-31 06:41:48,41,162


### 2. Time Interval Preprocessing

Create time interval and process original table to generate images.

Change datetime columns from str type to Timestamp:

In [15]:
# Preprocess datatime columns here
# Very expensive operation, do once.
tensor_gen['tpep_pickup_datetime'] = dd.to_datetime(tensor_gen['tpep_pickup_datetime'])
tensor_gen['tpep_dropoff_datetime'] = dd.to_datetime(tensor_gen['tpep_dropoff_datetime'])

tensor_gen.dtypes

tripid                            int64
tpep_pickup_datetime     datetime64[ns]
tpep_dropoff_datetime    datetime64[ns]
pulocationid                      int64
dolocationid                      int64
dtype: object

Function to create time **intervals**:

In [16]:
intervals = pd.date_range('2017-03-01 00:00:00', '2017-04-01 00:00:00', freq='10min')
print(len(intervals))
intervals # <-- to be used later

4465


DatetimeIndex(['2017-03-01 00:00:00', '2017-03-01 00:10:00',
               '2017-03-01 00:20:00', '2017-03-01 00:30:00',
               '2017-03-01 00:40:00', '2017-03-01 00:50:00',
               '2017-03-01 01:00:00', '2017-03-01 01:10:00',
               '2017-03-01 01:20:00', '2017-03-01 01:30:00',
               ...
               '2017-03-31 22:30:00', '2017-03-31 22:40:00',
               '2017-03-31 22:50:00', '2017-03-31 23:00:00',
               '2017-03-31 23:10:00', '2017-03-31 23:20:00',
               '2017-03-31 23:30:00', '2017-03-31 23:40:00',
               '2017-03-31 23:50:00', '2017-04-01 00:00:00'],
              dtype='datetime64[ns]', length=4465, freq='10T')

### 3. Location Preprocessing

Preprocess location related information: mapping ids and locations, then generate adjacency matrices.

### 4. Generate Tensor: 3 matrices (layers) of connection

Generate images that represent traffic states.