# Matching pipline

The comparison method is used in statistical analysis to eliminate distortions caused by differences in the basic characteristics of the studied groups. Simply put, matching helps to make sure that the results of the experiment are really caused by the studied effect, and not by external factors.

Matching is most often performed in cases where the use of a standard AB test is impossible.

In [21]:
import copy

from hypex.dataset import Dataset, InfoRole, TreatmentRole, TargetRole, DefaultRole, FeatureRole
from hypex.experiments.matching import Matching

## Data preparation 

It is important to mark the data fields by assigning the appropriate roles:

* FeatureRole: a role for columns that contain features or predictor variables. Our split will be based on them. Applied by default if the role is not specified for the column.
* TreatmentRole: a role for columns that show the treatment or intervention.
* TargetRole: a role for columns that show the target or outcome variable.
* InfoRole: a role for columns that contain information about the data, such as user IDs.

In [22]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int), 
        "post_spends": TargetRole(float)
    }, data="data.csv",
)
data.replace_roles({DefaultRole(): FeatureRole()}, auto_roles_types=True)
data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           1             8      1       512.5   462.222222  26.0    NaN   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  
0     E-commerce  
1     E-commerce  
2      Logistics  
3     E-com

In [23]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Simple Matching  
Now matching has 4 steps: 
1. Dummy Encoder 
2. Process Mahalanobis distance 
3. Two sides pairs searching by faiss 
4. Metrics (ATT, ATC, ATE) estimation depends on your data 

In [24]:
data = data.fillna(method="bfill")

  return self.data.fillna(value=values, method=method, **kwargs)


In [25]:
test = Matching()
result = test.execute(data)

**ATT** shows the difference in treated group.   
**ATC** shows the difference in untreated group.   
**ATE** shows the weighted average difference between ATT and ATC.  

In [26]:
result.resume

         ATT        ATC        ATE
0  64.343019  96.393761  80.573515

In [27]:
result.indexes

      FaissNearestNeighbors┴┴
0                        9433
1                        5438
2                        5165
3                        1735
4                         539
...                       ...
9995                     5893
9996                     7731
9997                     7066
9998                     1885
9999                     5748

[10000 rows x 1 columns]

In [28]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444  26.0      M   
1           1             8      1       512.5   462.222222  26.0      M   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  user_id_matched  signup_month_matched  treat_matched  \
0     E-comme

We can change **metric** and do estimation again.

In [29]:
test = Matching(metric="atc")
result = test.execute(data)

In [30]:
result.resume

         ATC
0  96.393761

In [31]:
result.indexes

      FaissNearestNeighbors┴┴
0                        9433
1                        5438
2                        5165
3                        1735
4                         539
...                       ...
9995                     5893
9996                     7731
9997                     7066
9998                     1885
9999                     5748

[10000 rows x 1 columns]

In [32]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444  26.0      M   
1           1             8      1       512.5   462.222222  26.0      M   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  user_id_matched  signup_month_matched  treat_matched  \
0     E-comme

Also it is possible to search pairs only in **test group**. This way we have metric "auto" and **ATT** will be estimated. 

In [33]:
test = Matching(two_sides=False)
result = test.execute(data)

In [34]:
result.resume

         ATT
0  64.343019

In [35]:
result.indexes

      FaissNearestNeighbors┴┴
0                          -1
1                        5438
2                        5165
3                          -1
4                         539
...                       ...
9995                     5893
9996                       -1
9997                     7066
9998                     1885
9999                     5748

[10000 rows x 1 columns]

In [36]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444  26.0      M   
1           1             8      1       512.5   462.222222  26.0      M   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  user_id_matched  signup_month_matched  treat_matched  \
0     E-comme

Finally, we may search pairs in L2 distance. 

In [37]:
test = Matching(distance="l2", two_sides=False)
result = test.execute(data)

In [38]:
result.resume

         ATT
0  64.182315

In [39]:
result.indexes

      FaissNearestNeighbors┴┴
0                          -1
1                        2490
2                        5493
3                          -1
4                         321
...                       ...
9995                     5893
9996                       -1
9997                     8670
9998                      507
9999                     7155

[10000 rows x 1 columns]

In [40]:
result.full_data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444  26.0      M   
1           1             8      1       512.5   462.222222  26.0      M   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  user_id_matched  signup_month_matched  treat_matched  \
0     E-comme