# Double ML Causal Inference

Double machine learning is an attempt to understand the effect a treatment has on a response without being unduly influenced by the covariates. We want to try and isolate the effects of a treatment and not an of the other covariates. The method happens with a number of steps as follows:

1) Split the data into two sets.
2) For each data set model the response based on the covariates.
3) For each data set model the treatment based on the covariates.
4) Calculate the response residuals.
5) Calculate the treatment residuals.
6) Regress the response residuals on the treatment residuals to get the treatment effect.

refer to [Double Machine Learning - an easy introduction](https://dm13450.github.io/2021/05/28/Double-ML.html?msclkid=8cc6c026a60911ec874bf7397d3a9b6e) for more details

in summary:

* Train 2 regression models. One is to predict the outcome variable in terms of relevant covariates. And another to predict the treatment variable in terms of covariates.
* Compute the residuals of each model. In other words, if f(Z) estimates Y in terms of Z and g(Z) estimates X in terms of Z, their residuals are given by U = Y-f(Z) and V = X-g(Z), respectively.
* Compute treatment effect. Using the residuals we can compute the treatment effect directly using the equation below.

In [1]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'
REGION = "us-central1"

# VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'
VERTEX_SA = 'jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"REGION: {REGION}")
print(f"VERTEX_SA: {VERTEX_SA}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
REGION: us-central1
VERTEX_SA: jt-vertex-sa@hybrid-vertex.iam.gserviceaccount.com


In [2]:
REGION = 'us-central1'
EXPERIMENT = 'control_group1'
SERIES = 'causal_impact_4'

BQ_PROJECT = PROJECT_ID
BQ_DATASET = SERIES.replace('-','_')
BQ_TABLE = EXPERIMENT

BQ_SOURCE1 = 'bigquery-public-data.new_york.citibike_trips'
BQ_SOURCE2 = 'bigquery-public-data.new_york.citibike_stations'

viz_limit = 9

In [3]:
from google.cloud import bigquery

import pandas as pd
import numpy as np
from datetime import datetime, timedelta

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

from time import sleep

from google.cloud import aiplatform as vertex_ai

bq = bigquery.Client(project=PROJECT_ID)

vertex_ai.init(
    project=PROJECT_ID, 
    location=REGION,
    # credentials=credentials
)

In [27]:
# CUSTOMIZE
TARGET_COLUMN = 'num_trips'
TIME_COLUMN = 'starttime'
SERIES_COLUMN = 'start_station_name'
SPLIT_COLUMN = 'splits'
COVARIATE_COLUMNS = [
    'avg_tripduration', 
    'pct_subscriber', 
    'ratio_gender', 
    'capacity'
] # could be empty

BQ_TABLE_GROUP_A="control_group1_grp_a"
# BQ_TABLE_GROUP_B="control_group1_grp_b"

GROUP_A_MODEL_PREFIX="grp_a"
GROUP_B_MODEL_PREFIX="grp_b"

forecast_granularity = 'DAY'
forecast_horizon = 7 #14
forecast_test_length = 14
#forecast_val_length = 14

## Train Y linear learner

In [6]:
model_1_name = f'{GROUP_A_MODEL_PREFIX}_y_linear_learner_model'

In [7]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_1_name}`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['{TARGET_COLUMN}'],
       enable_global_explain=TRUE
      ) AS
    SELECT {TIME_COLUMN}, {TARGET_COLUMN},
        {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
    WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_4.grp_a_y_linear_learner_model`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['num_trips'],
       enable_global_explain=TRUE
      ) AS
    SELECT starttime, num_trips,
        avg_tripduration, pct_subscriber, ratio_gender, capacity
    FROM `hybrid-vertex.causal_impact_4.control_group1_grp_a`
    WHERE splits in ('TRAIN','VALIDATE')



In [8]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 35.759


In [22]:
COVARIATE_COLUMNS

['avg_tripduration', 'ratio_gender', 'capacity']

In [24]:
query = f"""
    SELECT *
    FROM ML.PREDICT(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_1_name}`,
        (
            SELECT 
                {TIME_COLUMN}, 
                {TARGET_COLUMN},
                {', '.join(COVARIATE_COLUMNS)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
        )
        )
"""
forecast = bq.query(query).to_dataframe()
forecast.head(3)

