# Matching pipeline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.

In [1]:

from hypex import Matching
from hypex.dataset import (
    Dataset,
    FeatureRole,
    GroupingRole,
    InfoRole,
    TargetRole,
    TreatmentRole,
)

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.
* TreatmentRole: a role for columns that show the treatment or intervention.
* TargetRole: a role for columns that show the target or outcome variable.
* InfoRole: a role for columns that contain information about the data, such as user IDs.

In [34]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float),
    },
    data="data.csv",
    default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [35]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Simple Matching  
Now matching has 4 steps: 
1. Dummy Encoder 
2. Process Mahalanobis distance 
3. Two sides pairs searching by faiss 
4. Metrics (ATT, ATC, ATE) estimation depends on your data 

In [36]:
data = data.fillna(method="bfill")

In [37]:
test = Matching(quality_tests=["t-test", "ks-test"])
result = test.execute(data)

**ATT** shows the difference in treated group.   
**ATC** shows the difference in untreated group.   
**ATE** shows the weighted average difference between ATT and ATC.  

In [39]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,2.45,0.0,58.57,68.16,post_spends
ATC,96.47,1.57,0.0,93.4,99.55,post_spends
ATE,80.13,1.44,0.0,77.31,82.95,post_spends


In [40]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,9433,1,1,488.5,518.444444,37.0,F,Logistics
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,5438,0,0,529.0,417.111111,23.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5165,0,0,498.5,412.222222,25.0,F,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,1735,1,1,504.0,516.333333,33.0,M,Logistics
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539,0,0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893,0,0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,7731,1,1,500.0,515.888889,25.0,M,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066,0,0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,1885,0,0,499.0,423.000000,67.0,F,Logistics


In [41]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value,KSTest pass,KSTest p-value
0,post_spends,1,NOT OK,0.0,NOT OK,0.0


In [9]:
result.indexes

Unnamed: 0,indexes
0,9433
1,5438
2,5165
3,1735
4,539
...,...
9995,5893
9996,7731
9997,7066
9998,1885


In [10]:
result.full_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>),
 'user_id_matched': Info(<class 'int'>),
 'treat_matched': Treatment(<class 'int'>),
 'post_spends_matched': Target(<class 'float'>),
 'signup_month_matched': Feature(<class 'int'>),
 'pre_spends_matched': Feature(<class 'float'>),
 'age_matched': Feature(<class 'float'>),
 'gender_matched': Feature(<class 'str'>),
 'industry_matched': Feature(<class 'str'>)}

We can change **metric** and do estimation again.

In [12]:
test = Matching(metric="atc")
result = test.execute(data)

In [13]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATC,96.47,0.14,0.0,96.21,96.74,post_spends


In [14]:
result.indexes

Unnamed: 0,indexes
0,9433
1,-1
2,-1
3,1735
4,-1
...,...
9995,-1
9996,7731
9997,-1
9998,-1


Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [15]:
test = Matching(metric='att')
result = test.execute(data)

In [16]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.28,post_spends


In [17]:
result.indexes

Unnamed: 0,indexes
0,-1
1,5438
2,5165
3,-1
4,539
...,...
9995,5893
9996,-1
9997,7066
9998,1885


In [18]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,5438.0,0.0,0.0,529.0,417.111111,23.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5165.0,0.0,0.0,498.5,412.222222,25.0,F,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539.0,0.0,0.0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066.0,0.0,0.0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,1885.0,0.0,0.0,499.0,423.000000,67.0,F,Logistics


Finally, we may search pairs in L2 distance. 

In [19]:
test = Matching(distance="l2", metric='att')
result = test.execute(data)

In [20]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.27,post_spends


In [21]:
result.indexes

Unnamed: 0,indexes
0,-1
1,2490
2,5493
3,-1
4,321
...,...
9995,5893
9996,-1
9997,8670
9998,507


In [22]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,2490.0,0.0,0.0,511.5,417.444444,27.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5493.0,0.0,0.0,483.0,408.000000,25.0,M,E-commerce
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,321.0,0.0,0.0,538.0,421.444444,29.0,M,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,8670.0,0.0,0.0,473.0,415.777778,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,507.0,0.0,0.0,495.0,429.777778,67.0,F,Logistics


## Group Matching

Finds the matches strictly within the groups defined by GroupRole.

In [24]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float),
        "gender": GroupingRole(str),
    },
    data="data.csv",
    default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [25]:
data = data.fillna(method="bfill")
test = Matching(group_match=True)
result = test.execute(data)

100%|██████████| 2/2 [00:00<00:00,  2.03it/s]


In [26]:
result.resume

Unnamed: 0,F Effect Size,M Effect Size,F Standard Error,M Standard Error,F P-value,M P-value,F CI Lower,M CI Lower,F CI Upper,M CI Upper,outcome
ATT,62.76,63.53,2.56,1.59,0.0,0.0,57.74,60.41,67.77,66.66,post_spends
ATC,97.96,93.52,2.36,1.65,0.0,0.0,93.33,90.28,102.6,96.76,post_spends
ATE,81.04,78.31,1.71,1.09,0.0,0.0,77.69,76.17,84.4,80.46,post_spends


In [27]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,9367,2,1,484.0,522.777778,25.0,M,Logistics
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,4961,0,0,526.5,416.666667,23.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,1479,0,0,497.0,428.111111,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,1735,1,1,504.0,516.333333,33.0,M,Logistics
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539,0,0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893,0,0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,924,1,1,503.0,531.555556,27.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066,0,0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,7341,0,0,500.0,425.000000,67.0,F,Logistics


## Bias estimation

Bias estimation can be disabled by setting "bias_estimation" argument to False

In [28]:
data.data.head()

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,26.0,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,M,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce


In [29]:
test = Matching(bias_estimation=False)
result = test.execute(data)

In [30]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.61,2.46,0.0,58.8,68.43,post_spends
ATC,99.01,1.56,0.0,95.95,102.08,post_spends
ATE,81.54,1.44,0.0,78.71,84.37,post_spends
