Skip to content

Commit

Permalink
1. Add parquet dataloaders
Browse files Browse the repository at this point in the history
2. Support both npz and parquet data format for model training
3. Support using preprocessed parquet data and skip rebuiding by setting `rebuild_dataset: False`
4. Support reading CSV and Parquet files as inputs
  • Loading branch information
xpai committed Apr 23, 2024
1 parent 776268d commit e98de98
Show file tree
Hide file tree
Showing 19 changed files with 402 additions and 132 deletions.
16 changes: 13 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
## FuxiCTR Versions

### FuxiCTR v2.2
### FuxiCTR v2.3
[Doing] Add support for saving pb file, exporting embeddings
[Doing] Add support of NVTabular data
[Doing] Add support of multi-gpu training

**FuxiCTR v2.3.0, 2024-04-20**
+ [Refactor] Support data format of npz and parquet

-------------------------------

### FuxiCTR v2.2

**FuxiCTR v2.2.3, 2024-04-20**
+ [Fix] Quick fix to v2.2.2 that miss one line when committing

**FuxiCTR v2.2.2, 2024-04-18**
**FuxiCTR v2.2.2, 2024-04-18 (Deprecated)**
+ [Feature] Update to use polars instead of pandas for faster feature processing
+ [Fix] When num_workers > 1, NpzBlockDataLoader cannot keep the reading order of samples ([#86](https://github.com/xue-pai/FuxiCTR/issues/86))

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
</div>

<div align="center">
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/python-3.6+-blue" style="max-width: 100%;" alt="Python version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/python-3.9+-blue" style="max-width: 100%;" alt="Python version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/pytorch-1.10+-blue" style="max-width: 100%;" alt="Pytorch version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/badge/tensorflow-2.1+-blue" style="max-width: 100%;" alt="Pytorch version"></a>
<a href="https://pypi.org/project/fuxictr"><img src="https://img.shields.io/pypi/v/fuxictr.svg" style="max-width: 100%;" alt="Pypi version"></a>
Expand Down Expand Up @@ -102,7 +102,7 @@ We have benchmarked FuxiCTR models on a set of open datasets as follows:

FuxiCTR has the following dependencies:

+ python 3.6+
+ python 3.9+
+ pytorch 1.10+ (required only for Torch models)
+ tensorflow 2.1+ (required only for TF models)

Expand Down
3 changes: 2 additions & 1 deletion experiment/run_expid.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@

data_dir = os.path.join(params['data_root'], params['dataset_id'])
feature_map_json = os.path.join(data_dir, "feature_map.json")
if params["data_format"] == "csv":
if params["data_format"] in ["csv", "parquet"]:
# Build feature_map and transform data
feature_encoder = FeatureProcessor(**params)
params["train_data"], params["valid_data"], params["test_data"] = \
build_dataset(feature_encoder, **params)
params["data_format"] = "parquet"
feature_map = FeatureMap(params['dataset_id'], data_dir)
feature_map.load(feature_map_json, params)
logging.info("Feature specs: " + print_to_json(feature_map.features))
Expand Down
4 changes: 3 additions & 1 deletion fuxictr/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def enumerate_params(config_file, exclude_expid=[]):
dataset_para_combs = dict()
for idx, values in enumerate(itertools.product(*map(dataset_dict.get, dataset_para_keys))):
dataset_params = dict(zip(dataset_para_keys, values))
if dataset_params["data_format"] == "npz":
if (dataset_params["data_format"] == "npz" or
(dataset_params["data_format"] == "parquet" and
dataset_params["rebuild_dataset"] == False)):
dataset_para_combs[dataset_id] = dataset_params
else:
hash_id = hashlib.md5("".join(sorted(print_to_json(dataset_params))).encode("utf-8")).hexdigest()[0:8]
Expand Down
2 changes: 1 addition & 1 deletion fuxictr/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def load(self, json_file, params):
feature_map = json.load(fd) #, object_pairs_hook=OrderedDict
if feature_map["dataset_id"] != self.dataset_id:
raise RuntimeError("dataset_id={} does not match feature_map!".format(self.dataset_id))
self.num_fields = feature_map["num_fields"]
self.labels = feature_map.get("labels", [])
self.total_features = feature_map.get("total_features", 0)
self.input_length = feature_map.get("input_length", 0)
self.group_id = feature_map.get("group_id", None)
self.default_emb_dim = params.get("embedding_dim", None)
self.features = OrderedDict((k, v) for x in feature_map["features"] for k, v in x.items())
self.num_fields = self.get_num_fields()
if params.get("use_features", None):
self.features = OrderedDict((x, self.features[x]) for x in params["use_features"])
if params.get("feature_specs", None):
Expand Down
53 changes: 27 additions & 26 deletions fuxictr/preprocess/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,12 @@ def split_train_test(train_ddf=None, valid_ddf=None, test_ddf=None, valid_size=0
return train_ddf, valid_ddf, test_ddf


def save_npz(darray_dict, data_path):
logging.info("Saving data to npz: " + data_path)
os.makedirs(os.path.dirname(data_path), exist_ok=True)
np.savez(data_path, **darray_dict)


def transform_block(feature_encoder, df_block, filename):
darray_dict = feature_encoder.transform(df_block)
save_npz(darray_dict, os.path.join(feature_encoder.data_dir, filename))
df_block = feature_encoder.transform(df_block)
data_path = os.path.join(feature_encoder.data_dir, filename)
logging.info("Saving data to parquet: " + data_path)
os.makedirs(os.path.dirname(data_path), exist_ok=True)
df_block.to_parquet(data_path, index=False, engine="pyarrow")


def transform(feature_encoder, ddf, filename, block_size=0):
Expand All @@ -68,9 +65,8 @@ def transform(feature_encoder, ddf, filename, block_size=0):
df_block = ddf.iloc[idx:(idx + block_size)]
pool.apply_async(
transform_block,
args=(feature_encoder,
df_block,
'{}/part_{:05d}.npz'.format(filename, block_id))
args=(feature_encoder, df_block,
'{}/part_{:05d}.parquet'.format(filename, block_id))
)
block_id += 1
pool.close()
Expand All @@ -79,37 +75,39 @@ def transform(feature_encoder, ddf, filename, block_size=0):
transform_block(feature_encoder, ddf, filename)


def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=None, valid_size=0,
test_size=0, split_type="sequential", data_block_size=0, **kwargs):
def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=None,
valid_size=0, test_size=0, split_type="sequential", data_block_size=0,
rebuild_dataset=True, **kwargs):
""" Build feature_map and transform data """

feature_map_json = os.path.join(feature_encoder.data_dir, "feature_map.json")
if os.path.exists(feature_map_json):
logging.warn("Skip rebuilding {}. Please delete it manually if rebuilding is required." \
.format(feature_map_json))
else:
# Load csv data
train_ddf = feature_encoder.read_csv(train_data, **kwargs)
feature_map_path = os.path.join(feature_encoder.data_dir, "feature_map.json")
if os.path.exists(feature_map_path):
logging.warn(f"Skip rebuilding {feature_map_path}. "
+ "Please delete it manually if rebuilding is required.")

elif rebuild_dataset:
# Load data files
train_ddf = feature_encoder.read_data(train_data, **kwargs)
valid_ddf = None
test_ddf = None

# Split data for train/validation/test
if valid_size > 0 or test_size > 0:
valid_ddf = feature_encoder.read_csv(valid_data, **kwargs)
test_ddf = feature_encoder.read_csv(test_data, **kwargs)
valid_ddf = feature_encoder.read_data(valid_data, **kwargs)
test_ddf = feature_encoder.read_data(test_data, **kwargs)
# TODO: check split_train_test in lazy mode
train_ddf, valid_ddf, test_ddf = split_train_test(train_ddf, valid_ddf, test_ddf,
valid_size, test_size, split_type)

# fit and transform train_ddf
train_ddf = feature_encoder.preprocess(train_ddf)
feature_encoder.fit(train_ddf, **kwargs)
feature_encoder.fit(train_ddf, rebuild_dataset=True, **kwargs)
transform(feature_encoder, train_ddf, 'train', block_size=data_block_size)
del train_ddf
gc.collect()

# Transfrom valid_ddf
if valid_ddf is None and (valid_data is not None):
valid_ddf = feature_encoder.read_csv(valid_data, **kwargs)
valid_ddf = feature_encoder.read_data(valid_data, **kwargs)
if valid_ddf is not None:
valid_ddf = feature_encoder.preprocess(valid_ddf)
transform(feature_encoder, valid_ddf, 'valid', block_size=data_block_size)
Expand All @@ -118,14 +116,17 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N

# Transfrom test_ddf
if test_ddf is None and (test_data is not None):
test_ddf = feature_encoder.read_csv(test_data, **kwargs)
test_ddf = feature_encoder.read_data(test_data, **kwargs)
if test_ddf is not None:
test_ddf = feature_encoder.preprocess(test_ddf)
transform(feature_encoder, test_ddf, 'test', block_size=data_block_size)
del test_ddf
gc.collect()
logging.info("Transform csv data to npz done.")

else: # skip rebuilding data but only compute feature_map.json
feature_encoder.fit(train_ddf=None, rebuild_dataset=False, **kwargs)

# Return processed data splits
return os.path.join(feature_encoder.data_dir, "train"), \
os.path.join(feature_encoder.data_dir, "valid"), \
Expand Down
80 changes: 49 additions & 31 deletions fuxictr/preprocess/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,18 @@ def _complete_feature_cols(self, feature_cols):
full_feature_cols.append(col)
return full_feature_cols

def read_csv(self, data_path, sep=",", n_rows=None, **kwargs):
logging.info("Reading file: " + data_path)
def read_data(self, data_path, data_format="csv", sep=",", n_rows=None, **kwargs):
if not data_path.endswith(data_format):
data_path = os.path.join(data_path, "*.{data_format}")
logging.info("Reading files: " + data_path)
file_names = sorted(glob.glob(data_path))
assert len(file_names) > 0, f"Invalid data path: {data_path}"
# Require python >= 3.8 for use polars to scan multiple csv files
file_names = file_names[0]
ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict,
low_memory=False, n_rows=n_rows)
dfs = [
pl.scan_csv(source=file_name, separator=sep, dtypes=self.dtype_dict,
low_memory=False, n_rows=n_rows)
for file_name in file_names
]
ddf = pl.concat(dfs)
return ddf

