In [7]:
from fair_util import BMInterface
from fair_grid import BMGridSearch
from fair_bm import BMType
import pandas as pd

import mkl
mkl.set_num_threads(5)

train_df = pd.read_csv('./data/stroke_pre/stroke_train.csv')
val_df = pd.read_csv('./data/stroke_pre/stroke_val.csv')
test_df = pd.read_csv('./data/stroke_pre/stroke_test.csv')

# Defining the name of the label column
label_name = 'stroke'
# Defininf the name of the privileged group
sensitive_attribute = ['age']

bmI = BMInterface(train_df, val_df, test_df, label_name, sensitive_attribute)

privileged_group = [{'age': 1}]
unprivileged_group = [{'age': 0}]

In [4]:
from xgboost import XGBClassifier
from sklearn.svm import LinearSVC  
from sklearn.linear_model import LogisticRegression
from pytorch_tabnet.tab_model import TabNetClassifier
from catboost import CatBoostClassifier
from sklearn.neural_network import MLPClassifier

mlp= MLPClassifier(
    max_iter=500,
    hidden_layer_sizes = (200,200,200,200),
    activation = 'logistic',
    solver = 'adam',
    alpha = 0.01,
    learning_rate = 'adaptive'
)

# Create a new instance of the classifier
lr = LogisticRegression(solver='liblinear')

xgbr =  XGBClassifier(
    max_depth=8,
    learning_rate=0.01,
    gamma = 0.25,
    n_estimators = 500,
    subsample = 0.8,
    colsample_bytree = 0.3)

tabnet = TabNetClassifier()

cat = CatBoostClassifier(eval_metric='Accuracy',
                         depth =  4,
                         learning_rate = 0.01,
                         iterations = 10)


model_list = [tabnet, xgbr, cat, mlp, lr]



In [None]:
tabnet.__str__()#.startswith('XGBClassifier')

In [9]:
from itertools import product
from more_itertools import powerset

pre_bm = [BMType.preReweighing,BMType.preDisparate]
pos_bm = [BMType.posCalibrated, BMType.posEqqOds, BMType.posROC]

bm_list = [[pre, pos] for pre, pos in product(pre_bm, pos_bm)]
bm_list.extend([p] for p in pre_bm); bm_list.extend([p] for p in pos_bm)
bm_list

[[<BMType.preReweighing: 1>, <BMType.posCalibrated: 6>],
 [<BMType.preReweighing: 1>, <BMType.posEqqOds: 7>],
 [<BMType.preReweighing: 1>, <BMType.posROC: 8>],
 [<BMType.preDisparate: 2>, <BMType.posCalibrated: 6>],
 [<BMType.preDisparate: 2>, <BMType.posEqqOds: 7>],
 [<BMType.preDisparate: 2>, <BMType.posROC: 8>],
 [<BMType.preReweighing: 1>],
 [<BMType.preDisparate: 2>],
 [<BMType.posCalibrated: 6>],
 [<BMType.posEqqOds: 7>],
 [<BMType.posROC: 8>]]

In [17]:
import copy

bm_list_in = copy.deepcopy(bm_list)
[l.append(BMType.inAdversarial) for l in bm_list_in]
bm_list_in.extend(bm_list)
bm_list_in.append([BMType.inAdversarial])
bm_list_in

