Skip to content

Commit

Permalink
Merge pull request #101 from richford/bf/label-encode
Browse files Browse the repository at this point in the history
BF: Fix label encoding in the presence of NaN labels
  • Loading branch information
richford committed Dec 7, 2021
2 parents 362c9d2 + b85da9a commit d6334e3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
21 changes: 20 additions & 1 deletion afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def load_afq_data(

le = LabelEncoder()
for col in label_encode_cols:
y.loc[:, col] = le.fit_transform(y[col])
y.loc[:, col] = le.fit_transform(y[col].fillna("NaN"))
classes[col] = le.classes_
else:
classes = None
Expand Down Expand Up @@ -484,6 +484,7 @@ def __init__(

self.groups = afq_data.groups
self.feature_names = afq_data.feature_names
self.target_cols = target_cols
self.group_names = afq_data.group_names
self.subjects = afq_data.subjects
self.sessions = afq_data.sessions
Expand Down Expand Up @@ -511,6 +512,24 @@ def drop_target_na(self):
nan_mask = nan_mask.astype(int).sum(axis=1).astype(bool)

nan_mask = ~nan_mask

# This nan_mask contains booleans for float NaN values
# But we also potentially label encoded NaNs above so we need to
# check for the string "NaN" in the encoded labels
nan_encoding = {
label: "NaN" in vals for label, vals in self.classes.items()
}
for label, nan_encoded in nan_encoding.items():
if nan_encoded:
encoded_value = np.where(self.classes[label] == "NaN")[0][0]
encoded_col = self.target_cols.index(label)
if len(self.y.shape) > 1:
nan_mask = np.logical_and(
nan_mask, self.y[:, encoded_col] != encoded_value
)
else:
nan_mask = np.logical_and(nan_mask, self.y != encoded_value)

self.X = self.X[nan_mask]
self.y = self.y[nan_mask]
self.subjects = [sub for mask, sub in zip(nan_mask, self.subjects) if mask]
Expand Down
57 changes: 54 additions & 3 deletions afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,60 @@ def test_standardize_subject_id():
assert standardize_subject_id("01") == "sub-01"


def test_afqdataset_label_encode():
sub_dicts = [
{"subject_id": "1", "age": 0, "site": "A"},
{"subject_id": "2", "age": 1, "site": "B"},
{"subject_id": "3", "age": 2},
]
node_dicts = [
{"subjectID": "sub-1", "tractID": "A", "nodeID": 0, "fa": 0.1},
{"subjectID": "sub-1", "tractID": "A", "nodeID": 1, "fa": 0.2},
{"subjectID": "sub-1", "tractID": "B", "nodeID": 0, "fa": 0.3},
{"subjectID": "sub-1", "tractID": "B", "nodeID": 1, "fa": 0.3},
{"subjectID": "sub-2", "tractID": "A", "nodeID": 0, "fa": 0.4},
{"subjectID": "sub-2", "tractID": "A", "nodeID": 1, "fa": 0.5},
{"subjectID": "sub-2", "tractID": "B", "nodeID": 0, "fa": 0.6},
{"subjectID": "sub-2", "tractID": "B", "nodeID": 1, "fa": 0.6},
{"subjectID": "3", "tractID": "A", "nodeID": 0, "fa": 0.7},
{"subjectID": "3", "tractID": "A", "nodeID": 1, "fa": 0.8},
{"subjectID": "3", "tractID": "B", "nodeID": 0, "fa": 0.9},
{"subjectID": "3", "tractID": "B", "nodeID": 1, "fa": 0.9},
]
subs = pd.DataFrame(sub_dicts)
nodes = pd.DataFrame(node_dicts)

with tempfile.TemporaryDirectory() as temp_dir:
subs.to_csv(op.join(temp_dir, "subjects.csv"), index=False)
nodes.to_csv(op.join(temp_dir, "nodes.csv"), index=False)

tmp_dataset = afqi.AFQDataset(
fn_nodes=op.join(temp_dir, "nodes.csv"),
fn_subjects=op.join(temp_dir, "subjects.csv"),
target_cols=["site"],
dwi_metrics=["fa"],
index_col="subject_id",
label_encode_cols=["site"],
)

assert tmp_dataset.y.shape == (3,)
tmp_dataset.drop_target_na()
assert tmp_dataset.y.shape == (2,)

tmp_dataset = afqi.AFQDataset(
fn_nodes=op.join(temp_dir, "nodes.csv"),
fn_subjects=op.join(temp_dir, "subjects.csv"),
target_cols=["age", "site"],
dwi_metrics=["fa"],
index_col="subject_id",
label_encode_cols=["site"],
)

assert tmp_dataset.y.shape == (3, 2)
tmp_dataset.drop_target_na()
assert tmp_dataset.y.shape == (2, 2)


def test_afqdataset_sub_prefix():
sub_dicts = [
{"subject_id": "1", "age": 0},
Expand All @@ -61,9 +115,6 @@ def test_afqdataset_sub_prefix():
subs = pd.DataFrame(sub_dicts)
nodes = pd.DataFrame(node_dicts)

import os.path as op
import tempfile

with tempfile.TemporaryDirectory() as temp_dir:
subs.to_csv(op.join(temp_dir, "subjects.csv"), index=False)
nodes.to_csv(op.join(temp_dir, "nodes.csv"), index=False)
Expand Down

0 comments on commit d6334e3

Please sign in to comment.