Unnamed: 0,predicted_num_trips,starttime,num_trips,avg_tripduration,pct_subscriber,ratio_gender,capacity,start_station_name
0,19.462041,2016-07-11,284,773.496479,0.929577,2.021277,0,Lafayette St & Jersey St N
1,51.716156,2016-07-15,278,678.579137,0.902878,2.159091,0,Lafayette St & Jersey St N
2,-9.165523,2016-07-01,227,797.577093,0.929515,2.721311,0,Lafayette St & Jersey St N


In [26]:
COVARIATE_COLUMNS

['avg_tripduration', 'ratio_gender', 'capacity']

## Train Z linear learner

In [28]:
Z_TARGET = "pct_subscriber"
Z_COVARIATES = [
    'avg_tripduration', 
    # 'pct_subscriber', 
    'ratio_gender', 
    'capacity'
] # could be empty

In [17]:
Z_TARGET

'pct_subscriber'

In [14]:
model_2_name = f'{GROUP_A_MODEL_PREFIX}_z_linear_learner_model'

In [18]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_2_name}`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['{Z_TARGET}'],
       enable_global_explain=TRUE
      ) AS
    SELECT {TIME_COLUMN}, {Z_TARGET},
        {', '.join(Z_COVARIATES)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
    WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_4.grp_a_z_linear_learner_model`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['pct_subscriber'],
       enable_global_explain=TRUE
      ) AS
    SELECT starttime, pct_subscriber,
        avg_tripduration, ratio_gender, capacity
    FROM `hybrid-vertex.causal_impact_4.control_group1_grp_a`
    WHERE splits in ('TRAIN','VALIDATE')



In [19]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 30.019


In [20]:
query = f"""
    SELECT *
    FROM ML.PREDICT(
        MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_2_name}`,
        (
            SELECT 
                {TIME_COLUMN}, 
                {Z_TARGET},
                {', '.join(Z_COVARIATES)},
                {SERIES_COLUMN}
            FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
        )
        )
"""
forecast = bq.query(query).to_dataframe()
forecast.head(3)

Unnamed: 0,predicted_pct_subscriber,starttime,pct_subscriber,avg_tripduration,ratio_gender,capacity,start_station_name
0,0.900576,2016-07-11,0.929577,773.496479,2.021277,0,Lafayette St & Jersey St N
1,0.92222,2016-07-15,0.902878,678.579137,2.159091,0,Lafayette St & Jersey St N
2,0.942105,2016-07-01,0.929515,797.577093,2.721311,0,Lafayette St & Jersey St N


## Create Residuals Table

