In [16]:
from ml.data import process_data
from ml.model import train_model, compute_model_metrics, inference, sliced_metrics
import pandas as pd
from sklearn.model_selection import train_test_split


In [17]:
data = pd.read_csv("../data/census.csv")

In [18]:
data.columns = data.columns.str.replace(" ", "")
print(data.columns)

Index(['age', 'workclass', 'fnlgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'salary'],
      dtype='object')


In [19]:
train, test = train_test_split(data, test_size=0.20, random_state=42)

In [20]:
cat_features = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]

# Process the data
X_train, y_train, encoder, lb = process_data(
    train, categorical_features=cat_features, label="salary", training=True
)

# Process the test data
X_test, y_test, _, _ = process_data(
    test, categorical_features=cat_features, label="salary", training=False, encoder=encoder, lb=lb
)



In [31]:
train.columns

Index(['age', 'workclass', 'fnlgt', 'education', 'education-num',
       'marital-status', 'occupation', 'relationship', 'race', 'sex',
       'capital-gain', 'capital-loss', 'hours-per-week', 'native-country',
       'salary'],
      dtype='object')

In [32]:
X_train.shape

(26048, 108)

In [21]:
model = train_model(X_train, y_train)


In [22]:
preds = inference(model, X_test)
precision, recall, fbeta = compute_model_metrics(y_test, preds)

In [23]:
precision

0.7794585987261147

In [24]:
recall

0.6231699554423934

In [25]:
fbeta

0.6926070038910507

In [26]:
all_slice_metrics = sliced_metrics(model, X_test, y_test, test, cat_features)

In [27]:
all_slice_metrics

{'workclass': {' Private': {'precision': 0.7923076923076923,
   'recall': 0.6173826173826173,
   'fbeta': 0.6939921392476136,
   'samples': 4578},
  ' State-gov': {'precision': 0.75,
   'recall': 0.6986301369863014,
   'fbeta': 0.7234042553191489,
   'samples': 254},
  ' Self-emp-not-inc': {'precision': 0.7474747474747475,
   'recall': 0.4713375796178344,
   'fbeta': 0.5781250000000001,
   'samples': 498},
  ' Self-emp-inc': {'precision': 0.7622950819672131,
   'recall': 0.788135593220339,
   'fbeta': 0.775,
   'samples': 212},
  ' Federal-gov': {'precision': 0.746268656716418,
   'recall': 0.7142857142857143,
   'fbeta': 0.7299270072992701,
   'samples': 191},
  ' Local-gov': {'precision': 0.7623762376237624,
   'recall': 0.7,
   'fbeta': 0.7298578199052131,
   'samples': 387},
  ' ?': {'precision': 0.8421052631578947,
   'recall': 0.38095238095238093,
   'fbeta': 0.5245901639344263,
   'samples': 389},
  ' Without-pay': {'precision': 1.0,
   'recall': 1.0,
   'fbeta': 1.0,
   'sample