In [1]:
from pathlib import Path

import torch
import numpy as np
import pandas as pd

In [2]:
import logging
logging.getLogger().setLevel(logging.INFO)

## Set important local paths

Set your root directory:

In [3]:
ROOT_DIR = Path("~").expanduser().resolve()
# ROOT_DIR = Path("/fast/groups/sf")    # Path to cluster dir
ROOT_DIR

PosixPath('/Users/acruz')

Directory where LLMs are saved:

In [4]:
MODELS_DIR = ROOT_DIR / "huggingface-models"

Directory where data is saved (or will be saved to):

In [5]:
DATA_DIR = ROOT_DIR / "data"

Other configs:

In [6]:
MODEL_NAME = "google/gemma-2b"
# MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
# MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

TASK_NAME = "ACSIncome"

RESULTS_ROOT_DIR = ROOT_DIR / "folktexts-results"

DEVICE = "cuda" if torch.cuda.is_available() else "mps"

In [7]:
from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)
model, tokenizer = load_model_tokenizer(model_folder_path)

INFO:root:Loading model '/Users/acruz/huggingface-models/google--gemma-2b'
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

INFO:root:Moving model to device: mps


In [8]:
results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name
results_dir.mkdir(exist_ok=True, parents=True)
results_dir

PosixPath('/Users/acruz/folktexts-results/google--gemma-2b')

### Construct LLM Classifier

Load prediction task (which maps tabular data to text):

In [9]:
from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(TASK_NAME)

In [10]:
from folktexts.classifier import LLMClassifier
llm_clf = LLMClassifier(
    model=model,
    tokenizer=tokenizer,
    task=task,
    batch_size=16,
)

### Load Dataset

In [11]:
%%time
from folktexts.acs import ACSDataset
dataset = ACSDataset(task=task, cache_dir=DATA_DIR)

Loading ACS data...
CPU times: user 33.8 s, sys: 8.86 s, total: 42.7 s
Wall time: 43.1 s


Optionally, subsample to quickly get approximate results:

In [12]:
dataset = dataset.subsample(0.01)
print(f"{dataset.subsampling=}")

INFO:root:Subsampled dataset to 1.0% of the original size. Train size: 14981, Test size: 1665, Val size: 0;


dataset.subsampling=0.01


### Load and run ACS Benchmark

**_Note:_** Helper constructors exist at `CalibrationBenchmark.make_acs_benchmark` and `CalibrationBenchmark.make_benchmark` that avoid the above boilerplate code.

In [13]:
from folktexts.benchmark import CalibrationBenchmark, BenchmarkConfig

bench = CalibrationBenchmark(
    llm_clf=llm_clf,
    dataset=dataset,
    results_dir=results_dir,
    config=BenchmarkConfig.default_config(),
)

INFO:root:
** Benchmark initialization **
Model: google--gemma-2b;
Task: ACSIncome;
Results dir: /Users/acruz/folktexts-results/google--gemma-2b/google--gemma-2b_bench-1971313738;
Hash: 1971313738;



Optionally, you can fit the model's threshold on a few data samples.

This is generally quite fast as it is _not fine-tuning_; it only changes one parameter: the `llm_clf.threshold`.

In [14]:
%%time
X_sample, y_sample = dataset.sample_n_train_examples(n=100)
llm_clf.fit(X_sample, y_sample)

Computing risk estimates:   0%|          | 0/7 [00:00<?, ?it/s]

INFO:root:Set threshold to 0.49806201550387597.


CPU times: user 1.72 s, sys: 833 ms, total: 2.55 s
Wall time: 22.2 s


Run benchmark...

In [15]:
%%time
bench.run()

INFO:root:Test data features shape: (1665, 10)


Computing risk estimates:   0%|          | 0/105 [00:00<?, ?it/s]

INFO:root:
** Test results **
Model balanced accuracy:  64.1%;
Model accuracy:           61.4%;
Model ROC AUC :           67.3%;



CPU times: user 13.2 s, sys: 1min 9s, total: 1min 22s
Wall time: 9min 31s


0.10749160111839026

In [19]:
bench.plot_results()
bench.save_results()
from pprint import pprint
pprint(bench.results, depth=1)

INFO:root:Skipping group 3 plot as it's too small.
INFO:root:Skipping group 4 plot as it's too small.
INFO:root:Skipping group 5 plot as it's too small.
INFO:root:Skipping group 7 plot as it's too small.
INFO:root:Saving JSON file to '/Users/acruz/folktexts-results/google--gemma-2b/google--gemma-2b_bench-1971313738/results.bench-3131106163.json'
INFO:root:Saved experiment results to '/Users/acruz/folktexts-results/google--gemma-2b/google--gemma-2b_bench-1971313738'


{'accuracy': 0.6144144144144145,
 'accuracy_diff': 1.0,
 'accuracy_ratio': 0.0,
 'balanced_accuracy': 0.6413802870090635,
 'balanced_accuracy_diff': 0.6593359488520778,
 'balanced_accuracy_ratio': 0.0,
 'brier_score_loss': 0.24640440358741766,
 'config': {...},
 'ece': 0.10749160111839026,
 'ece_quantile': 0.11862214134564972,
 'equalized_odds_diff': 1.0,
 'equalized_odds_ratio': 0.0,
 'fnr': 0.21875,
 'fnr_diff': 1.0,
 'fnr_ratio': 0.0,
 'fpr': 0.4984894259818731,
 'fpr_diff': 1.0,
 'fpr_ratio': 0.0,
 'log_loss': 0.6859491970007683,
 'model_name': 'google--gemma-2b',
 'n_negatives': 993,
 'n_positives': 672,
 'n_samples': 1665,
 'plots': {...},
 'ppr': 0.6126126126126126,
 'ppr_diff': 1.0,
 'ppr_ratio': 0.0,
 'precision': 0.5147058823529411,
 'precision_diff': 0.65,
 'precision_ratio': 0.0,
 'predictions_path': '/Users/acruz/folktexts-results/google--gemma-2b/google--gemma-2b_bench-1971313738/ACSIncome_subsampled-0.01_seed-42_hash-1830080098.test_predictions.csv',
 'roc_auc': 0.673336

---