In [1]:
from hypex.dataset import Dataset, ExperimentData, InfoRole, TreatmentRole, TargetRole
from hypex.ml.faiss import FaissNearestNeighbors 

In [2]:
data = Dataset(
    roles={
        "user_id": InfoRole(int),
        "treat": TreatmentRole(int),
        "gender": InfoRole(), 
        "industry": InfoRole()
    }, data="data.csv",
)
data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           1             8      1       512.5   462.222222  26.0    NaN   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  
0     E-commerce  
1     E-commerce  
2      Logistics  
3     E-com

In [3]:
data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'gender': Info(<class 'str'>),
 'industry': Info(<class 'str'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'post_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>)}

In [13]:
from hypex.comparators.distances import MahalanobisDistance

test_maha = MahalanobisDistance(grouping_role=TreatmentRole()) 
res_maha = test_maha.execute(ExperimentData(data)) 

In [14]:
res_maha.groups

{"MahalanobisDistance┴┴['signup_month', 'pre_spends', 'post_spends', 'age']": {'control':                  0          1          2         3
  0              NaN        NaN        NaN       NaN
  3    -4.260553e-17  27.785702  57.098665  2.154735
  10             NaN        NaN        NaN       NaN
  12   -4.254975e-17  26.151249  56.739974  2.353910
  13   -4.259438e-17  28.173539  57.153914  3.014764
  ...            ...        ...        ...       ...
  9990           NaN        NaN        NaN       NaN
  9991 -4.236010e-17  26.733004  56.611141  2.421879
  9992 -4.257207e-17  27.231650  56.959249  1.493617
  9994 -4.254975e-17  26.926922  56.877790  4.139897
  9996 -4.326375e-17  27.730297  57.894674  1.288258
  
  [4936 rows x 4 columns],
  'test':              0          1          2         3
  1     3.602511  29.381883  71.667727  1.204128
  2     3.152197  27.624089  72.268996  1.127620
  4     0.450314  30.208358  69.823313  0.668942
  5     2.701883  27.694667  71.952193  2.

In [15]:
test = FaissNearestNeighbors(grouping_role=TreatmentRole(), two_sides=True) 
result = test.execute(res_maha)

In [17]:
result.additional_fields

{'FaissNearestNeighbors┴┴': array([       4931,  1908874354,  -954437177, -1908874354,           0,
         1431655765,   477218588,   954437177,  -477218588, -1431655765,
               4810, -1681010624,        4995,        5007,        5032,
               5053,        4969])}

In [8]:
# TODO check with different neighbours number
res = FaissNearestNeighbors.calc(data, features_fields=['signup_month', 'pre_spends', 'post_spends', 'age'], group_field='treat', n_neighbors=3)

In [9]:
res

      matched_indexes_0  matched_indexes_1  matched_indexes_2
0                  2131               2266                507
1                  4915               2183               1850
2                  1784               2131                529
3                  4595               2131               2618
4                  3397               3908               1720
...                 ...                ...                ...
5059               2522               2618               3243
5060                529               2705               2551
5061               2522               1909               1461
5062               2618               1110                244
5063               2131               4298                608

[5064 rows x 3 columns]