def preprocess(self, ddf):
Expand All @@ -95,14 +99,18 @@ def preprocess(self, ddf):
ddf = ddf.select(active_cols)
return ddf

def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs):
def fit(self, train_ddf, min_categr_count=1, num_buckets=10, rebuild_dataset=True, **kwargs):
logging.info("Fit feature processor...")
self.rebuild_dataset = rebuild_dataset
for col in self.feature_cols:
name = col["name"]
if col["active"]:
logging.info("Processing column: {}".format(col))
col_series = train_ddf.select(name).collect().to_series().to_pandas()
if col["type"] == "meta": # e.g. group_id
col_series = (
train_ddf.select(name).collect().to_series().to_pandas() if self.rebuild_dataset
else None
)
if col["type"] == "meta": # e.g. set group_id in gAUC
self.fit_meta_col(col)
elif col["type"] == "numeric":
self.fit_numeric_col(col, col_series)
Expand Down Expand Up @@ -154,9 +162,9 @@ def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs):

self.feature_map.num_fields = self.feature_map.get_num_fields()
self.feature_map.set_column_index()
self.feature_map.save(self.json_file)
self.save_pickle(self.pickle_file)
self.save_vocab(self.vocab_file)
self.feature_map.save(self.json_file)
logging.info("Set feature processor done.")

