In [1]:
import numpy as np
import os
import torch
import pandas as pd
import yaml

from sklearn.externals import joblib

from pytorch_utils.datasets import ArrayDataset
from pytorch_utils.models import FeedforwardNetModel
import pytorch_utils

In [2]:
outcome = 'los'

data_path = 'data/'
features_path = os.path.join(data_path, 'features', str(0))
label_path = os.path.join(data_path, 'labels')
config_path = os.path.join(data_path, 'config', 'grid', 'baseline')
checkpoints_path = os.path.join(data_path, 'checkpoints', 'scratch', outcome)
performance_path = os.path.join(data_path, 'performance', 'scratch', outcome)

In [3]:
os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(performance_path, exist_ok=True)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
features_dict = joblib.load(os.path.join(features_path, 'features.pkl'))
label_dict = joblib.load(os.path.join(label_path, 'label_dict.pkl'))

In [6]:
grid_element = 1

In [7]:
data_dict = {split: features_dict[split]['features'] for split in features_dict.keys()}
outcome_dict = {split : label_dict[split][outcome] for split in label_dict.keys()}

In [8]:
# with open(os.path.join(config_path, '{}.yaml'.format(grid_element)), 'r') as fp:
#     config_dict = yaml.load(fp)
    
# config_dict['num_epochs'] = 3 # For testing

## A more complex network
# config_dict = {
#     'input_dim' : data_dict['train'].shape[1],
#     'lr' : 1e-5,
#     'num_epochs' : 20,
#     'batch_size' : 256,
#     'hidden_dim' : 128,
#     'num_hidden' : 1,
#     'output_dim' : 2,
#     'drop_prob' : 0.5,
#     'normalize' : True,
#     'iters_per_epoch' : 100,
#     'gamma' : 0.99,
#     'resnet' : True,
#     'sparse' : True,
#     'sparse_mode' : 'binary'
# }

## Logistic Regression
config_dict = {
    'input_dim' : data_dict['train'].shape[1],
    'lr' : 1e-5,
    'num_epochs' : 20,
    'batch_size' : 256,
    'hidden_dim' : 128,
    'num_hidden' : 0,
    'output_dim' : 2,
    'drop_prob' : 0.0,
    'normalize' : False,
    'iters_per_epoch' : 100,
    'gamma' : 0.99,
    'resnet' : False,
    'sparse' : True,
    'sparse_mode' : 'binary'
}

In [9]:
config_dict

{'input_dim': 368117,
 'lr': 1e-05,
 'num_epochs': 20,
 'batch_size': 256,
 'hidden_dim': 128,
 'num_hidden': 0,
 'output_dim': 2,
 'drop_prob': 0.0,
 'normalize': False,
 'iters_per_epoch': 100,
 'gamma': 0.99,
 'resnet': False,
 'sparse': True,
 'sparse_mode': 'binary'}

In [10]:
model = FeedforwardNetModel(config_dict)

In [11]:
for child in model.model.children():
    print(child)

LinearLayerWrapper(
  (linear): EmbeddingBagLinear(
    in_features=368117, out_features=2, bias=True
    (embed): EmbeddingBag(368117, 2, mode=sum)
  )
)
ModuleList(
  (0): LinearLayerWrapper(
    (linear): EmbeddingBagLinear(
      in_features=368117, out_features=2, bias=True
      (embed): EmbeddingBag(368117, 2, mode=sum)
    )
  )
)


In [12]:
%%time
result = model.train(data_dict, outcome_dict)

