In [18]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'
REGION = "us-central1"

# VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'
VERTEX_SA = 'jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"REGION: {REGION}")
print(f"VERTEX_SA: {VERTEX_SA}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
REGION: us-central1
VERTEX_SA: jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com


In [19]:
REGION = 'us-central1'
EXPERIMENT = 'control_group1'
SERIES = 'causal_impact_4'

BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-','_')
BQ_TABLE = EXPERIMENT

BQ_SOURCE1 = 'bigquery-public-data.new_york.citibike_trips'
BQ_SOURCE2 = 'bigquery-public-data.new_york.citibike_stations'

viz_limit = 12

In [20]:
from google.cloud import bigquery

import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime, timedelta

from google.cloud import aiplatform as vertex_ai

bq = bigquery.Client(project=PROJECT_ID)

vertex_ai.init(
    project=PROJECT_ID, 
    location=REGION,
    # credentials=credentials
)

In [21]:
# CUSTOMIZE
TARGET_COLUMN = 'num_trips'
TIME_COLUMN = 'starttime'
SERIES_COLUMN = 'start_station_name'
COVARIATE_COLUMNS = [
    'avg_tripduration', 
    'pct_subscriber', 
    'ratio_gender', 
    'capacity'
] # could be empty

# BQ_TABLE_GROUP_A="control_group1_grp_a"
BQ_TABLE_GROUP_B="control_group1_grp_b"

In [22]:
query = f"""
    WITH
        SPLIT AS (
            SELECT splits, min({TIME_COLUMN}) as mindate, max({TIME_COLUMN}) as maxdate
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            GROUP BY splits
        ),
        TRAIN AS (
            SELECT mindate as start_date
            FROM SPLIT
            WHERE splits ='TRAIN'
        ),
        VAL AS (
            SELECT mindate as val_start
            FROM SPLIT
            WHERE splits = 'VALIDATE'
        ),
        TEST AS (
            SELECT mindate as test_start, maxdate as end_date
            FROM SPLIT
            WHERE splits = 'TEST'
        )
    SELECT * EXCEPT(pos) FROM
    (SELECT *, ROW_NUMBER() OVER() pos FROM TRAIN)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM VAL)
    USING (pos)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM TEST)
    USING (pos)
"""
keyDates = bq.query(query).to_dataframe()
keyDates

Unnamed: 0,start_date,val_start,test_start,end_date
0,2013-07-01,2016-05-14,2016-07-23,2016-09-30


In [23]:
query = f"""
    SELECT {SERIES_COLUMN}, {TIME_COLUMN}, {TARGET_COLUMN}, splits,
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
    ORDER by {SERIES_COLUMN}, {TIME_COLUMN}
"""
rawSeries = bq.query(query).to_dataframe()

In [24]:
rawSeries.tail(50)

Unnamed: 0,start_station_name,starttime,num_trips,splits,avg_tripduration,pct_subscriber,ratio_gender,capacity
4189,Marcy Ave & Lafayette Ave,2016-06-25,9,VALIDATE,952.222222,0.777778,0.8,23
4190,Marcy Ave & Lafayette Ave,2016-06-26,13,VALIDATE,1302.615385,0.692308,0.857143,23
4191,Marcy Ave & Lafayette Ave,2016-06-28,4,VALIDATE,779.75,1.0,0.0,23
4192,Marcy Ave & Lafayette Ave,2016-06-30,10,VALIDATE,1281.1,0.9,4.0,23
4193,Marcy Ave & Lafayette Ave,2016-07-02,17,VALIDATE,1179.647059,0.764706,0.545455,23
4194,Marcy Ave & Lafayette Ave,2016-07-03,4,VALIDATE,1033.25,0.75,0.333333,23
4195,Marcy Ave & Lafayette Ave,2016-07-04,6,VALIDATE,908.833333,0.833333,2.0,23
4196,Marcy Ave & Lafayette Ave,2016-07-05,6,VALIDATE,1491.333333,1.0,2.0,23
4197,Marcy Ave & Lafayette Ave,2016-07-06,7,VALIDATE,1033.571429,0.714286,1.333333,23
4198,Marcy Ave & Lafayette Ave,2016-07-07,7,VALIDATE,1180.0,1.0,0.166667,23


## Train MLR - Group B

In [25]:
# CUSTOMIZE
forecast_granularity = 'DAY'
forecast_horizon = 7 #14
forecast_test_length = 14
#forecast_val_length = 14

In [9]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['{TARGET_COLUMN}']
      ) AS
    SELECT {TIME_COLUMN}, {TARGET_COLUMN},
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
    WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_3.control_group1_grp_b_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['num_trips']
      ) AS
    SELECT starttime, num_trips,
        avg_tripduration, pct_subscriber, ratio_gender
    FROM `hybrid-vertex.causal_impact_3.control_group1_grp_b`
    WHERE splits in ('TRAIN','VALIDATE')



