In [1]:
import fair_forge as ff

In [2]:
adult = ff.load_adult(group=ff.AdultGroup.SEX)

In [3]:
adult.data

array([[39.,  0.,  0., ...,  1.,  0.,  0.],
       [50.,  0.,  0., ...,  1.,  0.,  0.],
       [38.,  0.,  0., ...,  1.,  0.,  0.],
       ...,
       [38.,  0.,  0., ...,  1.,  0.,  0.],
       [44.,  0.,  0., ...,  1.,  0.,  0.],
       [35.,  0.,  0., ...,  1.,  0.,  0.]],
      shape=(45222, 101), dtype=float32)

In [4]:
adult.feature_names

['age',
 'workclass:Federal-gov',
 'workclass:Local-gov',
 'workclass:Private',
 'workclass:Self-emp-inc',
 'workclass:Self-emp-not-inc',
 'workclass:State-gov',
 'workclass:Without-pay',
 'education:10th',
 'education:11th',
 'education:12th',
 'education:1st-4th',
 'education:5th-6th',
 'education:7th-8th',
 'education:9th',
 'education:Assoc-acdm',
 'education:Assoc-voc',
 'education:Bachelors',
 'education:Doctorate',
 'education:HS-grad',
 'education:Masters',
 'education:Preschool',
 'education:Prof-school',
 'education:Some-college',
 'education-num',
 'marital-status:Divorced',
 'marital-status:Married-AF-spouse',
 'marital-status:Married-civ-spouse',
 'marital-status:Married-spouse-absent',
 'marital-status:Never-married',
 'marital-status:Separated',
 'marital-status:Widowed',
 'occupation:Adm-clerical',
 'occupation:Armed-Forces',
 'occupation:Craft-repair',
 'occupation:Exec-managerial',
 'occupation:Farming-fishing',
 'occupation:Handlers-cleaners',
 'occupation:Machine-op

In [5]:
adult.target

array([0, 0, 0, ..., 0, 0, 1], shape=(45222,), dtype=int32)

In [6]:
adult.groups

array([1, 1, 1, ..., 1, 1, 1], shape=(45222,), dtype=int32)

In [7]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

In [8]:
results = ff.evaluate(
    dataset=adult,
    methods={
        "Upsampler LR": ff.GroupPipeline(
            ff.Upsampler(), LogisticRegression(solver="liblinear")
        ),
        "LR": LogisticRegression(solver="liblinear"),
        "SVM": LinearSVC(),
        "Reweighting": ff.Reweighting(LogisticRegression(solver="liblinear")),
    },
    metrics=[accuracy_score],
    group_metrics=(
        ff.cv,
        *ff.as_group_metric(
            (accuracy_score,),
            ff.MetricAgg.MIN_MAX | ff.MetricAgg.INDIVIDUAL,
        ),
        *ff.as_group_metric(
            (ff.tpr, ff.prob_pos),
            ff.MetricAgg.DIFF_RATIO,
        ),
    ),
    repeat=2,
    preprocessor=StandardScaler(),
)

In [9]:
results

method,repeat_index,split_seed,accuracy,cv,accuracy_min,accuracy_max,accuracy_0,accuracy_1,tpr_diff,tpr_ratio,tpr_0,tpr_1,prob_pos_diff,prob_pos_ratio,prob_pos_0,prob_pos_1
enum,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""Upsampler LR""",0,42,0.841017,0.906259,0.804127,0.917659,0.917659,0.804127,-0.14026,1.264443,0.670659,0.530398,0.093741,0.563733,0.12113,0.214871
"""Upsampler LR""",1,43,0.83969,0.899179,0.806584,0.908472,0.908472,0.806584,-0.081723,1.15022,0.625749,0.544025,0.100821,0.543651,0.120109,0.22093
"""LR""",0,42,0.846545,0.821915,0.808877,0.924804,0.924804,0.808877,0.093743,0.844468,0.508982,0.602725,0.178085,0.302508,0.077237,0.255323
"""LR""",1,43,0.84754,0.817695,0.809532,0.926506,0.926506,0.809532,0.091201,0.851017,0.520958,0.612159,0.182305,0.300341,0.078258,0.260563
"""SVM""",0,42,0.846434,0.827596,0.80904,0.924124,0.924124,0.80904,0.091719,0.844996,0.5,0.591719,0.172404,0.305607,0.075876,0.24828
"""SVM""",1,43,0.846987,0.82262,0.809532,0.924804,0.924804,0.809532,0.089852,0.851441,0.51497,0.604822,0.17738,0.307051,0.078598,0.255978
"""Reweighting""",0,42,0.840354,0.908301,0.802817,0.91834,0.91834,0.802817,-0.154333,1.29213,0.682635,0.528302,0.091699,0.573234,0.123171,0.214871
"""Reweighting""",1,43,0.83958,0.898989,0.806092,0.909153,0.909153,0.806092,-0.083296,1.153554,0.625749,0.542453,0.101011,0.541775,0.119428,0.220439


In [10]:
results.group_by("method").mean()

method,repeat_index,split_seed,accuracy,cv,accuracy_min,accuracy_max,accuracy_0,accuracy_1,tpr_diff,tpr_ratio,tpr_0,tpr_1,prob_pos_diff,prob_pos_ratio,prob_pos_0,prob_pos_1
enum,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""Upsampler LR""",0.5,42.5,0.840354,0.902719,0.805355,0.913066,0.913066,0.805355,-0.110992,1.207332,0.648204,0.537212,0.097281,0.553692,0.120619,0.2179
"""LR""",0.5,42.5,0.847043,0.819805,0.809204,0.925655,0.925655,0.809204,0.092472,0.847742,0.51497,0.607442,0.180195,0.301425,0.077748,0.257943
"""SVM""",0.5,42.5,0.846711,0.825108,0.809286,0.924464,0.924464,0.809286,0.090785,0.848218,0.507485,0.59827,0.174892,0.306329,0.077237,0.252129
"""Reweighting""",0.5,42.5,0.839967,0.903645,0.804455,0.913746,0.913746,0.804455,-0.118814,1.222842,0.654192,0.535377,0.096355,0.557505,0.1213,0.217655
