# Matching pipeline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.

In [134]:
import sys
sys.path.append('../..')

In [135]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
from hypex import Matching
from hypex.dataset import Dataset, FeatureRole, InfoRole, TargetRole, TreatmentRole

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.
* TreatmentRole: a role for columns that show the treatment or intervention.
* TargetRole: a role for columns that show the target or outcome variable.
* InfoRole: a role for columns that contain information about the data, such as user IDs.

In [137]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "post_spends": TargetRole(float)
    },
    data="data.csv",
    default_role=FeatureRole(),
)
data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


In [138]:
from typing import Optional, Tuple

import numpy as np

from hypex.dataset import Dataset, DefaultRole
from hypex.extensions.scipy_linalg import CholeskyExtension, InverseExtension

def generate_data(size:int=1000, x_interval:Tuple[float] = (-5, 5), y_interval:Tuple[float]=(-7, 7), x_scale:float=5, y_scale:float=3, rs:Optional[int]=None, dotA:Tuple[int] = (0,0), dotB:Tuple[int] = (0, 5), dotC:Tuple[int] = (5, 0)):
    if rs:
        np.random.seed(rs)
    data = Dataset.from_dict(
        {
            'x': np.linspace(x_interval[0], x_interval[1], size) + np.random.normal(size=size, scale=x_scale),
            'a': np.linspace(x_interval[0], x_interval[1], size) + np.random.normal(size=size, scale=x_scale),
            
            'y': np.linspace(y_interval[0], y_interval[1], size) + np.random.normal(size=size, scale=y_scale),
            'id': np.linspace(y_interval[0], y_interval[1], size) + np.random.normal(size=size, scale=y_scale),
            
            'treat': np.random.choice([0,1], size)
        },
        roles = {"x": InfoRole(float),
        "treat": TreatmentRole(int),
        "y": TargetRole(float),
        "a": FeatureRole(float),
       "id": FeatureRole(float)
    }
    
    )
    # dots = Dataset.from_dict(
    #     {
    #         'x': [dotA[0],dotB[0],dotC[0]],
    #         'y': [dotA[1], dotB[1], dotC[1]],
    #         'mark': ['A', 'B', 'C']
    #     },
    #     roles = {}
    # )
    return data

#data = generate_data(1000006)
data = Dataset(
    roles = {"x": InfoRole(float),
        "treat": TreatmentRole(int),
        "y": TargetRole(float),
        "a": FeatureRole(float),
       "id": FeatureRole(float)
    },
    data = 'data2.csv',
    
    default_role=FeatureRole(),
    
)

In [141]:
data.data['treat'].sum()

4

In [140]:
data.data['treat'] = 0
data.data.loc[data.data.index[[0, 1, 2, 3]], 'treat'] = 1

In [142]:
data

Unnamed: 0,x,a,y,id,treat
0,-1.261750,3.100834,-11.364112,-4.629049,1
1,-7.334568,-1.976615,-3.603994,-6.046804,1
2,0.473645,-4.431252,-8.427517,-10.654364,1
3,-1.208378,-6.650571,-9.837955,-12.646858,1
4,-7.382247,-1.892355,-8.977479,-4.376245,0
...,...,...,...,...,...
1000001,6.856478,6.969530,11.178526,10.094631,0
1000002,6.768127,11.242613,2.656471,11.719071,0
1000003,4.314690,9.649187,9.698278,2.083988,0
1000004,9.502454,9.700882,3.359847,6.363373,0


In [None]:
data.roles

{'x': Info(<class 'float'>),
 'treat': Treatment(<class 'int'>),
 'y': Target(<class 'float'>),
 'a': Feature(<class 'float'>),
 'id': Feature(<class 'float'>)}

## Simple Matching  
Now matching has 4 steps: 
1. Dummy Encoder 
2. Process Mahalanobis distance 
3. Two sides pairs searching by faiss 
4. Metrics (ATT, ATC, ATE) estimation depends on your data 

In [143]:
data = data.fillna(method="bfill")

