In [1]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'
REGION = "us-central1"

# VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'
VERTEX_SA = 'jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"REGION: {REGION}")
print(f"VERTEX_SA: {VERTEX_SA}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
REGION: us-central1
VERTEX_SA: jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com


In [2]:
REGION = 'us-central1'
EXPERIMENT = 'control_group1'
SERIES = 'causal_impact_2'

BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-','_')
BQ_TABLE = EXPERIMENT

BQ_SOURCE1 = 'bigquery-public-data.new_york.citibike_trips'
BQ_SOURCE2 = 'bigquery-public-data.new_york.citibike_stations'

viz_limit = 12

In [3]:
from google.cloud import bigquery

import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime, timedelta

from google.cloud import aiplatform as vertex_ai

bq = bigquery.Client(project=PROJECT_ID)

vertex_ai.init(
    project=PROJECT_ID, 
    location=REGION,
    # credentials=credentials
)

In [5]:
# CUSTOMIZE
TARGET_COLUMN = 'num_trips'
TIME_COLUMN = 'starttime'
SERIES_COLUMN = 'start_station_name'
COVARIATE_COLUMNS = ['avg_tripduration', 'pct_subscriber', 'ratio_gender', 'capacity'] # could be empty

# BQ_TABLE_GROUP_A="control_group1_grp_a"
BQ_TABLE_GROUP_B="control_group1_grp_b"

In [6]:
query = f"""
    WITH
        SPLIT AS (
            SELECT splits, min({TIME_COLUMN}) as mindate, max({TIME_COLUMN}) as maxdate
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            GROUP BY splits
        ),
        TRAIN AS (
            SELECT mindate as start_date
            FROM SPLIT
            WHERE splits ='TRAIN'
        ),
        VAL AS (
            SELECT mindate as val_start
            FROM SPLIT
            WHERE splits = 'VALIDATE'
        ),
        TEST AS (
            SELECT mindate as test_start, maxdate as end_date
            FROM SPLIT
            WHERE splits = 'TEST'
        )
    SELECT * EXCEPT(pos) FROM
    (SELECT *, ROW_NUMBER() OVER() pos FROM TRAIN)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM VAL)
    USING (pos)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM TEST)
    USING (pos)
"""
keyDates = bq.query(query).to_dataframe()
keyDates

Unnamed: 0,start_date,val_start,test_start,end_date
0,2013-07-01,2016-09-03,2016-09-17,2016-09-30


In [7]:
query = f"""
    SELECT {SERIES_COLUMN}, {TIME_COLUMN}, {TARGET_COLUMN}, splits,
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
    ORDER by {SERIES_COLUMN}, {TIME_COLUMN}
"""
rawSeries = bq.query(query).to_dataframe()

In [17]:
rawSeries.tail(50)

Unnamed: 0,start_station_name,starttime,num_trips,splits,avg_tripduration,pct_subscriber,ratio_gender,capacity
16949,Washington Pl & Broadway,2016-06-18,89,TRAIN,686.247191,0.719101,1.069767,27
16950,Washington Pl & Broadway,2016-06-19,72,TRAIN,826.708333,0.847222,1.666667,27
16951,Washington Pl & Broadway,2016-06-20,141,TRAIN,722.680851,0.971631,2.27907,27
16952,Washington Pl & Broadway,2016-06-21,150,TRAIN,646.6,0.973333,2.409091,27
16953,Washington Pl & Broadway,2016-06-23,127,TRAIN,654.346457,0.96063,2.02381,27
16954,Washington Pl & Broadway,2016-06-24,136,TRAIN,888.455882,0.941176,2.090909,27
16955,Washington Pl & Broadway,2016-06-27,121,TRAIN,710.710744,0.966942,3.321429,27
16956,Washington Pl & Broadway,2016-06-29,136,TRAIN,676.889706,0.941176,3.533333,27
16957,Washington Pl & Broadway,2016-06-30,140,TRAIN,706.1,0.921429,1.916667,27
16958,Washington Pl & Broadway,2016-07-03,77,TRAIN,793.675325,0.571429,0.711111,27


## Train MLR - Group A

In [8]:
# CUSTOMIZE
forecast_granularity = 'DAY'
forecast_horizon = 14
forecast_test_length = 14
#forecast_val_length = 14

In [9]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['{TARGET_COLUMN}']
      ) AS
    SELECT {TIME_COLUMN}, {TARGET_COLUMN},
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
    WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_2.control_group1_grp_b_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['num_trips']
      ) AS
    SELECT starttime, num_trips,
        avg_tripduration, pct_subscriber, ratio_gender, capacity
    FROM `hybrid-vertex.causal_impact_2.control_group1_grp_b`
    WHERE splits in ('TRAIN','VALIDATE')