def fit_meta_col(self, col):
Expand All @@ -178,7 +186,8 @@ def fit_numeric_col(self, col, col_series):
self.feature_map.features[name]["feature_encoder"] = col["feature_encoder"]
if "normalizer" in col:
normalizer = Normalizer(col["normalizer"])
normalizer.fit(col_series.dropna().values)
if self.rebuild_dataset:
normalizer.fit(col_series.dropna().values)
self.processor_dict[name + "::normalizer"] = normalizer

def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=10):
Expand All @@ -196,9 +205,15 @@ def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=1
self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"]
if "category_processor" not in col:
tokenizer = Tokenizer(min_freq=min_categr_count,
na_value=col.get("fill_na", ""),
na_value=col.get("fill_na", ""),
remap=col.get("remap", True))
tokenizer.fit_on_texts(col_series)
if self.rebuild_dataset:
tokenizer.fit_on_texts(col_series)
else:
if "vocab_size" in col:
tokenizer.update_vocab(range(col["vocab_size"] - 1))
else:
raise ValueError(f"{name}: vocab_size is required when rebuild_dataset=False")
if "share_embedding" in col:
self.feature_map.features[name]["share_embedding"] = col["share_embedding"]
tknzr_name = col["share_embedding"] + "::tokenizer"
Expand All @@ -217,10 +232,11 @@ def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=1
if category_processor == "quantile_bucket": # transform numeric value to bucket
num_buckets = col.get("num_buckets", num_buckets)
qtf = sklearn_preprocess.QuantileTransformer(n_quantiles=num_buckets + 1)
qtf.fit(col_series.values)
boundaries = qtf.quantiles_[1:-1]
if self.rebuild_dataset:
qtf.fit(col_series.values)
boundaries = qtf.quantiles_[1:-1]
self.processor_dict[name + "::boundaries"] = boundaries
self.feature_map.features[name]["vocab_size"] = num_buckets
self.processor_dict[name + "::boundaries"] = boundaries
elif category_processor == "hash_bucket":
num_buckets = col.get("num_buckets", num_buckets)
self.feature_map.features[name]["vocab_size"] = num_buckets
Expand Down Expand Up @@ -249,7 +265,13 @@ def fit_sequence_col(self, col, col_series, min_categr_count=1):
tokenizer = Tokenizer(min_freq=min_categr_count, splitter=splitter,
na_value=na_value, max_len=max_len, padding=padding,
remap=col.get("remap", True))
tokenizer.fit_on_texts(col_series)
if self.rebuild_dataset:
tokenizer.fit_on_texts(col_series)
else:
if "vocab_size" in col:
tokenizer.update_vocab(range(col["vocab_size"] - 1))
else:
raise ValueError(f"{name}: vocab_size is required when rebuild_dataset=False")
if "share_embedding" in col:
self.feature_map.features[name]["share_embedding"] = col["share_embedding"]
tknzr_name = col["share_embedding"] + "::tokenizer"
Expand All @@ -265,40 +287,36 @@ def fit_sequence_col(self, col, col_series, min_categr_count=1):
"vocab_size": tokenizer.vocab_size()})