In [None]:
data = Dataset(
    roles = {"x": InfoRole(float),
        "treat": TreatmentRole(int),
        "y": TargetRole(float),
        "a": FeatureRole(float),
       "id": FeatureRole(float)
    },
    data = 'data2.csv',
    
    default_role=FeatureRole(),
    
)

In [147]:
test = Matching(quality_tests=['t-test'], faiss_mode='fast')
result = test.execute(data)

  return self.data.fillna(value=values, **kwargs)
  return self.data.fillna(value=values, **kwargs)


In [148]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value
0,y,1┆y,OK,0.206798


In [146]:
result.full_data[result.full_data['treat'] == 1]

Unnamed: 0,x,a,y,id,treat,x_matched,a_matched,y_matched,id_matched,treat_matched
0,-1.26175,3.100834,-11.364112,-4.629049,1,0.191006,3.109871,-3.344122,-4.625369,0
1,-7.334568,-1.976615,-3.603994,-6.046804,1,-13.391273,-1.977021,-7.745727,-6.049932,0
2,0.473645,-4.431252,-8.427517,-10.654364,1,-9.822826,-4.418355,0.853453,-10.631227,0
3,-1.208378,-6.650571,-9.837955,-12.646858,1,-2.562922,-6.581405,-7.870214,-12.621051,0


In [None]:
result.full_data[result.full_data['treat'] == 1]

Unnamed: 0,x,a,y,id,treat,x_matched,a_matched,y_matched,id_matched,treat_matched
0,-1.261750,3.100834,-11.364112,-4.629049,1,0.191006,3.109871,-3.344122,-4.625369,0
3,-1.208378,-6.650571,-9.837955,-12.646858,1,-8.001790,-6.615850,-2.967840,-12.560058,0
7,-0.743510,-4.623327,-6.271688,-6.002979,1,2.064135,-4.633225,-3.347532,-6.015859,0
8,-4.341069,-4.115639,-9.161990,-7.077057,1,0.139544,-4.086484,-3.433782,-7.075633,0
10,-1.179751,3.592116,-7.421925,-6.596552,1,-4.237393,3.606829,-8.001289,-6.606263,0
...,...,...,...,...,...,...,...,...,...,...
999990,6.604173,8.799783,5.057766,8.409435,1,9.165009,8.770994,5.933647,8.388481,0
999993,-0.027057,1.751703,7.831820,6.157373,1,7.968290,1.762559,1.205600,6.147802,0
999997,5.801604,5.732662,4.587242,9.028066,1,-1.188640,5.746885,10.875836,9.046493,0
1000000,-2.763483,2.478838,11.706238,4.518487,1,-4.399947,2.488356,0.510845,4.531238,0


In [None]:
result.quality_results

Unnamed: 0,feature,group,TTest pass,TTest p-value,KSTest pass,KSTest p-value
0,y,1┆y,OK,0.133777,OK,0.323264


**ATT** shows the difference in treated group.   
**ATC** shows the difference in untreated group.   
**ATE** shows the weighted average difference between ATT and ATC.  

In [None]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,-96.91,1.55,0.0,-99.95,-93.87,post_spends
ATC,-63.52,1.39,0.0,-66.25,-60.79,post_spends
ATE,-80.43,1.01,0.0,-82.41,-78.45,post_spends


In [None]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,1,488.0,414.444444,26.0,M,E-commerce,9367,2,0,484.0,522.777778,25.0,M,Logistics
1,1,8,0,512.5,462.222222,26.0,M,E-commerce,1897,0,1,525.5,422.000000,28.0,M,E-commerce
2,2,7,0,483.0,479.444444,25.0,M,Logistics,5165,0,1,498.5,412.222222,25.0,F,Logistics
3,3,0,1,501.5,424.333333,39.0,M,E-commerce,7497,1,0,508.5,525.444444,37.0,M,Logistics
4,4,1,0,543.0,514.555556,18.0,F,E-commerce,539,0,1,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,0,538.5,450.444444,42.0,M,Logistics,5893,0,1,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,1,500.5,430.888889,26.0,F,Logistics,7731,1,0,500.0,515.888889,25.0,M,Logistics
9997,9997,3,0,473.0,534.111111,22.0,F,E-commerce,7066,0,1,480.0,423.222222,22.0,F,Logistics
9998,9998,2,0,495.0,523.222222,67.0,F,E-commerce,1885,0,1,499.0,423.000000,67.0,F,Logistics


