diff --git a/examples/pzmm_generate_complete_model_card.ipynb b/examples/pzmm_generate_complete_model_card.ipynb index 60124580..0dd78641 100644 --- a/examples/pzmm_generate_complete_model_card.ipynb +++ b/examples/pzmm_generate_complete_model_card.ipynb @@ -567,7 +567,11 @@ " df = pd.get_dummies(df, columns=[\"WorkClass\", \"Education\", \"MartialStatus\", \"Relationship\", \"Race\", \"Sex\"])\n", " df.columns = df.columns.str.replace(' ', '')\n", " df.columns = df.columns.str.replace('-', '_')\n", - " df = df.drop(['Sex_Male'], axis=1)\n", + " # Ensure that Sex_Female column exists and Sex_Male column is dropped if it exists\n", + " if 'Sex_Female' not in df.columns:\n", + " df['Sex_Female'] = 0\n", + " if 'Sex_Male' in df.columns:\n", + " df = df.drop(['Sex_Male'], axis=1)\n", " if 'index' in df.columns or 'index' in cat_vals.columns:\n", " df = pd.concat([df, cat_vals], axis=1).drop('index', axis=1)\n", " # For the model to score correctly, all OHE columns must exist\n",