Epoch 0/19
----------
  (0, 2)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 1787)	1.0
  (0, 7288)	1.0
  (0, 17973)	1.0
  (0, 743)	1.0
  (0, 8785)	1.0
  (0, 1642)	1.0
  (0, 17975)	1.0
  (0, 865)	1.0
  (0, 618)	1.0
  (0, 619)	1.0
  (0, 620)	1.0
  (0, 621)	1.0
  (0, 6496)	1.0
  (0, 11685)	1.0
  (0, 325)	1.0
  (0, 329)	1.0
  (0, 201)	1.0
  (0, 11689)	1.0
  (0, 50)	1.0
  (0, 17789)	1.0
  (0, 93)	1.0
  (0, 94)	1.0
  :	:
  (255, 4792)	1.0
  (255, 536)	1.0
  (255, 1312)	1.0
  (255, 538)	1.0
  (255, 1532)	1.0
  (255, 630)	1.0
  (255, 1152)	1.0
  (255, 541)	1.0
  (255, 545)	1.0
  (255, 12653)	1.0
  (255, 8518)	1.0
  (255, 5596)	1.0
  (255, 11148)	1.0
  (255, 21446)	1.0
  (255, 7556)	1.0
  (255, 878)	1.0
  (255, 4720)	1.0
  (255, 1882)	1.0
  (255, 1642)	1.0
  (255, 1883)	1.0
  (255, 49522)	1.0
  (255, 1862)	1.0
  (255, 1863)	1.0
  (255, 7991)	1.0
  (255, 22374)	1.0
  (0, 3)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 1890)	1.0
  (0, 7563)	1.0
  (0, 1801)	1.0
  (0, 12)	1.0
  (0, 13)	1.0
  (0, 15)	1.

  (0, 3)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 984)	1.0
  (0, 15585)	1.0
  (0, 30116)	1.0
  (0, 9318)	1.0
  (0, 65950)	1.0
  (0, 37720)	1.0
  (0, 37657)	1.0
  (0, 12)	1.0
  (0, 13)	1.0
  (0, 14)	1.0
  (0, 15)	1.0
  (0, 16)	1.0
  (0, 17)	1.0
  (0, 18)	1.0
  (0, 19)	1.0
  (0, 20)	1.0
  (0, 21)	1.0
  (0, 23)	1.0
  (0, 24)	1.0
  (0, 26)	1.0
  (0, 27)	1.0
  :	:
  (255, 5613)	1.0
  (255, 5697)	1.0
  (255, 7597)	1.0
  (255, 5615)	1.0
  (255, 7599)	1.0
  (255, 5616)	1.0
  (255, 5701)	1.0
  (255, 5703)	1.0
  (255, 5704)	1.0
  (255, 8593)	1.0
  (255, 13585)	1.0
  (255, 6301)	1.0
  (255, 13591)	1.0
  (255, 15114)	1.0
  (255, 7625)	1.0
  (255, 18113)	1.0
  (255, 7627)	1.0
  (255, 7628)	1.0
  (255, 7629)	1.0
  (255, 7630)	1.0
  (255, 7631)	1.0
  (255, 7632)	1.0
  (255, 7633)	1.0
  (255, 7634)	1.0
  (255, 7662)	1.0
  (0, 2)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 1118)	1.0
  (0, 669)	1.0
  (0, 19451)	1.0
  (0, 1803)	1.0
  (0, 42571)	1.0
  (0, 12)	1.0
  (0, 13)	1.0
  (0, 15)	1.0

  (0, 1)	1.0
  (0, 673)	1.0
  (0, 6)	1.0
  (0, 256)	1.0
  (0, 6336)	1.0
  (0, 8600)	1.0
  (0, 14424)	1.0
  (0, 6319)	1.0
  (0, 8611)	1.0
  (0, 6261)	1.0
  (0, 11508)	1.0
  (0, 6245)	1.0
  (0, 6265)	1.0
  (0, 8586)	1.0
  (0, 5827)	1.0
  (0, 8589)	1.0
  (0, 6272)	1.0
  (0, 6274)	1.0
  (0, 6280)	1.0
  (0, 6282)	1.0
  (0, 7605)	1.0
  (0, 8591)	1.0
  (0, 6284)	1.0
  (0, 6286)	1.0
  (0, 6288)	1.0
  :	:
  (255, 16130)	1.0
  (255, 1440)	1.0
  (255, 277)	1.0
  (255, 28728)	1.0
  (255, 2502)	1.0
  (255, 526)	1.0
  (255, 167)	1.0
  (255, 130691)	1.0
  (255, 8152)	1.0
  (255, 65965)	1.0
  (255, 521)	1.0
  (255, 529)	1.0
  (255, 1873)	1.0
  (255, 4182)	1.0
  (255, 6482)	1.0
  (255, 6477)	1.0
  (255, 608)	1.0
  (255, 1535)	1.0
  (255, 1537)	1.0
  (255, 59973)	1.0
  (255, 2046)	1.0
  (255, 1211)	1.0
  (255, 1403)	1.0
  (255, 3830)	1.0
  (255, 22160)	1.0
  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 14219)	1.0
  (0, 53613)	1.0
  (0, 6532)	1.0
  (0, 20015)	1.0
  (0, 37939)	1.0
  (0, 53615

  (0, 1)	1.0
  (0, 673)	1.0
  (0, 1787)	1.0
  (0, 1373)	1.0
  (0, 24664)	1.0
  (0, 29231)	1.0
  (0, 75369)	1.0
  (0, 3053)	1.0
  (0, 1098)	1.0
  (0, 1754)	1.0
  (0, 699)	1.0
  (0, 20769)	1.0
  (0, 20770)	1.0
  (0, 20771)	1.0
  (0, 20772)	1.0
  (0, 20773)	1.0
  (0, 14864)	1.0
  (0, 3467)	1.0
  (0, 20774)	1.0
  (0, 20775)	1.0
  (0, 20776)	1.0
  (0, 20777)	1.0
  (0, 3469)	1.0
  (0, 11699)	1.0
  (0, 1694)	1.0
  :	:
  (255, 608)	1.0
  (255, 545)	1.0
  (255, 1535)	1.0
  (255, 1536)	1.0
  (255, 6550)	1.0
  (255, 5302)	1.0
  (255, 1387)	1.0
  (255, 2443)	1.0
  (255, 8518)	1.0
  (255, 2960)	1.0
  (255, 43110)	1.0
  (255, 83322)	1.0
  (255, 70)	1.0
  (255, 198)	1.0
  (255, 1539)	1.0
  (255, 618)	1.0
  (255, 619)	1.0
  (255, 620)	1.0
  (255, 621)	1.0
  (255, 208)	1.0
  (255, 4016)	1.0
  (255, 84)	1.0
  (255, 2158)	1.0
  (255, 1403)	1.0
  (255, 138475)	1.0
  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 1890)	1.0
  (0, 7563)	1.0
  (0, 19451)	1.0
  (0, 65825)	1.0
  (0, 61579)	1.0
  (0, 

  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 14)	1.0
  (0, 19)	1.0
  (0, 23)	1.0
  (0, 24)	1.0
  (0, 711)	1.0
  (0, 712)	1.0
  (0, 34)	1.0
  (0, 35)	1.0
  (0, 36)	1.0
  (0, 37)	1.0
  (0, 38)	1.0
  (0, 45)	1.0
  (0, 49)	1.0
  (0, 52)	1.0
  (0, 53)	1.0
  (0, 57)	1.0
  (0, 58)	1.0
  (0, 59)	1.0
  (0, 777)	1.0
  (0, 16211)	1.0
  (0, 713)	1.0
  :	:
  (255, 496)	1.0
  (255, 468)	1.0
  (255, 86)	1.0
  (255, 25)	1.0
  (255, 46)	1.0
  (255, 47)	1.0
  (255, 70)	1.0
  (255, 84)	1.0
  (255, 7229)	1.0
  (255, 986)	1.0
  (255, 15198)	1.0
  (255, 524)	1.0
  (255, 528)	1.0
  (255, 166)	1.0
  (255, 532)	1.0
  (255, 533)	1.0
  (255, 168)	1.0
  (255, 538)	1.0
  (255, 177938)	1.0
  (255, 198)	1.0
  (255, 199)	1.0
  (255, 792)	1.0
  (255, 27849)	1.0
  (255, 12658)	1.0
  (255, 20351)	1.0
  (0, 3)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 163095)	1.0
  (0, 267)	1.0
  (0, 815)	1.0
  (0, 14)	1.0
  (0, 19)	1.0
  (0, 405)	1.0
  (0, 406)	1.0
  (0, 23)	1.0
  (0, 24)	1.0
  (0, 34)	1.0
  (0, 35

  (0, 2)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 1787)	1.0
  (0, 237)	1.0
  (0, 86)	1.0
  (0, 18760)	1.0
  (0, 11319)	1.0
  (0, 117751)	1.0
  (0, 12)	1.0
  (0, 13)	1.0
  (0, 14)	1.0
  (0, 769)	1.0
  (0, 15)	1.0
  (0, 279)	1.0
  (0, 16)	1.0
  (0, 17)	1.0
  (0, 18)	1.0
  (0, 19)	1.0
  (0, 20)	1.0
  (0, 21)	1.0
  (0, 22)	1.0
  (0, 23)	1.0
  (0, 770)	1.0
  (0, 24)	1.0
  :	:
  (255, 1955)	1.0
  (255, 5566)	1.0
  (255, 5567)	1.0
  (255, 5705)	1.0
  (255, 13641)	1.0
  (255, 74242)	1.0
  (255, 74256)	1.0
  (255, 74243)	1.0
  (255, 74257)	1.0
  (255, 74244)	1.0
  (255, 74245)	1.0
  (255, 74246)	1.0
  (255, 74247)	1.0
  (255, 74248)	1.0
  (255, 74249)	1.0
  (255, 23098)	1.0
  (255, 74250)	1.0
  (255, 5896)	1.0
  (255, 5287)	1.0
  (255, 2030)	1.0
  (255, 3383)	1.0
  (255, 3384)	1.0
  (255, 2418)	1.0
  (255, 4599)	1.0
  (255, 1434)	1.0
  (0, 2)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 16575)	1.0
  (0, 12589)	1.0
  (0, 5403)	1.0
  (0, 5971)	1.0
  (0, 2753)	1.0
  (0, 1759)	1.0
  (0, 19388)	1.0

  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 4837)	1.0
  (0, 4868)	1.0
  (0, 4874)	1.0
  (0, 4875)	1.0
  (0, 708)	1.0
  (0, 4850)	1.0
  (0, 36388)	1.0
  (0, 321)	1.0
  (0, 12)	1.0
  (0, 13)	1.0
  (0, 14)	1.0
  (0, 15)	1.0
  (0, 16)	1.0
  (0, 17)	1.0
  (0, 18)	1.0
  (0, 19)	1.0
  (0, 20)	1.0
  (0, 21)	1.0
  (0, 23)	1.0
  (0, 24)	1.0
  (0, 26)	1.0
  :	:
  (255, 1755)	1.0
  (255, 5965)	1.0
  (255, 5515)	1.0
  (255, 93274)	1.0
  (255, 193768)	1.0
  (255, 48989)	1.0
  (255, 115762)	1.0
  (255, 193769)	1.0
  (255, 242504)	1.0
  (255, 10229)	1.0
  (255, 5460)	1.0
  (255, 1758)	1.0
  (255, 2895)	1.0
  (255, 50236)	1.0
  (255, 37591)	1.0
  (255, 50238)	1.0
  (255, 13371)	1.0
  (255, 2221)	1.0
  (255, 78)	1.0
  (255, 4535)	1.0
  (255, 1658)	1.0
  (255, 6070)	1.0
  (255, 2894)	1.0
  (255, 2158)	1.0
  (255, 1403)	1.0
  (0, 2)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 1890)	1.0
  (0, 666)	1.0
  (0, 3762)	1.0
  (0, 6478)	1.0
  (0, 29254)	1.0
  (0, 14175)	1.0
  (0, 6645)	1.0
  (

  (0, 0)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 1787)	1.0
  (0, 243634)	1.0
  (0, 242878)	1.0
  (0, 224)	1.0
  (0, 772)	1.0
  (0, 12249)	1.0
  (0, 2736)	1.0
  (0, 14162)	1.0
  (0, 130271)	1.0
  (0, 321)	1.0
  (0, 65247)	1.0
  (0, 10756)	1.0
  (0, 33603)	1.0
  (0, 3247)	1.0
  (0, 4889)	1.0
  (0, 3087)	1.0
  (0, 29984)	1.0
  (0, 89956)	1.0
  (0, 110477)	1.0
  (0, 176570)	1.0
  (0, 17022)	1.0
  (0, 13514)	1.0
  :	:
  (255, 165)	1.0
  (255, 532)	1.0
  (255, 4182)	1.0
  (255, 1335)	1.0
  (255, 1532)	1.0
  (255, 730)	1.0
  (255, 1692)	1.0
  (255, 54108)	1.0
  (255, 8334)	1.0
  (255, 8336)	1.0
  (255, 58432)	1.0
  (255, 69)	1.0
  (255, 71)	1.0
  (255, 74)	1.0
  (255, 205)	1.0
  (255, 869)	1.0
  (255, 598)	1.0
  (255, 1278)	1.0
  (255, 209)	1.0
  (255, 2000)	1.0
  (255, 1402)	1.0
  (255, 1403)	1.0
  (255, 8338)	1.0
  (255, 51325)	1.0
  (255, 21624)	1.0
  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 10690)	1.0
  (0, 2270)	1.0
  (0, 27703)	1.0
  (0, 9323)	1.0
  (0, 455)	1.0
  (0, 896

  (0, 3)	1.0
  (0, 4)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 25)	1.0
  (0, 46)	1.0
  (0, 283)	1.0
  (0, 390)	1.0
  (0, 288)	1.0
  (0, 11149)	1.0
  (0, 4837)	1.0
  (0, 17047)	1.0
  (0, 321)	1.0
  (0, 75253)	1.0
  (0, 267)	1.0
  (0, 734)	1.0
  (0, 4892)	1.0
  (0, 18689)	1.0
  (0, 93)	1.0
  (0, 307)	1.0
  (0, 94)	1.0
  (0, 95)	1.0
  (0, 96)	1.0
  (0, 98)	1.0
  (0, 99)	1.0
  :	:
  (255, 534)	1.0
  (255, 1335)	1.0
  (255, 535)	1.0
  (255, 536)	1.0
  (255, 538)	1.0
  (255, 930)	1.0
  (255, 938)	1.0
  (255, 1313)	1.0
  (255, 1534)	1.0
  (255, 2094)	1.0
  (255, 545)	1.0
  (255, 884)	1.0
  (255, 1511)	1.0
  (255, 4373)	1.0
  (255, 3617)	1.0
  (255, 17140)	1.0
  (255, 35907)	1.0
  (255, 790)	1.0
  (255, 385)	1.0
  (255, 199)	1.0
  (255, 1055)	1.0
  (255, 792)	1.0
  (255, 1254)	1.0
  (255, 795)	1.0
  (255, 29629)	1.0
  (0, 0)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 4215)	1.0
  (0, 33858)	1.0
  (0, 18143)	1.0
  (0, 242039)	1.0
  (0, 4216)	1.0
  (0, 77288)	1.0
  (0, 2959)	1.0
  (0, 4211)	1.

  (0, 3)	1.0
  (0, 673)	1.0
  (0, 5)	1.0
  (0, 6)	1.0
  (0, 237)	1.0
  (0, 4797)	1.0
  (0, 9492)	1.0
  (0, 7974)	1.0
  (0, 1291)	1.0
  (0, 7961)	1.0
  (0, 8161)	1.0
  (0, 3040)	1.0
  (0, 20335)	1.0
  (0, 93)	1.0
  (0, 94)	1.0
  (0, 95)	1.0
  (0, 96)	1.0
  (0, 98)	1.0
  (0, 99)	1.0
  (0, 100)	1.0
  (0, 101)	1.0
  (0, 103)	1.0
  (0, 480)	1.0
  (0, 867)	1.0
  (0, 1613)	1.0
  :	:
  (255, 524)	1.0
  (255, 527)	1.0
  (255, 166)	1.0
  (255, 533)	1.0
  (255, 1312)	1.0
  (255, 538)	1.0
  (255, 763)	1.0
  (255, 7965)	1.0
  (255, 2235)	1.0
  (255, 1315)	1.0
  (255, 1552)	1.0
  (255, 17780)	1.0
  (255, 16403)	1.0
  (255, 7524)	1.0
  (255, 5083)	1.0
  (255, 27428)	1.0
  (255, 6543)	1.0
  (255, 38036)	1.0
  (255, 2039)	1.0
  (255, 198)	1.0
  (255, 199)	1.0
  (255, 225)	1.0
  (255, 16406)	1.0
  (255, 1434)	1.0
  (255, 46295)	1.0
  (0, 0)	1.0
  (0, 673)	1.0
  (0, 1787)	1.0
  (0, 1373)	1.0
  (0, 2196)	1.0
  (0, 2197)	1.0
  (0, 2198)	1.0
  (0, 2199)	1.0
  (0, 2200)	1.0
  (0, 2204)	1.0
  (0, 2205)	1.0
  

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/labs/shahlab/spfohl/miniconda3/envs/py_env/lib/python3.6/site-packages/IPython/core/magics/execution.py", line 1238, in time
    exec(code, glob, local_ns)
  File "<timed exec>", line 1, in <module>
  File "/home/spfohl/projects/fairness_cf/pytorch_utils/models.py", line 271, in train
    for the_data in loaders[phase]:
  File "/labs/shahlab/spfohl/miniconda3/envs/py_env/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 615, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/labs/shahlab/spfohl/miniconda3/envs/py_env/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 615, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/spfohl/projects/fairness_cf/pytorch_utils/datasets.py", line 21, in __getitem__
    return tuple(tensor[index] for tensor in self.tensors)
  File "/home/spfohl/projects/fairness_cf/pytorch_utils/datasets.py", line 21, in

