diff --git a/CHANGELOG.md b/CHANGELOG.md index 74ebf69..be05abc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,20 @@ ## FuxiCTR Versions -### FuxiCTR v2.2 +### FuxiCTR v2.3 [Doing] Add support for saving pb file, exporting embeddings -[Doing] Add support of NVTabular data +[Doing] Add support of multi-gpu training + +**FuxiCTR v2.3.0, 2024-04-20** ++ [Refactor] Support data format of npz and parquet + +------------------------------- + +### FuxiCTR v2.2 + +**FuxiCTR v2.2.3, 2024-04-20** ++ [Fix] Quick fix to v2.2.2 that miss one line when committing -**FuxiCTR v2.2.2, 2024-04-18** +**FuxiCTR v2.2.2, 2024-04-18 (Deprecated)** + [Feature] Update to use polars instead of pandas for faster feature processing + [Fix] When num_workers > 1, NpzBlockDataLoader cannot keep the reading order of samples ([#86](https://github.com/xue-pai/FuxiCTR/issues/86)) diff --git a/README.md b/README.md index 0b809bc..7fa7733 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
-Python version +Python version Pytorch version Pytorch version Pypi version @@ -102,7 +102,7 @@ We have benchmarked FuxiCTR models on a set of open datasets as follows: FuxiCTR has the following dependencies: -+ python 3.6+ ++ python 3.9+ + pytorch 1.10+ (required only for Torch models) + tensorflow 2.1+ (required only for TF models) diff --git a/experiment/run_expid.py b/experiment/run_expid.py index 7acff75..f42cfaf 100644 --- a/experiment/run_expid.py +++ b/experiment/run_expid.py @@ -52,11 +52,12 @@ data_dir = os.path.join(params['data_root'], params['dataset_id']) feature_map_json = os.path.join(data_dir, "feature_map.json") - if params["data_format"] == "csv": + if params["data_format"] in ["csv", "parquet"]: # Build feature_map and transform data feature_encoder = FeatureProcessor(**params) params["train_data"], params["valid_data"], params["test_data"] = \ build_dataset(feature_encoder, **params) + params["data_format"] = "parquet" feature_map = FeatureMap(params['dataset_id'], data_dir) feature_map.load(feature_map_json, params) logging.info("Feature specs: " + print_to_json(feature_map.features)) diff --git a/fuxictr/autotuner.py b/fuxictr/autotuner.py index 5e4ba85..fe0ede0 100644 --- a/fuxictr/autotuner.py +++ b/fuxictr/autotuner.py @@ -69,7 +69,9 @@ def enumerate_params(config_file, exclude_expid=[]): dataset_para_combs = dict() for idx, values in enumerate(itertools.product(*map(dataset_dict.get, dataset_para_keys))): dataset_params = dict(zip(dataset_para_keys, values)) - if dataset_params["data_format"] == "npz": + if (dataset_params["data_format"] == "npz" or + (dataset_params["data_format"] == "parquet" and + dataset_params["rebuild_dataset"] == False)): dataset_para_combs[dataset_id] = dataset_params else: hash_id = hashlib.md5("".join(sorted(print_to_json(dataset_params))).encode("utf-8")).hexdigest()[0:8] diff --git a/fuxictr/features.py b/fuxictr/features.py index a98a48f..7726c11 100644 --- a/fuxictr/features.py +++ b/fuxictr/features.py @@ -41,13 +41,13 @@ def load(self, json_file, params): feature_map = json.load(fd) #, object_pairs_hook=OrderedDict if feature_map["dataset_id"] != self.dataset_id: raise RuntimeError("dataset_id={} does not match feature_map!".format(self.dataset_id)) - self.num_fields = feature_map["num_fields"] self.labels = feature_map.get("labels", []) self.total_features = feature_map.get("total_features", 0) self.input_length = feature_map.get("input_length", 0) self.group_id = feature_map.get("group_id", None) self.default_emb_dim = params.get("embedding_dim", None) self.features = OrderedDict((k, v) for x in feature_map["features"] for k, v in x.items()) + self.num_fields = self.get_num_fields() if params.get("use_features", None): self.features = OrderedDict((x, self.features[x]) for x in params["use_features"]) if params.get("feature_specs", None): diff --git a/fuxictr/preprocess/build_dataset.py b/fuxictr/preprocess/build_dataset.py index 07a9d65..8578908 100644 --- a/fuxictr/preprocess/build_dataset.py +++ b/fuxictr/preprocess/build_dataset.py @@ -48,15 +48,12 @@ def split_train_test(train_ddf=None, valid_ddf=None, test_ddf=None, valid_size=0 return train_ddf, valid_ddf, test_ddf -def save_npz(darray_dict, data_path): - logging.info("Saving data to npz: " + data_path) - os.makedirs(os.path.dirname(data_path), exist_ok=True) - np.savez(data_path, **darray_dict) - - def transform_block(feature_encoder, df_block, filename): - darray_dict = feature_encoder.transform(df_block) - save_npz(darray_dict, os.path.join(feature_encoder.data_dir, filename)) + df_block = feature_encoder.transform(df_block) + data_path = os.path.join(feature_encoder.data_dir, filename) + logging.info("Saving data to parquet: " + data_path) + os.makedirs(os.path.dirname(data_path), exist_ok=True) + df_block.to_parquet(data_path, index=False, engine="pyarrow") def transform(feature_encoder, ddf, filename, block_size=0): @@ -68,9 +65,8 @@ def transform(feature_encoder, ddf, filename, block_size=0): df_block = ddf.iloc[idx:(idx + block_size)] pool.apply_async( transform_block, - args=(feature_encoder, - df_block, - '{}/part_{:05d}.npz'.format(filename, block_id)) + args=(feature_encoder, df_block, + '{}/part_{:05d}.parquet'.format(filename, block_id)) ) block_id += 1 pool.close() @@ -79,37 +75,39 @@ def transform(feature_encoder, ddf, filename, block_size=0): transform_block(feature_encoder, ddf, filename) -def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=None, valid_size=0, - test_size=0, split_type="sequential", data_block_size=0, **kwargs): +def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=None, + valid_size=0, test_size=0, split_type="sequential", data_block_size=0, + rebuild_dataset=True, **kwargs): """ Build feature_map and transform data """ - - feature_map_json = os.path.join(feature_encoder.data_dir, "feature_map.json") - if os.path.exists(feature_map_json): - logging.warn("Skip rebuilding {}. Please delete it manually if rebuilding is required." \ - .format(feature_map_json)) - else: - # Load csv data - train_ddf = feature_encoder.read_csv(train_data, **kwargs) + feature_map_path = os.path.join(feature_encoder.data_dir, "feature_map.json") + if os.path.exists(feature_map_path): + logging.warn(f"Skip rebuilding {feature_map_path}. " + + "Please delete it manually if rebuilding is required.") + + elif rebuild_dataset: + # Load data files + train_ddf = feature_encoder.read_data(train_data, **kwargs) valid_ddf = None test_ddf = None # Split data for train/validation/test if valid_size > 0 or test_size > 0: - valid_ddf = feature_encoder.read_csv(valid_data, **kwargs) - test_ddf = feature_encoder.read_csv(test_data, **kwargs) + valid_ddf = feature_encoder.read_data(valid_data, **kwargs) + test_ddf = feature_encoder.read_data(test_data, **kwargs) + # TODO: check split_train_test in lazy mode train_ddf, valid_ddf, test_ddf = split_train_test(train_ddf, valid_ddf, test_ddf, valid_size, test_size, split_type) # fit and transform train_ddf train_ddf = feature_encoder.preprocess(train_ddf) - feature_encoder.fit(train_ddf, **kwargs) + feature_encoder.fit(train_ddf, rebuild_dataset=True, **kwargs) transform(feature_encoder, train_ddf, 'train', block_size=data_block_size) del train_ddf gc.collect() # Transfrom valid_ddf if valid_ddf is None and (valid_data is not None): - valid_ddf = feature_encoder.read_csv(valid_data, **kwargs) + valid_ddf = feature_encoder.read_data(valid_data, **kwargs) if valid_ddf is not None: valid_ddf = feature_encoder.preprocess(valid_ddf) transform(feature_encoder, valid_ddf, 'valid', block_size=data_block_size) @@ -118,7 +116,7 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N # Transfrom test_ddf if test_ddf is None and (test_data is not None): - test_ddf = feature_encoder.read_csv(test_data, **kwargs) + test_ddf = feature_encoder.read_data(test_data, **kwargs) if test_ddf is not None: test_ddf = feature_encoder.preprocess(test_ddf) transform(feature_encoder, test_ddf, 'test', block_size=data_block_size) @@ -126,6 +124,9 @@ def build_dataset(feature_encoder, train_data=None, valid_data=None, test_data=N gc.collect() logging.info("Transform csv data to npz done.") + else: # skip rebuilding data but only compute feature_map.json + feature_encoder.fit(train_ddf=None, rebuild_dataset=False, **kwargs) + # Return processed data splits return os.path.join(feature_encoder.data_dir, "train"), \ os.path.join(feature_encoder.data_dir, "valid"), \ diff --git a/fuxictr/preprocess/feature_processor.py b/fuxictr/preprocess/feature_processor.py index 8072c19..1e8498b 100644 --- a/fuxictr/preprocess/feature_processor.py +++ b/fuxictr/preprocess/feature_processor.py @@ -67,14 +67,18 @@ def _complete_feature_cols(self, feature_cols): full_feature_cols.append(col) return full_feature_cols - def read_csv(self, data_path, sep=",", n_rows=None, **kwargs): - logging.info("Reading file: " + data_path) + def read_data(self, data_path, data_format="csv", sep=",", n_rows=None, **kwargs): + if not data_path.endswith(data_format): + data_path = os.path.join(data_path, "*.{data_format}") + logging.info("Reading files: " + data_path) file_names = sorted(glob.glob(data_path)) assert len(file_names) > 0, f"Invalid data path: {data_path}" - # Require python >= 3.8 for use polars to scan multiple csv files - file_names = file_names[0] - ddf = pl.scan_csv(source=file_names, separator=sep, dtypes=self.dtype_dict, - low_memory=False, n_rows=n_rows) + dfs = [ + pl.scan_csv(source=file_name, separator=sep, dtypes=self.dtype_dict, + low_memory=False, n_rows=n_rows) + for file_name in file_names + ] + ddf = pl.concat(dfs) return ddf def preprocess(self, ddf): @@ -95,14 +99,18 @@ def preprocess(self, ddf): ddf = ddf.select(active_cols) return ddf - def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs): + def fit(self, train_ddf, min_categr_count=1, num_buckets=10, rebuild_dataset=True, **kwargs): logging.info("Fit feature processor...") + self.rebuild_dataset = rebuild_dataset for col in self.feature_cols: name = col["name"] if col["active"]: logging.info("Processing column: {}".format(col)) - col_series = train_ddf.select(name).collect().to_series().to_pandas() - if col["type"] == "meta": # e.g. group_id + col_series = ( + train_ddf.select(name).collect().to_series().to_pandas() if self.rebuild_dataset + else None + ) + if col["type"] == "meta": # e.g. set group_id in gAUC self.fit_meta_col(col) elif col["type"] == "numeric": self.fit_numeric_col(col, col_series) @@ -154,9 +162,9 @@ def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs): self.feature_map.num_fields = self.feature_map.get_num_fields() self.feature_map.set_column_index() + self.feature_map.save(self.json_file) self.save_pickle(self.pickle_file) self.save_vocab(self.vocab_file) - self.feature_map.save(self.json_file) logging.info("Set feature processor done.") def fit_meta_col(self, col): @@ -178,7 +186,8 @@ def fit_numeric_col(self, col, col_series): self.feature_map.features[name]["feature_encoder"] = col["feature_encoder"] if "normalizer" in col: normalizer = Normalizer(col["normalizer"]) - normalizer.fit(col_series.dropna().values) + if self.rebuild_dataset: + normalizer.fit(col_series.dropna().values) self.processor_dict[name + "::normalizer"] = normalizer def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=10): @@ -196,9 +205,15 @@ def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=1 self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"] if "category_processor" not in col: tokenizer = Tokenizer(min_freq=min_categr_count, - na_value=col.get("fill_na", ""), + na_value=col.get("fill_na", ""), remap=col.get("remap", True)) - tokenizer.fit_on_texts(col_series) + if self.rebuild_dataset: + tokenizer.fit_on_texts(col_series) + else: + if "vocab_size" in col: + tokenizer.update_vocab(range(col["vocab_size"] - 1)) + else: + raise ValueError(f"{name}: vocab_size is required when rebuild_dataset=False") if "share_embedding" in col: self.feature_map.features[name]["share_embedding"] = col["share_embedding"] tknzr_name = col["share_embedding"] + "::tokenizer" @@ -217,10 +232,11 @@ def fit_categorical_col(self, col, col_series, min_categr_count=1, num_buckets=1 if category_processor == "quantile_bucket": # transform numeric value to bucket num_buckets = col.get("num_buckets", num_buckets) qtf = sklearn_preprocess.QuantileTransformer(n_quantiles=num_buckets + 1) - qtf.fit(col_series.values) - boundaries = qtf.quantiles_[1:-1] + if self.rebuild_dataset: + qtf.fit(col_series.values) + boundaries = qtf.quantiles_[1:-1] + self.processor_dict[name + "::boundaries"] = boundaries self.feature_map.features[name]["vocab_size"] = num_buckets - self.processor_dict[name + "::boundaries"] = boundaries elif category_processor == "hash_bucket": num_buckets = col.get("num_buckets", num_buckets) self.feature_map.features[name]["vocab_size"] = num_buckets @@ -249,7 +265,13 @@ def fit_sequence_col(self, col, col_series, min_categr_count=1): tokenizer = Tokenizer(min_freq=min_categr_count, splitter=splitter, na_value=na_value, max_len=max_len, padding=padding, remap=col.get("remap", True)) - tokenizer.fit_on_texts(col_series) + if self.rebuild_dataset: + tokenizer.fit_on_texts(col_series) + else: + if "vocab_size" in col: + tokenizer.update_vocab(range(col["vocab_size"] - 1)) + else: + raise ValueError(f"{name}: vocab_size is required when rebuild_dataset=False") if "share_embedding" in col: self.feature_map.features[name]["share_embedding"] = col["share_embedding"] tknzr_name = col["share_embedding"] + "::tokenizer" @@ -265,8 +287,7 @@ def fit_sequence_col(self, col, col_series, min_categr_count=1): "vocab_size": tokenizer.vocab_size()}) def transform(self, ddf): - logging.info("Transform feature columns with ID mapping...") - data_dict = dict() + logging.info("Transform feature columns to IDs...") for feature, feature_spec in self.feature_map.features.items(): if feature in ddf.columns: feature_type = feature_spec["type"] @@ -274,31 +295,28 @@ def transform(self, ddf): if feature_type == "meta": if feature + "::tokenizer" in self.processor_dict: tokenizer = self.processor_dict[feature + "::tokenizer"] - data_dict[feature] = tokenizer.encode_meta(col_series) + ddf[feature] = tokenizer.encode_meta(col_series) # Update vocab in tokenizer self.processor_dict[feature + "::tokenizer"] = tokenizer - else: - data_dict[feature] = col_series.values elif feature_type == "numeric": - col_values = col_series.values normalizer = self.processor_dict.get(feature + "::normalizer") if normalizer: - col_values = normalizer.transform(col_values) - data_dict[feature] = col_values + ddf[feature] = normalizer.transform(col_series.values) elif feature_type == "categorical": category_processor = feature_spec.get("category_processor") if category_processor is None: - data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_category(col_series) + ddf[feature] = ( + self.processor_dict.get(feature + "::tokenizer") + .encode_category(col_series) + ) elif category_processor == "numeric_bucket": raise NotImplementedError elif category_processor == "hash_bucket": raise NotImplementedError elif feature_type == "sequence": - data_dict[feature] = self.processor_dict.get(feature + "::tokenizer").encode_sequence(col_series) - for label in self.feature_map.labels: - if label in ddf.columns: - data_dict[label] = ddf[label].values - return data_dict + ddf[feature] = (self.processor_dict.get(feature + "::tokenizer") + .encode_sequence(col_series)) + return ddf def load_pickle(self, pickle_file=None): """ Load feature processor from cache """ diff --git a/fuxictr/preprocess/tokenizer.py b/fuxictr/preprocess/tokenizer.py index c23bd59..a01beeb 100644 --- a/fuxictr/preprocess/tokenizer.py +++ b/fuxictr/preprocess/tokenizer.py @@ -22,6 +22,7 @@ from keras_preprocessing.sequence import pad_sequences from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp +from ..utils import load_pretrain_emb class Tokenizer(object): @@ -96,7 +97,7 @@ def update_vocab(self, word_list): new_words = 0 for word in word_list: if word not in self.vocab: - self.vocab[word] = self.vocab["__OOV__"] + new_words + self.vocab[word] = self.vocab.get("__OOV__", 0) + new_words new_words += 1 if new_words > 0: self.vocab["__OOV__"] = self.vocab_size() @@ -122,16 +123,12 @@ def encode_sequence(self, series): seqs = pad_sequences(series.to_list(), maxlen=self.max_len, value=self.vocab["__PAD__"], padding=self.padding, truncating=self.padding) - return np.array(seqs) + return seqs.tolist() def load_pretrained_vocab(self, feature_dtype, pretrain_path, expand_vocab=True): - if pretrain_path.endswith(".h5"): - with h5py.File(pretrain_path, 'r') as hf: - keys = hf["key"][:] - # in case mismatch of dtype between int and str - keys = keys.astype(feature_dtype) - elif pretrain_path.endswith(".npz"): - keys = np.load(pretrain_path)["key"] + keys = load_pretrain_emb(pretrain_path, keys=["key"]) + # in case mismatch of dtype between int and str + keys = keys.astype(feature_dtype) # Update vocab with pretrained keys in case new tokens appear in validation or test set # Do NOT update OOV index here since it is used in PretrainedEmbedding if expand_vocab: diff --git a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py index eb81592..2f57061 100644 --- a/fuxictr/pytorch/dataloaders/npz_block_dataloader.py +++ b/fuxictr/pytorch/dataloaders/npz_block_dataloader.py @@ -18,15 +18,16 @@ import numpy as np from itertools import chain -import torch -from torch.utils import data +from torch.utils.data.dataloader import default_collate +from torch.utils.data import IterDataPipe, DataLoader, get_worker_info import glob +import os -class BlockDataPipe(data.IterDataPipe): - def __init__(self, block_datapipe, feature_map): +class NpzIterDataPipe(IterDataPipe): + def __init__(self, data_blocks, feature_map): self.feature_map = feature_map - self.block_datapipe = block_datapipe + self.data_blocks = data_blocks def load_data(self, data_path): data_dict = np.load(data_path) @@ -38,8 +39,7 @@ def load_data(self, data_path): data_arrays.append(array.reshape(-1, 1)) else: data_arrays.append(array) - data_tensor = torch.from_numpy(np.hstack(data_arrays)) - return data_tensor + return np.hstack(data_arrays) def read_block(self, data_block): darray = self.load_data(data_block) @@ -47,37 +47,39 @@ def read_block(self, data_block): yield darray[idx, :] def __iter__(self): - worker_info = data.get_worker_info() + worker_info = get_worker_info() if worker_info is None: # single-process data loading - block_list = self.block_datapipe + block_list = self.data_blocks else: # in a worker process block_list = [ block - for idx, block in enumerate(self.block_datapipe) + for idx, block in enumerate(self.data_blocks) if idx % worker_info.num_workers == worker_info.id ] return chain.from_iterable(map(self.read_block, block_list)) -class NpzBlockDataLoader(data.DataLoader): - def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, +class NpzBlockDataLoader(DataLoader): + def __init__(self, feature_map, data_path, split="train", batch_size=32, shuffle=False, num_workers=1, buffer_size=100000, **kwargs): - data_blocks = glob.glob(data_path + "/*.npz") + if not data_path.endswith("npz"): + data_path = os.path.join(data_path, "*.npz") + data_blocks = sorted(glob.glob(data_path)) # sort by part name assert len(data_blocks) > 0, f"invalid data_path: {data_path}" - if len(data_blocks) > 1: - data_blocks.sort() # sort by part name self.data_blocks = data_blocks self.num_blocks = len(self.data_blocks) self.feature_map = feature_map self.batch_size = batch_size self.num_batches, self.num_samples = self.count_batches_and_samples() - datapipe = BlockDataPipe(data_blocks, feature_map) + datapipe = NpzIterDataPipe(self.data_blocks, feature_map) if shuffle: datapipe = datapipe.shuffle(buffer_size=buffer_size) - else: - num_workers = 1 # multiple workers cannot keep the order of data reading - super(NpzBlockDataLoader, self).__init__(dataset=datapipe, batch_size=batch_size, - num_workers=num_workers) + elif split == "test": + num_workers = 1 # multiple workers cannot keep the order of data reading + super().__init__(dataset=datapipe, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) def __len__(self): return self.num_batches @@ -89,3 +91,16 @@ def count_batches_and_samples(self): num_samples += block_size num_batches = int(np.ceil(num_samples / self.batch_size)) return num_batches, num_samples + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/npz_dataloader.py b/fuxictr/pytorch/dataloaders/npz_dataloader.py index e5740a1..057674a 100644 --- a/fuxictr/pytorch/dataloaders/npz_dataloader.py +++ b/fuxictr/pytorch/dataloaders/npz_dataloader.py @@ -16,11 +16,11 @@ import numpy as np -from torch.utils import data -import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import default_collate -class Dataset(data.Dataset): +class NpzDataset(Dataset): def __init__(self, feature_map, data_path): self.feature_map = feature_map self.darray = self.load_data(data_path) @@ -41,20 +41,33 @@ def load_data(self, data_path): data_arrays.append(array.reshape(-1, 1)) else: data_arrays.append(array) - data_tensor = torch.from_numpy(np.hstack(data_arrays)) - return data_tensor + return np.hstack(data_arrays) -class NpzDataLoader(data.DataLoader): +class NpzDataLoader(DataLoader): def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, num_workers=1, **kwargs): if not data_path.endswith(".npz"): data_path += ".npz" - self.dataset = Dataset(feature_map, data_path) + self.dataset = NpzDataset(feature_map, data_path) super(NpzDataLoader, self).__init__(dataset=self.dataset, batch_size=batch_size, - shuffle=shuffle, num_workers=num_workers) + shuffle=shuffle, num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) self.num_samples = len(self.dataset) self.num_blocks = 1 self.num_batches = int(np.ceil(self.num_samples * 1.0 / self.batch_size)) def __len__(self): return self.num_batches + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py new file mode 100644 index 0000000..e901aae --- /dev/null +++ b/fuxictr/pytorch/dataloaders/parquet_block_dataloader.py @@ -0,0 +1,108 @@ +# ========================================================================= +# Copyright (C) 2023-2024. FuxiCTR Authors. All rights reserved. +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import numpy as np +from itertools import chain +from torch.utils.data.dataloader import default_collate +from torch.utils.data import IterDataPipe, DataLoader, get_worker_info +import glob +import polars as pl +import pandas as pd +import os + + +class ParquetIterDataPipe(IterDataPipe): + def __init__(self, data_blocks, feature_map): + self.feature_map = feature_map + self.data_blocks = data_blocks + + def load_data(self, data_path): + df = pd.read_parquet(data_path) + data_arrays = [] + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + for col in all_cols: + array = np.array(df[col].to_list()) + if array.ndim == 1: + data_arrays.append(array.reshape(-1, 1)) + else: + data_arrays.append(array) + return np.hstack(data_arrays) + + def read_block(self, data_block): + darray = self.load_data(data_block) + for idx in range(darray.shape[0]): + yield darray[idx, :] + + def __iter__(self): + worker_info = get_worker_info() + if worker_info is None: # single-process data loading + block_list = self.data_blocks + else: # in a worker process + block_list = [ + block + for idx, block in enumerate(self.data_blocks) + if idx % worker_info.num_workers == worker_info.id + ] + return chain.from_iterable(map(self.read_block, block_list)) + + +class ParquetBlockDataLoader(DataLoader): + def __init__(self, feature_map, data_path, split="train", batch_size=32, shuffle=False, + num_workers=1, buffer_size=100000, **kwargs): + if not data_path.endswith("parquet"): + data_path = os.path.join(data_path, "*.parquet") + data_blocks = sorted(glob.glob(data_path)) # sort by part name + assert len(data_blocks) > 0, f"invalid data_path: {data_path}" + self.data_blocks = data_blocks + self.num_blocks = len(self.data_blocks) + self.feature_map = feature_map + self.batch_size = batch_size + self.num_batches, self.num_samples = self.count_batches_and_samples() + datapipe = ParquetIterDataPipe(self.data_blocks, feature_map) + if shuffle: + datapipe = datapipe.shuffle(buffer_size=buffer_size) + elif split == "test": + num_workers = 1 # multiple workers cannot keep the order of data reading + super().__init__(dataset=datapipe, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) + + def __len__(self): + return self.num_batches + + def count_batches_and_samples(self): + num_samples = 0 + for data_block in self.data_blocks: + df = pl.scan_parquet(data_block) + num_samples += df.select(pl.count()).collect().item() + num_batches = int(np.ceil(num_samples / self.batch_size)) + return num_batches, num_samples + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/parquet_dataloader.py b/fuxictr/pytorch/dataloaders/parquet_dataloader.py new file mode 100644 index 0000000..1c461d1 --- /dev/null +++ b/fuxictr/pytorch/dataloaders/parquet_dataloader.py @@ -0,0 +1,75 @@ +# ========================================================================= +# Copyright (C) 2024. FuxiCTR Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +import numpy as np +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import default_collate +import pandas as pd + + +class ParquetDataset(Dataset): + def __init__(self, feature_map, data_path): + self.feature_map = feature_map + self.darray = self.load_data(data_path) + + def __getitem__(self, index): + return self.darray[index, :] + + def __len__(self): + return self.darray.shape[0] + + def load_data(self, data_path): + df = pd.read_parquet(data_path) + data_arrays = [] + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + for col in all_cols: + array = np.array(df[col].to_list()) + if array.ndim == 1: + data_arrays.append(array.reshape(-1, 1)) + else: + data_arrays.append(array) + return np.hstack(data_arrays) + + +class ParquetDataLoader(DataLoader): + def __init__(self, feature_map, data_path, batch_size=32, shuffle=False, + num_workers=1, **kwargs): + if not data_path.endswith(".parquet"): + data_path += ".parquet" + self.dataset = ParquetDataset(feature_map, data_path) + super().__init__(dataset=self.dataset, batch_size=batch_size, + shuffle=shuffle, num_workers=num_workers, + collate_fn=BatchCollator(feature_map)) + self.num_samples = len(self.dataset) + self.num_blocks = 1 + self.num_batches = int(np.ceil(self.num_samples / self.batch_size)) + + def __len__(self): + return self.num_batches + + +class BatchCollator(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, batch): + batch_tensor = default_collate(batch) + all_cols = list(self.feature_map.features.keys()) + self.feature_map.labels + batch_dict = dict() + for col in all_cols: + batch_dict[col] = batch_tensor[:, self.feature_map.get_column_index(col)] + return batch_dict diff --git a/fuxictr/pytorch/dataloaders/rank_dataloader.py b/fuxictr/pytorch/dataloaders/rank_dataloader.py index 1ca830f..faaab59 100644 --- a/fuxictr/pytorch/dataloaders/rank_dataloader.py +++ b/fuxictr/pytorch/dataloaders/rank_dataloader.py @@ -17,29 +17,48 @@ from .npz_block_dataloader import NpzBlockDataLoader from .npz_dataloader import NpzDataLoader +from .parquet_block_dataloader import ParquetBlockDataLoader +from .parquet_dataloader import ParquetDataLoader import logging class RankDataLoader(object): def __init__(self, feature_map, stage="both", train_data=None, valid_data=None, test_data=None, - batch_size=32, shuffle=True, streaming=False, **kwargs): + batch_size=32, shuffle=True, streaming=False, data_format="npz", **kwargs): logging.info("Loading datasets...") train_gen = None valid_gen = None test_gen = None - DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader + if data_format == "npz": + DataLoader = NpzBlockDataLoader if streaming else NpzDataLoader + elif data_format == "parquet": + DataLoader = ParquetBlockDataLoader if streaming else ParquetDataLoader + else: + raise ValueError(f"data_format={data_format} not supported.") self.stage = stage if stage in ["both", "train"]: - train_gen = DataLoader(feature_map, train_data, batch_size=batch_size, shuffle=shuffle, **kwargs) - logging.info("Train samples: total/{:d}, blocks/{:d}".format(train_gen.num_samples, train_gen.num_blocks)) + train_gen = DataLoader(feature_map, train_data, split="train", batch_size=batch_size, + shuffle=shuffle, **kwargs) + logging.info( + "Train samples: total/{:d}, blocks/{:d}" + .format(train_gen.num_samples, train_gen.num_blocks) + ) if valid_data: - valid_gen = DataLoader(feature_map, valid_data, batch_size=batch_size, shuffle=False, **kwargs) - logging.info("Validation samples: total/{:d}, blocks/{:d}".format(valid_gen.num_samples, valid_gen.num_blocks)) + valid_gen = DataLoader(feature_map, valid_data, split="valid", + batch_size=batch_size, shuffle=False, **kwargs) + logging.info( + "Validation samples: total/{:d}, blocks/{:d}" + .format(valid_gen.num_samples, valid_gen.num_blocks) + ) if stage in ["both", "test"]: if test_data: - test_gen = DataLoader(feature_map, test_data, batch_size=batch_size, shuffle=False, **kwargs) - logging.info("Test samples: total/{:d}, blocks/{:d}".format(test_gen.num_samples, test_gen.num_blocks)) + test_gen = DataLoader(feature_map, test_data, split="test", batch_size=batch_size, + shuffle=False, **kwargs) + logging.info( + "Test samples: total/{:d}, blocks/{:d}" + .format(test_gen.num_samples, test_gen.num_blocks) + ) self.train_gen, self.valid_gen, self.test_gen = train_gen, valid_gen, test_gen def make_iterator(self): diff --git a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py index 488333f..5b67275 100644 --- a/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py +++ b/fuxictr/pytorch/layers/embeddings/pretrained_embedding.py @@ -17,12 +17,12 @@ import torch from torch import nn -import h5py import os import io import json import numpy as np import logging +from ....utils import load_pretrain_emb class PretrainedEmbedding(nn.Module): @@ -66,17 +66,6 @@ def reset_parameters(self, embedding_initializer): nn.init.zeros_(self.id_embedding.weight) # set oov token embeddings to zeros embedding_initializer(self.id_embedding.weight[1:self.oov_idx, :]) - def get_pretrained_embedding(self, pretrain_path): - logging.info("Loading pretrained_emb: {}".format(pretrain_path)) - if pretrain_path.endswith("h5"): - with h5py.File(pretrain_path, 'r') as hf: - keys = hf["key"][:] - embeddings = hf["value"][:] - elif pretrain_path.endswith("npz"): - npz = np.load(pretrain_path) - keys, embeddings = npz["key"], npz["value"] - return keys, embeddings - def load_feature_vocab(self, vocab_path, feature_name): with io.open(vocab_path, "r", encoding="utf-8") as fd: vocab = json.load(fd) @@ -94,7 +83,8 @@ def load_pretrained_embedding(self, vocab_size, pretrain_dim, pretrain_path, voc embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(vocab_size, pretrain_dim)) if padding_idx: embedding_matrix[padding_idx, :] = np.zeros(pretrain_dim) # set as zero for PAD - keys, embeddings = self.get_pretrained_embedding(pretrain_path) + logging.info("Loading pretrained_emb: {}".format(pretrain_path)) + keys, embeddings = load_pretrain_emb(pretrain_path, keys=["key", "value"]) assert embeddings.shape[-1] == pretrain_dim, f"pretrain_dim={pretrain_dim} not correct." vocab, vocab_type = self.load_feature_vocab(vocab_path, feature_name) keys = keys.astype(vocab_type) # ensure the same dtype between pretrained keys and vocab keys diff --git a/fuxictr/pytorch/models/rank_model.py b/fuxictr/pytorch/models/rank_model.py index 2a5da95..88b5b76 100644 --- a/fuxictr/pytorch/models/rank_model.py +++ b/fuxictr/pytorch/models/rank_model.py @@ -1,4 +1,5 @@ # ========================================================================= +# Copyright (C) 2023. FuxiCTR Authors. All rights reserved. # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -110,18 +111,18 @@ def get_inputs(self, inputs, feature_source=None): continue if spec["type"] == "meta": continue - X_dict[feature] = inputs[:, self.feature_map.get_column_index(feature)].to(self.device) + X_dict[feature] = inputs[feature].to(self.device) return X_dict def get_labels(self, inputs): """ Please override get_labels() when using multiple labels! """ labels = self.feature_map.labels - y = inputs[:, self.feature_map.get_column_index(labels[0])].to(self.device) + y = inputs[labels[0]].to(self.device) return y.float().view(-1, 1) def get_group_id(self, inputs): - return inputs[:, self.feature_map.get_column_index(self.feature_map.group_id)] + return inputs[self.feature_map.group_id] def model_to_device(self): self.to(device=self.device) diff --git a/fuxictr/utils.py b/fuxictr/utils.py index 0011ff6..a0b9219 100644 --- a/fuxictr/utils.py +++ b/fuxictr/utils.py @@ -20,6 +20,9 @@ import yaml import glob import json +import h5py +import numpy as np +import pandas as pd from collections import OrderedDict @@ -90,6 +93,7 @@ def print_to_json(data, sort_keys=True): def print_to_list(data): return ' - '.join('{}: {:.6f}'.format(k, v) for k, v in data.items()) + class Monitor(object): def __init__(self, kv): if isinstance(kv, str): @@ -104,3 +108,20 @@ def get_value(self, logs): def get_metrics(self): return list(self.kv_pairs.keys()) + + +def load_pretrain_emb(pretrain_path, keys=["key", "value"]): + if type(keys) != list: + keys = [keys] + if pretrain_path.endswith("h5"): + with h5py.File(pretrain_path, 'r') as hf: + values = [hf[k][:] for k in keys] + elif pretrain_path.endswith("npz"): + npz = np.load(pretrain_path) + values = [npz[k] for k in keys] + elif pretrain_path.endswith("parquet"): + df = pd.read_parquet(pretrain_path) + values = [df[k].values for k in keys] + else: + raise ValueError(f"Embedding format not supported: {pretrain_path}") + return values[0] if len(values) == 1 else values diff --git a/fuxictr/version.py b/fuxictr/version.py index 6f43348..1108fcc 100644 --- a/fuxictr/version.py +++ b/fuxictr/version.py @@ -1 +1 @@ -__version__="2.2.3" +__version__="2.3.0" diff --git a/requirements.txt b/requirements.txt index 6711007..eb1bba3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -torch keras_preprocessing -PyYAML pandas +PyYAML scikit-learn numpy h5py diff --git a/setup.py b/setup.py index 89bab60..9f141d1 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="fuxictr", - version="2.2.3", + version="2.3.0", author="RECZOO", author_email="reczoo@users.noreply.github.com", description="A configurable, tunable, and reproducible library for CTR prediction", @@ -17,7 +17,8 @@ exclude=["model_zoo", "tests", "data", "docs", "demo"]), include_package_data=True, python_requires=">=3.6", - install_requires=["pandas", "numpy", "h5py", "PyYAML>=5.1", "scikit-learn", "tqdm"], + install_requires=["keras_preprocessing", "pandas", "PyYAML>=5.1", "scikit-learn", + "numpy", "h5py", "tqdm", "pyarrow", "polars"], classifiers=( "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", @@ -25,7 +26,6 @@ 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development',