def transform(self, ddf):
logging.info("Transform feature columns with ID mapping...")
data_dict = dict()
logging.info("Transform feature columns to IDs...")
for feature, feature_spec in self.feature_map.features.items():
if feature in ddf.columns:
feature_type = feature_spec["type"]
col_series = ddf[feature]
if feature_type == "meta":
if feature + "::tokenizer" in self.processor_dict:
tokenizer = self.processor_dict[feature + "::tokenizer"]
data_dict[feature] = tokenizer.encode_meta(col_series)
ddf[feature] = tokenizer.encode_meta(col_series)
# Update vocab in tokenizer
self.processor_dict[feature + "::tokenizer"] = tokenizer
else:
data_dict[feature] = col_series.values
elif feature_type == "numeric":
col_values = col_series.values
normalizer = self.processor_dict.get(feature + "::normalizer")
if normalizer:
col_values = normalizer.transform(col_values)
data_dict[feature] = col_values
ddf[feature] = normalizer.transform(col_series.values)
elif feature_type == "categorical":
category_processor = feature_spec.get("category_processor")
if category_processor is None:
data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_category(col_series)
ddf[feature] = (
self.processor_dict.get(feature + "::tokenizer")
.encode_category(col_series)
)
elif category_processor == "numeric_bucket":
raise NotImplementedError
elif category_processor == "hash_bucket":
raise NotImplementedError
elif feature_type == "sequence":
data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_sequence(col_series)
for label in self.feature_map.labels:
if label in ddf.columns:
data_dict[label] = ddf[label].values
return data_dict
ddf[feature] = (self.processor_dict.get(feature + "::tokenizer")
.encode_sequence(col_series))
return ddf

def load_pickle(self, pickle_file=None):
""" Load feature processor from cache """
Expand Down

0 comments on commit e98de98

Please sign in to comment.