In [10]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 15.811


### Review Input Features

In [11]:
query = f"""
    SELECT *
    FROM ML.FEATURE_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`)
"""
featureInfo = bq.query(query).to_dataframe()
featureInfo.head()

Unnamed: 0,input,min,max,mean,median,stddev,category_count,null_count,dimension
0,starttime,,,,,,1114.0,0,
1,avg_tripduration,123.5,44411.216216,916.334062,774.007092,977.067231,,0,
2,pct_subscriber,0.0,1.0,0.872957,0.913043,0.133236,,0,
3,ratio_gender,0.0,44.0,2.810144,2.333333,2.26681,,166,
4,capacity,19.0,114.0,56.312325,55.0,23.09984,,3567,


In [12]:
query = f"""
    SELECT *
    FROM ML.TRAINING_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`)
"""
trainingInfo = bq.query(query).to_dataframe()
trainingInfo.head()

Unnamed: 0,training_run,iteration,loss,eval_loss,learning_rate,duration_ms
0,0,0,3197.432983,3873.933853,,2831


## Forecast Evaluation

In [13]:
query = f"""
    SELECT *
    FROM ML.EVALUATE(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
        (
            SELECT {TIME_COLUMN}, {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            WHERE splits = 'TEST'
        )
    )
"""
metrics = bq.query(query).to_dataframe()
metrics

Unnamed: 0,mean_absolute_error,mean_squared_error,mean_squared_log_error,median_absolute_error,r2_score,explained_variance
0,10660.613641,113657800.0,22.102688,10685.116292,-12154.285054,0.020232


## Forecast Test Set

In [14]:
query = f"""
    SELECT *
    FROM ML.PREDICT(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
        (
            SELECT 
                {TIME_COLUMN}, 
                {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
            WHERE splits = 'TEST'
        )
        )
"""
forecast = bq.query(query).to_dataframe()
forecast
# print(query)

Unnamed: 0,predicted_num_trips,starttime,num_trips,avg_tripduration,pct_subscriber,ratio_gender,start_station_name
0,10790.671655,2016-09-25,256,958.222656,0.726562,1.370370,Broadway & W 24 St
1,10797.808615,2016-08-26,256,797.410156,0.839844,2.047619,Broadway & W 49 St
2,10786.673324,2016-09-25,258,1313.139535,0.666667,0.720000,Broadway & W 60 St
3,10784.396166,2016-08-21,258,1139.682171,0.627907,0.804196,Broadway & W 60 St
4,10786.493156,2016-08-06,514,1087.603113,0.659533,0.736486,E 17 St & Broadway
...,...,...,...,...,...,...,...
1308,10804.341365,2016-09-15,508,731.131890,0.944882,2.628571,Broadway & E 22 St
1309,10800.996973,2016-09-26,253,769.075099,0.897233,3.362069,Broadway & W 39 St
1310,10804.186539,2016-09-23,254,753.574803,0.940945,2.298701,Broadway & W 32 St
1311,10786.155250,2016-08-07,255,853.988235,0.650980,0.961538,Broadway & W 24 St


In [15]:
# CUSTOMIZE
query = f"""
CREATE OR REPLACE TABLE `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_pred_Test` AS (
    SELECT * FROM ML.PREDICT(
            MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}_mlr`,
            (
                SELECT
                {TIME_COLUMN}, 
                {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_B}`
                WHERE splits = 'TEST'
            )
            )
)
"""
job = bq.query(query = query)
job.result()
(job.ended-job.started).total_seconds()

1.609

In [49]:
BQ_DATASET

'causal_impact_3'