In [None]:
result.indexes

Unnamed: 0,indexes
0,9367
1,1897
2,5165
3,7497
4,539
...,...
9995,5893
9996,7731
9997,7066
9998,1885


In [None]:
result.full_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>),
 'user_id_matched': Info(<class 'int'>),
 'treat_matched': Treatment(<class 'int'>),
 'post_spends_matched': Target(<class 'float'>),
 'signup_month_matched': Feature(<class 'int'>),
 'pre_spends_matched': Feature(<class 'float'>),
 'age_matched': Feature(<class 'float'>),
 'gender_matched': Feature(<class 'str'>),
 'industry_matched': Feature(<class 'str'>)}

We can change **metric** and do estimation again.

In [None]:
test = Matching(metric="atc")
result = test.execute(data)

  return self.data.fillna(value=values, **kwargs)


In [None]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATC,96.47,0.14,0.0,96.21,96.74,post_spends


In [None]:
result.indexes

Unnamed: 0,indexes
0,9433
1,-1
2,-1
3,1735
4,-1
...,...
9995,-1
9996,7731
9997,-1
9998,-1


Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [None]:
test = Matching(metric='att')
result = test.execute(data)

  return self.data.fillna(value=values, **kwargs)


In [None]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.28,post_spends


In [None]:
result.indexes

Unnamed: 0,indexes
0,-1
1,5438
2,5165
3,-1
4,539
...,...
9995,5893
9996,-1
9997,7066
9998,1885


In [None]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,5438.0,0.0,0.0,529.0,417.111111,23.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5165.0,0.0,0.0,498.5,412.222222,25.0,F,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,539.0,0.0,0.0,531.0,414.000000,20.0,F,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,7066.0,0.0,0.0,480.0,423.222222,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,1885.0,0.0,0.0,499.0,423.000000,67.0,F,Logistics


Finally, we may search pairs in L2 distance. 

In [None]:
test = Matching(distance="l2", metric='att')
result = test.execute(data)

  return self.data.fillna(value=values, **kwargs)


In [None]:
result.resume

Unnamed: 0,Effect Size,Standard Error,P-value,CI Lower,CI Upper,outcome
ATT,63.37,0.46,0.0,62.46,64.27,post_spends


In [None]:
result.indexes

Unnamed: 0,indexes
0,-1
1,2490
2,5493
3,-1
4,321
...,...
9995,5893
9996,-1
9997,8670
9998,507


In [None]:
result.full_data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry,user_id_matched,signup_month_matched,treat_matched,pre_spends_matched,post_spends_matched,age_matched,gender_matched,industry_matched
0,0,0,0,488.0,414.444444,26.0,M,E-commerce,,,,,,,,
1,1,8,1,512.5,462.222222,26.0,M,E-commerce,2490.0,0.0,0.0,511.5,417.444444,27.0,F,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics,5493.0,0.0,0.0,483.0,408.000000,25.0,M,E-commerce
3,3,0,0,501.5,424.333333,39.0,M,E-commerce,,,,,,,,
4,4,1,1,543.0,514.555556,18.0,F,E-commerce,321.0,0.0,0.0,538.0,421.444444,29.0,M,E-commerce
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics,5893.0,0.0,0.0,535.0,414.555556,40.0,M,E-commerce
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics,,,,,,,,
9997,9997,3,1,473.0,534.111111,22.0,F,E-commerce,8670.0,0.0,0.0,473.0,415.777778,22.0,F,Logistics
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce,507.0,0.0,0.0,495.0,429.777778,67.0,F,Logistics
