Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ACS dataset #433

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conduit/data/datasets/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(

if self.transform is not None:
self.x = self.transform(self.x)
if self.target_transform is not None:
self.y = self.target_transform(self.x)
if self.target_transform is not None and self.y is not None:
self.y = self.target_transform(self.y)

self.cont_indexes = cont_indexes
self.disc_indexes = disc_indexes
Expand Down
3 changes: 1 addition & 2 deletions conduit/data/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ def __init__(
# pytorch_lightning inspects the signature and if it sees `**kwargs`, it assumes that
# the __init__ takes all arguments that DataLoader.__init__ takes, so we have to
# manually remove "collate_fn" here in order to avoid passing it in *twice*.
if "collate_fn" in kwargs:
del kwargs["collate_fn"] # type: ignore
kwargs.pop("collate_fn", None) # type: ignore
super().__init__(
dataset, # type: ignore
collate_fn=cdt_collate(cast_to_sample=cast_to_sample, converter=converter),
Expand Down
6 changes: 3 additions & 3 deletions conduit/data/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,16 +591,16 @@ def __imul__(self, value: float) -> Self:
def __mul__(self, value: float) -> Self:
copy = gcopy(self, deep=True)
copy *= value
return copy # pyright: ignore
return copy

def __idiv__(self, value: float) -> Self:
self *= 1 / value
return self # pyright: ignore
return self

def __div__(self, value: float) -> Self:
copy = gcopy(self, deep=True)
copy *= 1 / value
return copy # pyright: ignore
return copy


R_co = TypeVar("R_co", covariant=True)
Expand Down
1 change: 1 addition & 0 deletions conduit/fair/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .acs import *
from .dummy import *
from .ethicml import *
from .vision import *
188 changes: 188 additions & 0 deletions conduit/fair/data/datasets/acs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from collections.abc import Iterable
from enum import Enum, auto
from typing import TypeAlias

from folktables import (
ACSDataSource,
ACSEmployment,
ACSIncome,
BasicProblem,
generate_categories,
)
import numpy as np
import pandas as pd
from torch import Tensor

from conduit.data.datasets.tabular.base import CdtTabularDataset
from conduit.data.structures import TernarySample
from conduit.transforms.tabular import TabularTransform

__all__ = ["ACSSetting", "ACSDataset", "ACSState", "ACSHorizon", "ACSSurvey", "SurveyYear"]


class SurveyYear(Enum):
YEAR_2014 = "2014"
YEAR_2015 = "2015"
YEAR_2016 = "2016"
YEAR_2017 = "2017"
YEAR_2018 = "2018"


class ACSState(Enum):
AL = auto()
AK = auto()
AZ = auto()
AR = auto()
CA = auto()
CO = auto()
CT = auto()
DE = auto()
FL = auto()
GA = auto()
HI = auto()
ID = auto()
IL = auto()
IN = auto()
IA = auto()
KS = auto()
KY = auto()
LA = auto()
ME = auto()
MD = auto()
MA = auto()
MI = auto()
MN = auto()
MS = auto()
MO = auto()
MT = auto()
NE = auto()
NV = auto()
NH = auto()
NJ = auto()
NM = auto()
NY = auto()
NC = auto()
ND = auto()
OH = auto()
OK = auto()
OR = auto()
PA = auto()
RI = auto()
SC = auto()
SD = auto()
TN = auto()
TX = auto()
UT = auto()
VT = auto()
VA = auto()
WA = auto()
WV = auto()
WI = auto()
WY = auto()
PR = auto()


class ACSHorizon(Enum):
ONE_YEAR = "1-Year"
FIVE_YEARS = "5-Year"


class ACSSurvey(Enum):
PERSON = "person"
HOUSEHOLD = "household"


class ACSSetting(Enum):
employment = ACSEmployment
income = ACSIncome


class ACSDataset(CdtTabularDataset[TernarySample, Tensor, Tensor]):
"""Wrapper for the ACS dataset from Folktables."""

Setting: TypeAlias = ACSSetting
Horizon: TypeAlias = ACSHorizon
State: TypeAlias = ACSState
Survey: TypeAlias = ACSSurvey
SurveyYear: TypeAlias = SurveyYear

def __init__(
self,
setting: ACSSetting,
survey_year: SurveyYear = SurveyYear.YEAR_2018,
horizon: Horizon = Horizon.ONE_YEAR,
survey: ACSSurvey = ACSSurvey.PERSON,
states: Iterable[ACSState] = (ACSState.AL,),
transform: TabularTransform | None = None,
target_transform: TabularTransform | None = None,
):
data_source = ACSDataSource(
survey_year=survey_year.value, horizon=horizon.value, survey=survey.value
)
acs_data = data_source.get_data(states=[state.name for state in states], download=True)
dataset: BasicProblem = setting.value

# `generate_categories` is only available for years >= 2017.
if int(survey_year.value) >= 2017:
categories = generate_categories(
features=dataset.features,
definition_df=data_source.get_definitions(download=True),
)

# One-hot encoding based on the categories.
features_df, label_df, group_df = dataset.df_to_pandas(
acs_data, categories=categories, dummies=True
)

feature_groups, disc_indexes = feature_groups_from_categories(categories, features_df)
label = label_df.to_numpy(dtype=np.int64)
group = group_df.to_numpy(dtype=np.int64)
features = features_df.to_numpy(dtype=np.float32)
cont_indexes = list(set(range(features.shape[1])) - set(disc_indexes))
else:
# Categorical features are *not* one-hot encoded for years < 2017.
features, label, group = dataset.df_to_numpy(acs_data)
cont_indexes, disc_indexes, feature_groups = None, None, None

super().__init__(
x=features,
y=label,
s=group,
transform=transform,
target_transform=target_transform,
cont_indexes=cont_indexes,
disc_indexes=disc_indexes,
feature_groups=feature_groups,
)


CategoryName: TypeAlias = str
ValueName: TypeAlias = str


def feature_groups_from_categories(
categories: dict[CategoryName, dict[float, ValueName]], features: pd.DataFrame
) -> tuple[list[slice], list[int]]:
slices: list[slice] = []
disc_indexes: list[int] = []

for category_name, value_entries in categories.items():
indexes = []
for value_name in value_entries.values():
feature_name = f"{category_name}_{value_name}"
if feature_name in features.columns:
feature_index = features.columns.get_loc(feature_name)
indexes.append(feature_index)

# Determine the slice bounds for this category.
start = min(indexes)
stop = max(indexes) + 1

# Check that the indexes are contiguous.
index_set = set(indexes)
for i in range(start, stop):
assert i in index_set

slices.append(slice(start, stop))
disc_indexes.extend(indexes)
return slices, disc_indexes
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ hydra-core = {version = "^1.1.1", optional = true}
rich = {version = "^12.5.1", optional = true}
# fair
ethicml = {version = "^1.2.1", extras = ["data"], optional = true}
folktables = {version = "^0.0.12", optional = true}

[tool.poetry.extras]
download = ["gdown", "kaggle"]
image = ["albumentations", "opencv-python"]
audio = ["soundfile", "sox"]
hydra = ["hydra-core"]
logging = ["rich"]
fair = ["ethicml"]
fair = ["ethicml", "folktables"]
all = [
"gdown", "kaggle", # download
"albumentations", "opencv-python", # image
"soundfile", "sox", # audio
"hydra-core", # hydra
"rich", # logging
"ethicml", # fair
"ethicml", "folktables", # fair
]

[[tool.poetry.source]]
Expand Down
12 changes: 11 additions & 1 deletion tests/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SSRP,
Waterbirds,
)
from conduit.fair.data.datasets import DummyDataset
from conduit.fair.data.datasets import ACSDataset, DummyDataset


@pytest.mark.parametrize("greyscale", [True, False])
Expand Down Expand Up @@ -427,3 +427,13 @@ def test_stratified_split():

assert n_train == pytest.approx(0.45 * n_all, abs=1)
assert n_test == pytest.approx(0.55 * n_all, abs=1)


def test_acs_dataset() -> None:
acs_income = ACSDataset(setting=ACSDataset.Setting.income)
assert acs_income.feature_groups is not None
assert acs_income.feature_groups[0] == slice(2, 10)
assert acs_income.x.shape == (22_268, 729)
assert acs_income.s.shape == (22_268,)
assert acs_income.y.shape == (22_268,)
assert acs_income.cont_indexes == [0, 1]
13 changes: 13 additions & 0 deletions typings/folktables/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .acs import ACSDataSource as ACSDataSource
from .acs import ACSEmployment as ACSEmployment
from .acs import ACSEmploymentFiltered as ACSEmploymentFiltered
from .acs import ACSHealthInsurance as ACSHealthInsurance
from .acs import ACSIncome as ACSIncome
from .acs import ACSIncomePovertyRatio as ACSIncomePovertyRatio
from .acs import ACSMobility as ACSMobility
from .acs import ACSPublicCoverage as ACSPublicCoverage
from .acs import ACSTravelTime as ACSTravelTime
from .folktables import BasicProblem as BasicProblem
from .folktables import DataSource as DataSource
from .folktables import Problem as Problem
from .load_acs import generate_categories as generate_categories
Loading
Loading