In [31]:
query = f"""
    CREATE OR REPLACE TABLE `{BQ_PROJECT}.{BQ_DATASET}.double_ml_residual_table` AS (
    SELECT
        a.start_station_name,
        a.starttime,
        a.avg_tripduration,
        a.ratio_gender,
        a.capacity,
        a.predicted_num_trips,
        a.num_trips,
        (a.num_trips - a.predicted_num_trips) as residuals_u,
        b.predicted_pct_subscriber,
        b.pct_subscriber,
        (b.pct_subscriber - b.predicted_pct_subscriber) as residuals_z_pct,
        -- b.start_station_name,
        -- b.starttime,
    FROM ML.PREDICT( MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_1_name}`,
            (
                SELECT 
                    {TIME_COLUMN}, 
                    {TARGET_COLUMN},
                    {', '.join(COVARIATE_COLUMNS)},
                    {SERIES_COLUMN}
                FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
            )) a
    JOIN
     ML.PREDICT( MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_2_name}`,
            (
                SELECT 
                    {TIME_COLUMN}, 
                    {Z_TARGET},
                    {', '.join(Z_COVARIATES)},
                    {SERIES_COLUMN}
                FROM `{BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE_GROUP_A}`
            )) b
    ON a.starttime=b.starttime AND a.start_station_name = b.start_station_name
    );
"""
print(query)


    CREATE OR REPLACE TABLE `hybrid-vertex.causal_impact_4.double_ml_residual_table` AS (
    SELECT
        a.start_station_name,
        a.starttime,
        a.avg_tripduration,
        a.ratio_gender,
        a.capacity,
        a.predicted_num_trips,
        a.num_trips,
        (a.num_trips - a.predicted_num_trips) as residuals_u,
        b.predicted_pct_subscriber,
        b.pct_subscriber,
        (b.pct_subscriber - b.predicted_pct_subscriber) as residuals_z_pct,
        -- b.start_station_name,
        -- b.starttime,
    FROM ML.PREDICT( MODEL `hybrid-vertex.causal_impact_4.grp_a_y_linear_learner_model`,
            (
                SELECT 
                    starttime, 
                    num_trips,
                    avg_tripduration, pct_subscriber, ratio_gender, capacity,
                    start_station_name
                FROM `hybrid-vertex.causal_impact_4.control_group1_grp_a`
            )) a
    JOIN
     ML.PREDICT( MODEL `hybrid-vertex.causal_impact_4.grp_a

In [32]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 2.493


In [33]:
residual_table_df = bq.query(query).to_dataframe()
residual_table_df.head(3)

## Train Residuals Lienar Model

In [34]:
model_3_name ="residuals_reg_a"

In [39]:
query = f"""
    CREATE OR REPLACE MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_3_name}`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['residuals_u'],
       enable_global_explain=TRUE,
       CALCULATE_P_VALUES=TRUE,
       CATEGORY_ENCODING_METHOD="DUMMY_ENCODING"
      ) AS
    SELECT
        residuals_u,
        residuals_z_pct
        -- {', '.join(COVARIATE_COLUMNS)}
    FROM `{BQ_PROJECT}.{BQ_DATASET}.double_ml_residual_table`
    -- WHERE splits in ('TRAIN','VALIDATE')
"""
print(query)


    CREATE OR REPLACE MODEL `hybrid-vertex.causal_impact_4.residuals_reg_a`
    OPTIONS
      (model_type = 'linear_reg',
       input_label_cols = ['residuals_u'],
       enable_global_explain=TRUE,
       CALCULATE_P_VALUES=TRUE,
       CATEGORY_ENCODING_METHOD="DUMMY_ENCODING"
      ) AS
    SELECT
        residuals_u,
        residuals_z_pct
        -- avg_tripduration, pct_subscriber, ratio_gender, capacity
    FROM `hybrid-vertex.causal_impact_4.double_ml_residual_table`
    -- WHERE splits in ('TRAIN','VALIDATE')



In [40]:
job = bq.query(query)
job.result()
print(job.state, (job.ended-job.started).total_seconds())

DONE 23.09


## Export Residuals Linear Model

In [41]:
query = f"""
    SELECT
     *
    FROM
     ML.ADVANCED_WEIGHTS(MODEL `{BQ_PROJECT}.{BQ_DATASET}.{model_3_name}`)
"""
trainingInfo = bq.query(query).to_dataframe()
trainingInfo.head()

Unnamed: 0,processed_input,category,weight,standard_error,p_value
0,residuals_z_pct,,43.69787,8.329015,6e-06
1,__INTERCEPT__,,1.600385,1.087892,0.141166


Assume we want to calculate the effect of the number of agents on the profit of a market center?

$$Y_{profit} = a.N_{agents}+f(X_{external})$$

we want to find $a$, and the confidence intervals of $a$. 
Use some basic ML model to estimate:

$$Y_{profit} = g(X_{external})+\epsilon_{g}$$
$$N_{agents} = h(X_{external})+\epsilon_{h}$$

Estimate 

$$\epsilon_{g} = a_{\epsilon}.\epsilon_{h}+\sigma_{\epsilon}$$



$a_{\epsilon}$ is a reasonable estimate of $a$.
 


$$\hat{Y}=a.X+b.Z+\sigma$$

$$\hat{Y}=a_yZ+\sigma_y$$
$$\hat{X}=a_xZ+\sigma_x$$
then we fit a linear model against the residuals 
$$Y-\hat{Y}=\hat{a}.(Z-\hat{Z})+\sigma_a$$

$$\hat{a}$$
