In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler

# Load data
data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data', header=None)

# Set column names
columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income']
data.columns = columns
data['race'] = np.where(data['race'] == ' Black', 'Black', 'Non-black')

# Convert target variable to binary values
data['income'] = np.where(data['income'] == ' >50K', 1, 0)

# Split data into features and target
X = data.drop('income', axis=1)
y = data['income']

# One-hot encode categorical variables
cat_cols = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
enc = OneHotEncoder(handle_unknown='ignore')
X_cat = enc.fit_transform(X[cat_cols]).toarray()

# Normalize numerical variables
num_cols = ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
scaler = StandardScaler()
X_num = scaler.fit_transform(X[num_cols])

# Combine categorical and numerical features
X = np.concatenate((X_num, X_cat), axis=1)

# Balance dataset
# ros = RandomOverSampler(random_state=0)
# X_resampled, y_resampled = ros.fit_resample(X, y)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Save training and testing data
np.save('X_train.npy', X_train)
np.save('y_train.npy', y_train)
np.save('X_test.npy', X_test)
np.save('y_test.npy', y_test)

feature_names = num_cols + enc.get_feature_names_out().tolist()
np.save("feature_names.npy", np.array(feature_names))

In [14]:
feature_names[60]

'race_Non-black'

In [15]:
X_train.shape

(26048, 105)