[[<BMType.preReweighing: 1>,
  <BMType.posCalibrated: 6>,
  <BMType.inAdversarial: 4>],
 [<BMType.preReweighing: 1>, <BMType.posEqqOds: 7>, <BMType.inAdversarial: 4>],
 [<BMType.preReweighing: 1>, <BMType.posROC: 8>, <BMType.inAdversarial: 4>],
 [<BMType.preDisparate: 2>,
  <BMType.posCalibrated: 6>,
  <BMType.inAdversarial: 4>],
 [<BMType.preDisparate: 2>, <BMType.posEqqOds: 7>, <BMType.inAdversarial: 4>],
 [<BMType.preDisparate: 2>, <BMType.posROC: 8>, <BMType.inAdversarial: 4>],
 [<BMType.preReweighing: 1>, <BMType.inAdversarial: 4>],
 [<BMType.preDisparate: 2>, <BMType.inAdversarial: 4>],
 [<BMType.posCalibrated: 6>, <BMType.inAdversarial: 4>],
 [<BMType.posEqqOds: 7>, <BMType.inAdversarial: 4>],
 [<BMType.posROC: 8>, <BMType.inAdversarial: 4>],
 [<BMType.preReweighing: 1>, <BMType.posCalibrated: 6>],
 [<BMType.preReweighing: 1>, <BMType.posEqqOds: 7>],
 [<BMType.preReweighing: 1>, <BMType.posROC: 8>],
 [<BMType.preDisparate: 2>, <BMType.posCalibrated: 6>],
 [<BMType.preDisparate: 

In [None]:
for model in model_list:
    bmG = BMGridSearch(bmI, model=model, bm_list=bm_list, privileged_group=privileged_group, unprivileged_group=unprivileged_group)
    bmG.run_single_sensitive()

In [13]:
bmG = BMGridSearch(bmI, model=None, bm_list=[[BMType.inAdversarial]], privileged_group=privileged_group, unprivileged_group=unprivileged_group)
bmG.run_single_sensitive()

epoch 0; iter: 0; batch classifier loss: 0.745372
epoch 1; iter: 0; batch classifier loss: 0.612154
epoch 2; iter: 0; batch classifier loss: 0.635546
epoch 3; iter: 0; batch classifier loss: 0.565819
epoch 4; iter: 0; batch classifier loss: 0.600401
epoch 5; iter: 0; batch classifier loss: 0.577236
epoch 6; iter: 0; batch classifier loss: 0.567886
epoch 7; iter: 0; batch classifier loss: 0.506691
epoch 8; iter: 0; batch classifier loss: 0.465262
epoch 9; iter: 0; batch classifier loss: 0.546089
epoch 10; iter: 0; batch classifier loss: 0.468253
epoch 11; iter: 0; batch classifier loss: 0.522293
epoch 12; iter: 0; batch classifier loss: 0.659548
epoch 13; iter: 0; batch classifier loss: 0.436789
epoch 14; iter: 0; batch classifier loss: 0.508616
epoch 15; iter: 0; batch classifier loss: 0.490434
epoch 16; iter: 0; batch classifier loss: 0.357344
epoch 17; iter: 0; batch classifier loss: 0.367088
epoch 18; iter: 0; batch classifier loss: 0.472829
epoch 19; iter: 0; batch classifier loss:

In [19]:
for bm in bm_list:
    bmG = BMGridSearch(bmI, model=None, bm_list=bm_list_in, privileged_group=privileged_group, unprivileged_group=unprivileged_group)
    bmG.run_single_sensitive()

epoch 0; iter: 0; batch classifier loss: 0.719803
epoch 1; iter: 0; batch classifier loss: 0.654992
epoch 2; iter: 0; batch classifier loss: 0.610857
epoch 3; iter: 0; batch classifier loss: 0.581193
epoch 4; iter: 0; batch classifier loss: 0.558953
epoch 5; iter: 0; batch classifier loss: 0.539920
epoch 6; iter: 0; batch classifier loss: 0.502686
epoch 7; iter: 0; batch classifier loss: 0.493570
epoch 8; iter: 0; batch classifier loss: 0.614913
epoch 9; iter: 0; batch classifier loss: 0.539116
epoch 10; iter: 0; batch classifier loss: 0.472148
epoch 11; iter: 0; batch classifier loss: 0.496182
epoch 12; iter: 0; batch classifier loss: 0.460344
epoch 13; iter: 0; batch classifier loss: 0.436358
epoch 14; iter: 0; batch classifier loss: 0.418672
epoch 15; iter: 0; batch classifier loss: 0.394867
epoch 16; iter: 0; batch classifier loss: 0.549934
epoch 17; iter: 0; batch classifier loss: 0.408042
epoch 18; iter: 0; batch classifier loss: 0.423845
epoch 19; iter: 0; batch classifier loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.679534; batch adversarial loss: 0.728273
epoch 1; iter: 0; batch classifier loss: 0.659544; batch adversarial loss: 0.700448
epoch 2; iter: 0; batch classifier loss: 0.649007; batch adversarial loss: 0.703333
epoch 3; iter: 0; batch classifier loss: 0.648143; batch adversarial loss: 0.708444
epoch 4; iter: 0; batch classifier loss: 0.623381; batch adversarial loss: 0.716429
epoch 5; iter: 0; batch classifier loss: 0.626076; batch adversarial loss: 0.708169
epoch 6; iter: 0; batch classifier loss: 0.587354; batch adversarial loss: 0.720147
epoch 7; iter: 0; batch classifier loss: 0.547467; batch adversarial loss: 0.715715
epoch 8; iter: 0; batch classifier loss: 0.607473; batch adversarial loss: 0.717915
epoch 9; iter: 0; batch classifier loss: 0.617357; batch adversarial loss: 0.696319
epoch 10; iter: 0; batch classifier loss: 0.581906; batch adversarial loss: 0.699845
epoch 11; iter: 0; batch classifier loss: 0.611414; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.713886; batch adversarial loss: 0.796376
epoch 1; iter: 0; batch classifier loss: 0.697121; batch adversarial loss: 0.921602
epoch 2; iter: 0; batch classifier loss: 0.727936; batch adversarial loss: 0.923898
epoch 3; iter: 0; batch classifier loss: 0.700186; batch adversarial loss: 1.038342
epoch 4; iter: 0; batch classifier loss: 0.704873; batch adversarial loss: 1.086937
epoch 5; iter: 0; batch classifier loss: 0.812813; batch adversarial loss: 1.227965
epoch 6; iter: 0; batch classifier loss: 0.864645; batch adversarial loss: 1.397516
epoch 7; iter: 0; batch classifier loss: 0.767736; batch adversarial loss: 1.227604
epoch 8; iter: 0; batch classifier loss: 0.852062; batch adversarial loss: 1.208915
epoch 9; iter: 0; batch classifier loss: 0.796743; batch adversarial loss: 1.240277
epoch 10; iter: 0; batch classifier loss: 1.093536; batch adversarial loss: 1.396934
epoch 11; iter: 0; batch classifier loss: 0.918514; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.676391; batch adversarial loss: 0.774378
epoch 1; iter: 0; batch classifier loss: 0.659110; batch adversarial loss: 0.808660
epoch 2; iter: 0; batch classifier loss: 0.636593; batch adversarial loss: 0.833747
epoch 3; iter: 0; batch classifier loss: 0.629125; batch adversarial loss: 0.805388
epoch 4; iter: 0; batch classifier loss: 0.610011; batch adversarial loss: 0.913593
epoch 5; iter: 0; batch classifier loss: 0.585976; batch adversarial loss: 0.940167
epoch 6; iter: 0; batch classifier loss: 0.552903; batch adversarial loss: 0.895137
epoch 7; iter: 0; batch classifier loss: 0.553292; batch adversarial loss: 0.770156
epoch 8; iter: 0; batch classifier loss: 0.541631; batch adversarial loss: 0.881264
epoch 9; iter: 0; batch classifier loss: 0.652719; batch adversarial loss: 0.739511
epoch 10; iter: 0; batch classifier loss: 0.676888; batch adversarial loss: 0.779180
epoch 11; iter: 0; batch classifier loss: 0.552928; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.667338; batch adversarial loss: 0.781846
epoch 1; iter: 0; batch classifier loss: 0.589211; batch adversarial loss: 0.766533
epoch 2; iter: 0; batch classifier loss: 0.638645; batch adversarial loss: 0.785962
epoch 3; iter: 0; batch classifier loss: 0.639287; batch adversarial loss: 0.817909
epoch 4; iter: 0; batch classifier loss: 0.722497; batch adversarial loss: 0.862662
epoch 5; iter: 0; batch classifier loss: 0.719715; batch adversarial loss: 0.841944
epoch 6; iter: 0; batch classifier loss: 0.575639; batch adversarial loss: 0.758846
epoch 7; iter: 0; batch classifier loss: 0.795517; batch adversarial loss: 0.855204
epoch 8; iter: 0; batch classifier loss: 0.769704; batch adversarial loss: 0.844806
epoch 9; iter: 0; batch classifier loss: 0.760459; batch adversarial loss: 0.821356
epoch 10; iter: 0; batch classifier loss: 0.697964; batch adversarial loss: 0.815039
epoch 11; iter: 0; batch classifier loss: 0.709055; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.634651; batch adversarial loss: 0.858790
epoch 1; iter: 0; batch classifier loss: 0.695794; batch adversarial loss: 0.961884
epoch 2; iter: 0; batch classifier loss: 0.670720; batch adversarial loss: 0.953141
epoch 3; iter: 0; batch classifier loss: 0.649666; batch adversarial loss: 1.010586
epoch 4; iter: 0; batch classifier loss: 0.828670; batch adversarial loss: 1.067489
epoch 5; iter: 0; batch classifier loss: 0.793809; batch adversarial loss: 1.020226
epoch 6; iter: 0; batch classifier loss: 0.883036; batch adversarial loss: 1.120602
epoch 7; iter: 0; batch classifier loss: 1.015721; batch adversarial loss: 1.136048
epoch 8; iter: 0; batch classifier loss: 0.907666; batch adversarial loss: 1.083309
epoch 9; iter: 0; batch classifier loss: 1.033359; batch adversarial loss: 1.042416
epoch 10; iter: 0; batch classifier loss: 0.943179; batch adversarial loss: 1.043541
epoch 11; iter: 0; batch classifier loss: 0.806724; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)
  f1 = 2 * ((precision * recall)/(precision + recall))


epoch 0; iter: 0; batch classifier loss: 0.700389
epoch 1; iter: 0; batch classifier loss: 0.638720
epoch 2; iter: 0; batch classifier loss: 0.622764
epoch 3; iter: 0; batch classifier loss: 0.521397
epoch 4; iter: 0; batch classifier loss: 0.543191
epoch 5; iter: 0; batch classifier loss: 0.561388
epoch 6; iter: 0; batch classifier loss: 0.482110
epoch 7; iter: 0; batch classifier loss: 0.527573
epoch 8; iter: 0; batch classifier loss: 0.544490
epoch 9; iter: 0; batch classifier loss: 0.502498
epoch 10; iter: 0; batch classifier loss: 0.516447
epoch 11; iter: 0; batch classifier loss: 0.401210
epoch 12; iter: 0; batch classifier loss: 0.457146
epoch 13; iter: 0; batch classifier loss: 0.404309
epoch 14; iter: 0; batch classifier loss: 0.512547
epoch 15; iter: 0; batch classifier loss: 0.404949
epoch 16; iter: 0; batch classifier loss: 0.411443
epoch 17; iter: 0; batch classifier loss: 0.459982
epoch 18; iter: 0; batch classifier loss: 0.434491
epoch 19; iter: 0; batch classifier loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.670593; batch adversarial loss: 0.607333
epoch 1; iter: 0; batch classifier loss: 0.618034; batch adversarial loss: 0.583971
epoch 2; iter: 0; batch classifier loss: 0.602646; batch adversarial loss: 0.627170
epoch 3; iter: 0; batch classifier loss: 0.564518; batch adversarial loss: 0.657006
epoch 4; iter: 0; batch classifier loss: 0.522517; batch adversarial loss: 0.710122
epoch 5; iter: 0; batch classifier loss: 0.496973; batch adversarial loss: 0.767244
epoch 6; iter: 0; batch classifier loss: 0.520830; batch adversarial loss: 0.777683
epoch 7; iter: 0; batch classifier loss: 0.464272; batch adversarial loss: 0.718226
epoch 8; iter: 0; batch classifier loss: 0.480203; batch adversarial loss: 0.787807
epoch 9; iter: 0; batch classifier loss: 0.426310; batch adversarial loss: 0.735890
epoch 10; iter: 0; batch classifier loss: 0.487289; batch adversarial loss: 0.843212
epoch 11; iter: 0; batch classifier loss: 0.513411; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.687020
epoch 1; iter: 0; batch classifier loss: 0.607689
epoch 2; iter: 0; batch classifier loss: 0.618853
epoch 3; iter: 0; batch classifier loss: 0.575718
epoch 4; iter: 0; batch classifier loss: 0.552570
epoch 5; iter: 0; batch classifier loss: 0.504162
epoch 6; iter: 0; batch classifier loss: 0.505480
epoch 7; iter: 0; batch classifier loss: 0.519438
epoch 8; iter: 0; batch classifier loss: 0.451345
epoch 9; iter: 0; batch classifier loss: 0.509365
epoch 10; iter: 0; batch classifier loss: 0.548173
epoch 11; iter: 0; batch classifier loss: 0.455651
epoch 12; iter: 0; batch classifier loss: 0.455464
epoch 13; iter: 0; batch classifier loss: 0.467490
epoch 14; iter: 0; batch classifier loss: 0.482486
epoch 15; iter: 0; batch classifier loss: 0.472437
epoch 16; iter: 0; batch classifier loss: 0.396048
epoch 17; iter: 0; batch classifier loss: 0.399102
epoch 18; iter: 0; batch classifier loss: 0.461994
epoch 19; iter: 0; batch classifier loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.660501; batch adversarial loss: 0.782430
epoch 1; iter: 0; batch classifier loss: 0.667382; batch adversarial loss: 0.792547
epoch 2; iter: 0; batch classifier loss: 0.606400; batch adversarial loss: 0.864757
epoch 3; iter: 0; batch classifier loss: 0.635594; batch adversarial loss: 0.860699
epoch 4; iter: 0; batch classifier loss: 0.722526; batch adversarial loss: 0.912197
epoch 5; iter: 0; batch classifier loss: 0.713963; batch adversarial loss: 0.874728
epoch 6; iter: 0; batch classifier loss: 0.738538; batch adversarial loss: 0.905705
epoch 7; iter: 0; batch classifier loss: 0.806590; batch adversarial loss: 0.940937
epoch 8; iter: 0; batch classifier loss: 0.876248; batch adversarial loss: 0.966174
epoch 9; iter: 0; batch classifier loss: 0.742084; batch adversarial loss: 0.907201
epoch 10; iter: 0; batch classifier loss: 0.777785; batch adversarial loss: 0.891733
epoch 11; iter: 0; batch classifier loss: 0.877293; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.709157; batch adversarial loss: 0.733996
epoch 1; iter: 0; batch classifier loss: 0.642358; batch adversarial loss: 0.677238
epoch 2; iter: 0; batch classifier loss: 0.615455; batch adversarial loss: 0.735028
epoch 3; iter: 0; batch classifier loss: 0.644773; batch adversarial loss: 0.715169
epoch 4; iter: 0; batch classifier loss: 0.606711; batch adversarial loss: 0.733675
epoch 5; iter: 0; batch classifier loss: 0.551583; batch adversarial loss: 0.761652
epoch 6; iter: 0; batch classifier loss: 0.521553; batch adversarial loss: 0.762047
epoch 7; iter: 0; batch classifier loss: 0.521096; batch adversarial loss: 0.758723
epoch 8; iter: 0; batch classifier loss: 0.497032; batch adversarial loss: 0.723090
epoch 9; iter: 0; batch classifier loss: 0.519047; batch adversarial loss: 0.786051
epoch 10; iter: 0; batch classifier loss: 0.498354; batch adversarial loss: 0.691449
epoch 11; iter: 0; batch classifier loss: 0.473334; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.656744; batch adversarial loss: 0.760360
epoch 1; iter: 0; batch classifier loss: 0.647561; batch adversarial loss: 0.739507
epoch 2; iter: 0; batch classifier loss: 0.689231; batch adversarial loss: 0.771907
epoch 3; iter: 0; batch classifier loss: 0.628378; batch adversarial loss: 0.762422
epoch 4; iter: 0; batch classifier loss: 0.644035; batch adversarial loss: 0.787597
epoch 5; iter: 0; batch classifier loss: 0.680275; batch adversarial loss: 0.812769
epoch 6; iter: 0; batch classifier loss: 0.705135; batch adversarial loss: 0.784431
epoch 7; iter: 0; batch classifier loss: 0.730333; batch adversarial loss: 0.777083
epoch 8; iter: 0; batch classifier loss: 0.673996; batch adversarial loss: 0.773020
epoch 9; iter: 0; batch classifier loss: 0.741829; batch adversarial loss: 0.781551
epoch 10; iter: 0; batch classifier loss: 0.735905; batch adversarial loss: 0.767628
epoch 11; iter: 0; batch classifier loss: 0.682088; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.707723
epoch 1; iter: 0; batch classifier loss: 0.673028
epoch 2; iter: 0; batch classifier loss: 0.624231
epoch 3; iter: 0; batch classifier loss: 0.586178
epoch 4; iter: 0; batch classifier loss: 0.616272
epoch 5; iter: 0; batch classifier loss: 0.539740
epoch 6; iter: 0; batch classifier loss: 0.515726
epoch 7; iter: 0; batch classifier loss: 0.563883
epoch 8; iter: 0; batch classifier loss: 0.460168
epoch 9; iter: 0; batch classifier loss: 0.473977
epoch 10; iter: 0; batch classifier loss: 0.479535
epoch 11; iter: 0; batch classifier loss: 0.558231
epoch 12; iter: 0; batch classifier loss: 0.542097
epoch 13; iter: 0; batch classifier loss: 0.495395
epoch 14; iter: 0; batch classifier loss: 0.457554
epoch 15; iter: 0; batch classifier loss: 0.471926
epoch 16; iter: 0; batch classifier loss: 0.437409
epoch 17; iter: 0; batch classifier loss: 0.518007
epoch 18; iter: 0; batch classifier loss: 0.450994
epoch 19; iter: 0; batch classifier loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.650862; batch adversarial loss: 0.739318
epoch 1; iter: 0; batch classifier loss: 0.659691; batch adversarial loss: 0.789088
epoch 2; iter: 0; batch classifier loss: 0.675921; batch adversarial loss: 0.799783
epoch 3; iter: 0; batch classifier loss: 0.646468; batch adversarial loss: 0.793489
epoch 4; iter: 0; batch classifier loss: 0.686371; batch adversarial loss: 0.820268
epoch 5; iter: 0; batch classifier loss: 0.681104; batch adversarial loss: 0.845017
epoch 6; iter: 0; batch classifier loss: 0.847352; batch adversarial loss: 0.931536
epoch 7; iter: 0; batch classifier loss: 0.856703; batch adversarial loss: 0.894274
epoch 8; iter: 0; batch classifier loss: 0.752944; batch adversarial loss: 0.865300
epoch 9; iter: 0; batch classifier loss: 0.743242; batch adversarial loss: 0.845164
epoch 10; iter: 0; batch classifier loss: 0.851550; batch adversarial loss: 0.896326
epoch 11; iter: 0; batch classifier loss: 0.968200; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)