In [10]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 15.758


### Review Input Features

In [11]:
query = f"""
    SELECT *
    FROM ML.FEATURE_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`)
"""
featureInfo = bq.query(query).to_dataframe()
featureInfo.head()

Unnamed: 0,input,min,max,mean,median,stddev,category_count,null_count,dimension
0,starttime,,,,,,1170.0,0,
1,avg_tripduration,79.0,52736.768116,924.828415,778.598425,1052.170266,,0,
2,pct_subscriber,0.0,1.0,0.870197,0.907692,0.132956,,0,
3,ratio_gender,0.0,36.0,2.748721,2.285714,2.190748,,175,
4,capacity,19.0,114.0,55.856575,55.0,23.271001,,3828,


In [12]:
query = f"""
    SELECT *
    FROM ML.TRAINING_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`)
"""
trainingInfo = bq.query(query).to_dataframe()
trainingInfo.head()

Unnamed: 0,training_run,iteration,loss,eval_loss,learning_rate,duration_ms
0,0,0,3483.702781,3808.823979,,2932


## Forecast Evaluation

In [13]:
query = f"""
    SELECT *
    FROM ML.EVALUATE(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
        (
            SELECT {TIME_COLUMN}, {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            WHERE splits = 'TEST'
        )
    )
"""
metrics = bq.query(query).to_dataframe()
metrics

Unnamed: 0,mean_absolute_error,mean_squared_error,mean_squared_log_error,median_absolute_error,r2_score,explained_variance
0,137.259329,27801.765211,5.318435,113.101263,-1.638028,0.145994


## Forecast Test Set

In [18]:
query = f"""
    SELECT *
    FROM ML.PREDICT(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
        (
            SELECT 
                {TIME_COLUMN}, 
                {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            WHERE splits = 'TEST'
        )
        )
"""
forecast = bq.query(query).to_dataframe()
forecast
# print(query)

Unnamed: 0,predicted_num_trips,starttime,num_trips,avg_tripduration,pct_subscriber,ratio_gender,capacity,start_station_name
0,13.624835,2016-09-20,258,954.251938,0.844961,1.898876,,Broadway & W 60 St
1,21.487411,2016-09-22,4,920.500000,1.000000,3.000000,55,Broadway & W 41 St
2,-14.835330,2016-09-30,5,390.600000,1.000000,4.000000,29,Lorimer St & Broadway
3,15.083044,2016-09-24,261,841.567050,0.858238,1.610000,,Broadway & E 22 St
4,-1.352712,2016-09-28,262,657.694656,0.931298,2.638889,41,Broadway & W 29 St
...,...,...,...,...,...,...,...,...
262,45.443082,2016-09-21,248,954.931452,0.866935,2.263158,79,Broadway & W 58 St
263,18.990593,2016-09-27,251,737.872510,0.944223,2.861538,,Broadway & W 32 St
264,17.699815,2016-09-20,251,966.007968,0.936255,3.254237,,Broadway & W 55 St
265,14.978212,2016-09-18,252,796.444444,0.853175,1.470588,,Broadway & E 22 St


In [19]:
# CUSTOMIZE
query = f"""
CREATE OR REPLACE TABLE `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_pred_Test` AS (
    SELECT * FROM ML.PREDICT(
            MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
            (
                SELECT
                {TIME_COLUMN}, 
                {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
                WHERE splits = 'TEST'
            )
            )
)
"""
job = bq.query(query = query)
job.result()
(job.ended-job.started).total_seconds()

1.636

In [None]:
query = f"""
    SELECT * 
    FROM `{GROUP_B_PREDS_BQ_URI}`
    ORDER BY starttime DESC;
"""
groupb_test_preds = bq.query(query = query).to_dataframe()

print(f"Shape: {groupb_test_preds.shape}")
groupb_test_preds.head(10)