KeyboardInterrupt: 

In [13]:
result_eval = model.predict(data_dict, outcome_dict, phases = ['val', 'test'])

KeyboardInterrupt: 

In [None]:
result_eval

In [None]:
## Save weights
model.save_weights(os.path.join(checkpoints_path, '{}.chk'.format(grid_element)))

In [None]:
result_df_training = model.process_result_dict(result)
result_df_eval = model.process_result_dict(result_eval[1])

print(result_df_training)
print(result_df_eval)

In [None]:
## Get performance by group
sensitive_variables = ['race_eth', 'gender', 'age']
data_dict_by_group = {sensitive_variable: {} for sensitive_variable in sensitive_variables}
outcome_dict_by_group = {sensitive_variable: {} for sensitive_variable in sensitive_variables}
for sensitive_variable in sensitive_variables:
    groups = np.unique(label_dict['train'][sensitive_variable])
    for group in groups:
        data_dict_by_group[sensitive_variable][group] = {split: 
                                       data_dict[split][label_dict[split][sensitive_variable] == group]
                                       for split in data_dict.keys()
                                      }
        outcome_dict_by_group[sensitive_variable][group] = {split: 
                                       outcome_dict[split][label_dict[split][sensitive_variable] == group]
                                       for split in data_dict.keys()
                                      }
result_df_by_group = pd.concat({sensitive_variable: 
                            pd.concat({
                                group: model.process_result_dict(model.predict(data_dict_by_group[sensitive_variable][group],
                                                    outcome_dict_by_group[sensitive_variable][group],
                                                    phases = ['val', 'test'])[1])
                                for group in data_dict_by_group[sensitive_variable].keys()
                            })
                            for sensitive_variable in data_dict_by_group.keys()
                           })
result_df_by_group.index = result_df_by_group.index.set_names(['sensitive_variable', 'group', 'index'])
result_df_by_group = result_df_by_group.reset_index(level = [0, 1])
result_df_by_group.head()

In [None]:
result_df_by_group

In [None]:
result_df_training.to_csv(os.path.join(performance_path, '{}_training'.format(grid_element)), index = False)
result_df_eval.to_csv(os.path.join(performance_path, '{}_eval'.format(grid_element)), index = False)
result_df_by_group.to_csv(os.path.join(performance_path, '{}_by_group'.format(grid_element)), index = False)