In [1]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'
REGION = "us-central1"

# VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'
VERTEX_SA = 'jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"REGION: {REGION}")
print(f"VERTEX_SA: {VERTEX_SA}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
REGION: us-central1
VERTEX_SA: jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com


In [2]:
REGION = 'us-central1'
EXPERIMENT = 'control_group1'
SERIES = 'causal_impact_2'

BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-','_')
BQ_TABLE = EXPERIMENT

BQ_SOURCE1 = 'bigquery-public-data.new_york.citibike_trips'
BQ_SOURCE2 = 'bigquery-public-data.new_york.citibike_stations'

viz_limit = 12

In [3]:
from google.cloud import bigquery

import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime, timedelta

from google.cloud import aiplatform as vertex_ai

In [4]:
bq = bigquery.Client(project=PROJECT_ID)

vertex_ai.init(
    project=PROJECT_ID, 
    location=REGION,
    # credentials=credentials
)

In [5]:
# CUSTOMIZE
TARGET_COLUMN = 'num_trips'
TIME_COLUMN = 'starttime'
SERIES_COLUMN = 'start_station_name'
COVARIATE_COLUMNS = ['avg_tripduration', 'pct_subscriber', 'ratio_gender', 'capacity'] # could be empty

BQ_TABLE_GROUP_A="control_group1_grp_a"
# BQ_TABLE_GROUP_B="control_group1_grp_b"

In [6]:
query = f"""
    WITH
        SPLIT AS (
            SELECT splits, min({TIME_COLUMN}) as mindate, max({TIME_COLUMN}) as maxdate
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
            GROUP BY splits
        ),
        TRAIN AS (
            SELECT mindate as start_date
            FROM SPLIT
            WHERE splits ='TRAIN'
        ),
        VAL AS (
            SELECT mindate as val_start
            FROM SPLIT
            WHERE splits = 'VALIDATE'
        ),
        TEST AS (
            SELECT mindate as test_start, maxdate as end_date
            FROM SPLIT
            WHERE splits = 'TEST'
        )
    SELECT * EXCEPT(pos) FROM
    (SELECT *, ROW_NUMBER() OVER() pos FROM TRAIN)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM VAL)
    USING (pos)
    JOIN (SELECT *, ROW_NUMBER() OVER() pos FROM TEST)
    USING (pos)
"""
keyDates = bq.query(query).to_dataframe()
keyDates

Unnamed: 0,start_date,val_start,test_start,end_date
0,2013-07-01,2016-09-03,2016-09-17,2016-09-30


In [7]:
query = f"""
    SELECT {SERIES_COLUMN}, {TIME_COLUMN}, {TARGET_COLUMN}, splits,
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
    ORDER by {SERIES_COLUMN}, {TIME_COLUMN}
"""
rawSeries = bq.query(query).to_dataframe()

## Train MLR - Group A

In [8]:
# CUSTOMIZE
forecast_granularity = 'DAY'
forecast_horizon = 14
forecast_test_length = 14
#forecast_val_length = 14

In [9]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['{TARGET_COLUMN}']
      ) AS
    SELECT {TIME_COLUMN}, {TARGET_COLUMN},
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
    WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_2.control_group1_grp_a_mlr`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['num_trips']
      ) AS
    SELECT starttime, num_trips,
        avg_tripduration, pct_subscriber, ratio_gender, capacity
    FROM `hybrid-vertex.causal_impact_2.control_group1_grp_a`
    WHERE splits in ('TRAIN','VALIDATE')



In [10]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 17.757


### Review Input Features

In [11]:
query = f"""
    SELECT *
    FROM ML.FEATURE_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}_mlr`)
"""
featureInfo = bq.query(query).to_dataframe()
featureInfo.head()

Unnamed: 0,input,min,max,mean,median,stddev,category_count,null_count,dimension
0,starttime,,,,,,1170.0,0,
1,avg_tripduration,123.5,465018.2,945.78661,776.90625,4114.877052,,0,
2,pct_subscriber,0.0,1.0,0.870525,0.91,0.133274,,0,
3,ratio_gender,0.0,44.0,2.734149,2.305085,2.141457,,174,
4,capacity,19.0,114.0,55.700176,49.0,23.141132,,3829,


In [12]:
query = f"""
    SELECT *
    FROM ML.TRAINING_INFO(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}_mlr`)
"""
trainingInfo = bq.query(query).to_dataframe()
trainingInfo.head()

Unnamed: 0,training_run,iteration,loss,eval_loss,learning_rate,duration_ms
0,0,0,3326.261688,3854.984665,,2531


## Forecast Evaluation

In [13]:
query = f"""
    SELECT *
    FROM ML.EVALUATE(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}_mlr`,
        (
            SELECT {TIME_COLUMN}, {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
            WHERE splits = 'TEST'
        )
    )
"""
metrics = bq.query(query).to_dataframe()
metrics

Unnamed: 0,mean_absolute_error,mean_squared_error,mean_squared_log_error,median_absolute_error,r2_score,explained_variance
0,1303.414524,1706948.0,7.480687,1325.740787,-162.579057,0.227741


## Forecast Test Set

In [14]:
query = f"""
    SELECT *
    FROM ML.PREDICT(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}_mlr`,
        (
            SELECT {TIME_COLUMN}, {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
            WHERE splits = 'TEST'
        )
        )
"""
forecast = bq.query(query).to_dataframe()
forecast
# print(query)

Unnamed: 0,predicted_num_trips,starttime,num_trips,avg_tripduration,pct_subscriber,ratio_gender,capacity
0,1443.850476,2016-09-25,256,958.222656,0.726562,1.370370,
1,1440.894694,2016-09-25,258,1313.139535,0.666667,0.720000,
2,1448.086884,2016-09-26,261,863.915709,0.808429,2.107143,
3,1453.255953,2016-09-28,261,804.750958,0.931034,3.833333,
4,1457.849414,2016-09-22,517,802.326886,0.965184,2.615385,
...,...,...,...,...,...,...,...
266,1448.007857,2016-09-23,241,882.742739,0.792531,1.563830,
267,1454.654988,2016-09-28,243,991.773663,0.905350,2.115385,
268,1469.576502,2016-09-22,244,865.127049,0.905738,1.595745,66
269,1456.795178,2016-09-23,254,753.574803,0.940945,2.298701,
