# Matching pipline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.

In [1]:
import copy

from hypex.dataset import Dataset, InfoRole, TreatmentRole, TargetRole 
from hypex.experiments.matching import Matching

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.
* TreatmentRole: a role for columns that show the treatment or intervention.
* TargetRole: a role for columns that show the target or outcome variable.
* InfoRole: a role for columns that contain information about the data, such as user IDs.

In [2]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int), 
        "post_spends": TargetRole(float)
    }, data="data.csv",
)
data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           1             8      1       512.5   462.222222  26.0    NaN   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  
0     E-commerce  
1     E-commerce  
2      Logistics  
3     E-com

In [3]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Simple Matching  
Now matching has 4 steps: 
1. Dummy Encoder 
2. Process Mahalanobis distance 
3. Two sides pairs searching by faiss 
4. Metrics (ATT, ATC, ATE) estimation depends on your data 

In [4]:
c_data = copy.deepcopy(data)
test = Matching()
result = test.execute(c_data)

In [5]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           0             0      0       488.0   414.444444   NaN      M   
2           0             0      0       488.0   414.444444   NaN      M   
3           0             0      0       488.0   414.444444   NaN      M   
4           0             0      0       488.0   414.444444   NaN      M   
...       ...           ...    ...         ...          ...   ...    ...   
9995     5053            10      1       497.5   440.888889  65.0      F   
9996     5053            10      1       497.5   440.888889  65.0      F   
9997     5053            10      1       497.5   440.888889  65.0      F   
9998     5062             0      0       463.0   425.777778  29.0      M   
9999     5062             0      0       463.0   425.777778  29.0      M   

      industry  user_id_matched  signup_month_matched  treat_matched  \
0         True 

**ATT** shows the difference in treated group.   
**ATC** shows the difference in untreated group.   
**ATE** shows the weighted average difference between ATT and ATC.   

In [6]:
result.resume

{'ATT': 209.90982351881865,
 'ATC': 268.1156310338775,
 'ATE': 239.38524444444445}

We can change **metric** and do estimation again.

In [7]:
c_data = copy.deepcopy(data)
test = Matching(metric="atc")
result = test.execute(c_data)

In [8]:
result.resume

{'ATC': 257.18683078813416}

In [9]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           0             0      0       488.0   414.444444   NaN      M   
2           0             0      0       488.0   414.444444   NaN      M   
3           0             0      0       488.0   414.444444   NaN      M   
4           0             0      0       488.0   414.444444   NaN      M   
...       ...           ...    ...         ...          ...   ...    ...   
9995     8592            10      1       506.5   429.888889  52.0      M   
9996     9040             7      1       489.0   485.777778   NaN      F   
9997     9790             1      1       521.5   519.888889   NaN      F   
9998     9910             5      1       480.5   499.222222   NaN      F   
9999     9950             2      1       520.5   522.777778   NaN      M   

      industry  user_id_matched  signup_month_matched  treat_matched  \
0         True 

Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [10]:
c_data = copy.deepcopy(data)
test = Matching(two_sides=False)
result = test.execute(c_data)

In [11]:
result.resume

{'ATT': 209.90982351881865}

In [12]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           0             0      0       488.0   414.444444   NaN      M   
2           0             0      0       488.0   414.444444   NaN      M   
3           0             0      0       488.0   414.444444   NaN      M   
4           0             0      0       488.0   414.444444   NaN      M   
...       ...           ...    ...         ...          ...   ...    ...   
9995     4931             0      0       480.5   426.777778  48.0    NaN   
9996     4931             0      0       480.5   426.777778  48.0    NaN   
9997     4931             0      0       480.5   426.777778  48.0    NaN   
9998     4931             0      0       480.5   426.777778  48.0    NaN   
9999     4931             0      0       480.5   426.777778  48.0    NaN   

      industry  user_id_matched  signup_month_matched  treat_matched  \
0         True 

Finally, we may search pairs in L2 distance. 

In [13]:
c_data = copy.deepcopy(data)
test = Matching(distance="l2", two_sides=False)
result = test.execute(c_data)

In [14]:
result.resume

{'ATT': 187.37531514496666}

In [15]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           0             0      0       488.0   414.444444   NaN      M   
2           0             0      0       488.0   414.444444   NaN      M   
3           0             0      0       488.0   414.444444   NaN      M   
4           0             0      0       488.0   414.444444   NaN      M   
...       ...           ...    ...         ...          ...   ...    ...   
9995     4931             0      0       480.5   426.777778  48.0    NaN   
9996     4931             0      0       480.5   426.777778  48.0    NaN   
9997     4932             0      0       489.0   417.111111  60.0      F   
9998     4934             1      1       562.0   538.333333  64.0      M   
9999     4934             1      1       562.0   538.333333  64.0      M   

      industry  user_id_matched  signup_month_matched  treat_matched  \
0         True 