epoch 0; iter: 0; batch classifier loss: 0.719228; batch adversarial loss: 0.740388
epoch 1; iter: 0; batch classifier loss: 0.708387; batch adversarial loss: 0.739791
epoch 2; iter: 0; batch classifier loss: 0.648779; batch adversarial loss: 0.748028
epoch 3; iter: 0; batch classifier loss: 0.658029; batch adversarial loss: 0.687131
epoch 4; iter: 0; batch classifier loss: 0.657254; batch adversarial loss: 0.733654
epoch 5; iter: 0; batch classifier loss: 0.590790; batch adversarial loss: 0.748695
epoch 6; iter: 0; batch classifier loss: 0.593804; batch adversarial loss: 0.769261
epoch 7; iter: 0; batch classifier loss: 0.658408; batch adversarial loss: 0.762331
epoch 8; iter: 0; batch classifier loss: 0.508819; batch adversarial loss: 0.865844
epoch 9; iter: 0; batch classifier loss: 0.545554; batch adversarial loss: 0.767737
epoch 10; iter: 0; batch classifier loss: 0.546303; batch adversarial loss: 0.821102
epoch 11; iter: 0; batch classifier loss: 0.569473; batch adversarial loss:

  return metric_fun(privileged=False) / metric_fun(privileged=True)
