# Run `folktexts` benchmark with a custom ACS task

This notebook describes how to define a custom prediction task based on 
American Community Survey (ACS) data.

In [None]:
import folktexts
folktexts.__version__

In [None]:
from pathlib import Path

**Note:** Change the `DATA_DIR` variable to the ACS data path on your file system (if any, otherwise will attempt to download data).

In [None]:
DATA_DIR = Path("/fast/groups/sf") / "data"

### Define custom task objects appropriately:

In this example, we'll try to predict whether someone is or ever was part of the Military, based on a series of demographic features.

We must define objects of the following classes:
- `Threshold`: Used to binarize the target column (not needed if target is already binary);
- `MultipleChoiceQA`: Defines a question and answer scheme for the target column;
- `TaskMetadata`: Defines which columns to use as features, and which to use as the prediction target;

Select the target column (among those defined in `folktexts.acs.acs_columns`):

In [None]:
TARGET_COLUMN = "MIL"

Define threshold to be applied to target column:

In [None]:
from folktexts.threshold import Threshold
target_threshold = Threshold(4, "!=")

# data["MIL"] != 4 means "Is on active duty in the military or was in the past"
# data["MIL"] == 4 means "Never served in the military"

Define question and answer interface used for prompting:

In [None]:
from folktexts.qa_interface import MultipleChoiceQA
from folktexts.qa_interface import Choice

target_column_qa = MultipleChoiceQA(
    column=target_threshold.apply_to_column_name(TARGET_COLUMN),
    text="Has this person ever served in the military?",
    choices=(
        Choice("Yes, this person is now on active duty or was in the past", 1),
        Choice("No, this person has never served in the military", 0),
    ),
)

Define task metadata:

In [None]:
from folktexts.acs import ACSTaskMetadata

task = ACSTaskMetadata.make_task(
    name="ACSMilitary",
    description="predict if a person has ever served in the military",
    features=[
        "AGEP", "SCHL", "MAR", "POBP", "WKHP", "SEX", "RAC1P", "ST", "CIT", "DIS", "PINCP",
    ],
    target=TARGET_COLUMN,
    target_threshold=target_threshold,
    sensitive_attribute="RAC1P",
    multiple_choice_qa=target_column_qa,
)

Load ACS dataset for the custom task:

In [None]:
%%time
from folktexts.acs.acs_dataset import ACSDataset
acs_dataset = ACSDataset.make_from_task(
    task=task,
    cache_dir=DATA_DIR,)

**Note:** This dataset will contain all samples in the ACS PUMS file, you can
now filter the `acs_dataset.data` if you want to use only a portion of the data 
(e.g, only people in California, or only people above 18 years old).

As an example, we'll filter the ACS data to contain only people above 18 years old.

In [None]:
print(f"Original number of samples: {len(acs_dataset.data)}")

acs_dataset.data = acs_dataset.data[acs_dataset.data["AGEP"] >= 18]
print(f"Parsed number of samples: {len(acs_dataset.data)}")

**Optionally,** you can subsample the dataset to have faster but noisier results.

In [None]:
acs_dataset.subsample(0.01)

### Load LLM from huggingface or from local disk

In [None]:
MODEL_NAME_OR_PATH = "meta-llama/Meta-Llama-3-8B"

In [None]:
%%time
from folktexts.llm_utils import load_model_tokenizer
model, tokenizer = load_model_tokenizer(MODEL_NAME_OR_PATH)

Create `LLMClassifier` object.

In [None]:
from folktexts.classifier import TransformersLLMClassifier

llm_clf = TransformersLLMClassifier(
    model=model,
    tokenizer=tokenizer,
    task=task,
    batch_size=20,
    context_size=600,
)

**Validation:** Check that model prompts with the custom task make sense, or if anything needs to be adjusted.

In [None]:
X_sample, _y_sample = acs_dataset.sample_n_train_examples(n=1)
print(llm_clf.encode_row(X_sample.iloc[0], question=llm_clf.task.question))

### Run benchmark on the custom task

In [None]:
from folktexts.benchmark import Benchmark
bench = Benchmark(llm_clf=llm_clf, dataset=acs_dataset)

In [None]:
%%time
bench.run(results_root_dir=".")

In [None]:
bench.plot_results();