Skip to content

Commit

Permalink
Merge pull request #3 from facebookresearch/master
Browse files Browse the repository at this point in the history
Parallelizing the pre-processing of the dataset. (facebookresearch#117)
  • Loading branch information
tginart committed Sep 29, 2020
2 parents 9bdb7b8 + 78d7045 commit 5cae11a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 73 deletions.
210 changes: 145 additions & 65 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import sys
# import os
from os import path
from multiprocessing import Process, Manager
# import io
# from io import StringIO
# import collections as coll
Expand Down Expand Up @@ -108,7 +109,7 @@ def convertUStringToDistinctIntsUnique(mat, mat_uni, counts):
return out, mat_uni, counts


def processCriteoAdData(d_path, d_file, npzfile, split, convertDicts, pre_comp_counts):
def processCriteoAdData(d_path, d_file, npzfile, i, convertDicts, pre_comp_counts):
# Process Kaggle Display Advertising Challenge or Terabyte Dataset
# by converting unicode strings in X_cat to integers and
# converting negative integer values in X_int.
Expand All @@ -117,49 +118,48 @@ def processCriteoAdData(d_path, d_file, npzfile, split, convertDicts, pre_comp_c
#
# Inputs:
# d_path (str): path for {kaggle|terabyte}_day_i.npz files
# split (int): total number of splits in the dataset (typically 7 or 24)
# i (int): splits in the dataset (typically 0 to 7 or 0 to 24)

# process data if not all files exist
for i in range(split):
filename_i = npzfile + "_{0}_processed.npz".format(i)
filename_i = npzfile + "_{0}_processed.npz".format(i)

if path.exists(filename_i):
print("Using existing " + filename_i, end="\r")
else:
with np.load(npzfile + "_{0}.npz".format(i)) as data:
# categorical features
'''
# Approach 1a: using empty dictionaries
X_cat, convertDicts, counts = convertUStringToDistinctIntsDict(
data["X_cat"], convertDicts, counts
)
'''
'''
# Approach 1b: using empty np.unique
X_cat, convertDicts, counts = convertUStringToDistinctIntsUnique(
data["X_cat"], convertDicts, counts
)
'''
# Approach 2a: using pre-computed dictionaries
X_cat_t = np.zeros(data["X_cat_t"].shape)
for j in range(26):
for k, x in enumerate(data["X_cat_t"][j, :]):
X_cat_t[j, k] = convertDicts[j][x]
# continuous features
X_int = data["X_int"]
X_int[X_int < 0] = 0
# targets
y = data["y"]

np.savez_compressed(
filename_i,
# X_cat = X_cat,
X_cat=np.transpose(X_cat_t), # transpose of the data
X_int=X_int,
y=y,
if path.exists(filename_i):
print("Using existing " + filename_i, end="\n")
else:
print("Not existing " + filename_i)
with np.load(npzfile + "_{0}.npz".format(i)) as data:
# categorical features
'''
# Approach 1a: using empty dictionaries
X_cat, convertDicts, counts = convertUStringToDistinctIntsDict(
data["X_cat"], convertDicts, counts
)
print("Processed " + filename_i, end="\r")
print("")
'''
'''
# Approach 1b: using empty np.unique
X_cat, convertDicts, counts = convertUStringToDistinctIntsUnique(
data["X_cat"], convertDicts, counts
)
'''
# Approach 2a: using pre-computed dictionaries
X_cat_t = np.zeros(data["X_cat_t"].shape)
for j in range(26):
for k, x in enumerate(data["X_cat_t"][j, :]):
X_cat_t[j, k] = convertDicts[j][x]
# continuous features
X_int = data["X_int"]
X_int[X_int < 0] = 0
# targets
y = data["y"]

np.savez_compressed(
filename_i,
# X_cat = X_cat,
X_cat=np.transpose(X_cat_t), # transpose of the data
X_int=X_int,
y=y,
)
print("Processed " + filename_i, end="\n")
# sanity check (applicable only if counts have been pre-computed & are re-computed)
# for j in range(26):
# if pre_comp_counts[j] != counts[j]:
Expand Down Expand Up @@ -882,7 +882,8 @@ def getCriteoAdData(
data_split='train',
randomize='total',
criteo_kaggle=True,
memory_map=False
memory_map=False,
dataset_multiprocessing=False
):
# Passes through entire dataset and defines dictionaries for categorical
# features and determines the number of total categories.
Expand Down Expand Up @@ -968,7 +969,13 @@ def process_one_file(
npzfile,
split,
num_data_in_split,
dataset_multiprocessing,
convertDictsDay=None,
resultDay=None
):
if dataset_multiprocessing:
convertDicts_day = [{} for _ in range(26)]

with open(str(datfile)) as f:
y = np.zeros(num_data_in_split, dtype="i4") # 4 byte int
X_int = np.zeros((num_data_in_split, 13), dtype="i4") # 4 byte int
Expand All @@ -979,6 +986,7 @@ def process_one_file(
rand_u = np.random.uniform(low=0.0, high=1.0, size=num_data_in_split)

i = 0
percent = 0
for k, line in enumerate(f):
# process a line (data point)
line = line.split('\t')
Expand All @@ -1004,22 +1012,41 @@ def process_one_file(
list(map(lambda x: int(x, 16), line[14:])),
dtype=np.int32
)
# count uniques
for j in range(26):
convertDicts[j][X_cat[i][j]] = 1

# debug prints
print(
"Load %d/%d Split: %d Label True: %d Stored: %d"
% (
i,
num_data_in_split,
split,
target,
y[i],
),
end="\r",
)
# count uniques
if dataset_multiprocessing:
for j in range(26):
convertDicts_day[j][X_cat[i][j]] = 1
# debug prints
if float(i)/num_data_in_split*100 > percent+1:
percent = int(float(i)/num_data_in_split*100)
print(
"Load %d/%d (%d%%) Split: %d Label True: %d Stored: %d"
% (
i,
num_data_in_split,
percent,
split,
target,
y[i],
),
end="\n",
)
else:
for j in range(26):
convertDicts[j][X_cat[i][j]] = 1
# debug prints
print(
"Load %d/%d Split: %d Label True: %d Stored: %d"
% (
i,
num_data_in_split,
split,
target,
y[i],
),
end="\r",
)
i += 1

# store num_data_in_split samples or extras at the end of file
Expand All @@ -1041,7 +1068,13 @@ def process_one_file(
y=y[0:i],
)
print("\nSaved " + npzfile + "_{0}.npz!".format(split))
return i

if dataset_multiprocessing:
resultDay[split] = i
convertDictsDay[split] = convertDicts_day
return
else:
return i

# create all splits (reuse existing files if possible)
recreate_flag = False
Expand All @@ -1050,7 +1083,6 @@ def process_one_file(
# np.random.seed(123)
# in this case there is a single split in each day
for i in range(days):
datfile_i = npzfile + "_{0}".format(i) # + ".gz"
npzfile_i = npzfile + "_{0}.npz".format(i)
npzfile_p = npzfile + "_{0}_processed.npz".format(i)
if path.exists(npzfile_i):
Expand All @@ -1059,12 +1091,42 @@ def process_one_file(
print("Skip existing " + npzfile_p)
else:
recreate_flag = True
total_per_file[i] = process_one_file(
datfile_i,
npzfile,
i,
total_per_file[i],
)

if recreate_flag:
if dataset_multiprocessing:
resultDay = Manager().dict()
convertDictsDay = Manager().dict()
processes = [Process(target=process_one_file,
name="process_one_file:%i" % i,
args=(npzfile + "_{0}".format(i),
npzfile,
i,
total_per_file[i],
dataset_multiprocessing,
convertDictsDay,
resultDay,
)
) for i in range(0, days)]
for process in processes:
process.start()
for process in processes:
process.join()
for day in range(days):
total_per_file[day] = resultDay[day]
print("Constructing convertDicts Split: {}".format(day))
convertDicts_tmp = convertDictsDay[day]
for i in range(26):
for j in convertDicts_tmp[i]:
convertDicts[i][j] = 1
else:
for i in range(days):
total_per_file[i] = process_one_file(
npzfile + "_{0}".format(i),
npzfile,
i,
total_per_file[i],
dataset_multiprocessing,
)

# report and save total into a file
total_count = np.sum(total_per_file)
Expand Down Expand Up @@ -1103,7 +1165,25 @@ def process_one_file(
counts = data["counts"]

# process all splits
processCriteoAdData(d_path, d_file, npzfile, days, convertDicts, counts)
if dataset_multiprocessing:
processes = [Process(target=processCriteoAdData,
name="processCriteoAdData:%i" % i,
args=(d_path,
d_file,
npzfile,
i,
convertDicts,
counts,
)
) for i in range (0, days)]
for process in processes:
process.start()
for process in processes:
process.join()
else:
for i in range(days):
processCriteoAdData(d_path, d_file, npzfile, i, convertDicts, counts)

o_file = concatCriteoAdData(
d_path,
d_file,
Expand Down
24 changes: 16 additions & 8 deletions dlrm_data_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
split="train",
raw_path="",
pro_data="",
memory_map=False
memory_map=False,
dataset_multiprocessing=False
):
# dataset
# tar_fea = 1 # single target
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
split,
randomize,
dataset == "kaggle",
memory_map
memory_map,
dataset_multiprocessing
)

# get a number of samples per day
Expand Down Expand Up @@ -345,7 +347,8 @@ def ensure_dataset_preprocessed(args, d_path):
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

_ = CriteoDataset(
Expand All @@ -356,7 +359,8 @@ def ensure_dataset_preprocessed(args, d_path):
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

for split in ['train', 'val', 'test']:
Expand Down Expand Up @@ -442,7 +446,8 @@ def make_criteo_data_and_loaders(args):
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

test_data = CriteoDataset(
Expand All @@ -453,7 +458,8 @@ def make_criteo_data_and_loaders(args):
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

train_loader = data_loader_terabyte.DataLoader(
Expand Down Expand Up @@ -482,7 +488,8 @@ def make_criteo_data_and_loaders(args):
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

test_data = CriteoDataset(
Expand All @@ -493,7 +500,8 @@ def make_criteo_data_and_loaders(args):
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map
args.memory_map,
args.dataset_multiprocessing
)

train_loader = torch.utils.data.DataLoader(
Expand Down
5 changes: 5 additions & 0 deletions dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ def dash_separated_floats(value):
parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
parser.add_argument("--dataset-multiprocessing", action="store_true", default=False,
help="The Kaggle dataset can be multiprocessed in an environment \
with more than 7 CPU cores and more than 20 GB of memory. \n \
The Terabyte dataset can be multiprocessed in an environment \
with more than 24 CPU cores and at least 1 TB of memory.")
# training
parser.add_argument("--solver", type=str, default="sgd")
parser.add_argument("--mini-batch-size", type=int, default=1)
Expand Down

0 comments on commit 5cae11a

Please sign in to comment.