-
Notifications
You must be signed in to change notification settings - Fork 943
/
mnist_data_setup.py
65 lines (51 loc) · 2.74 KB
/
mnist_data_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
if __name__ == "__main__":
import argparse
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
import tensorflow as tf
import tensorflow_datasets as tfds
parser = argparse.ArgumentParser()
parser.add_argument("--num_partitions", help="Number of output partitions", type=int, default=10)
parser.add_argument("--output", help="HDFS directory to save examples in parallelized format", default="data/mnist")
args = parser.parse_args()
print("args:", args)
sc = SparkContext(conf=SparkConf().setAppName("mnist_data_setup"))
mnist, info = tfds.load('mnist', with_info=True)
print(info.as_json)
# convert to numpy, then RDDs
mnist_train = tfds.as_numpy(mnist['train'])
mnist_test = tfds.as_numpy(mnist['test'])
train_rdd = sc.parallelize(mnist_train, args.num_partitions).cache()
test_rdd = sc.parallelize(mnist_test, args.num_partitions).cache()
# save as CSV (label,comma-separated-features)
def to_csv(example):
return str(example['label']) + ',' + ','.join([str(i) for i in example['image'].reshape(784)])
train_rdd.map(to_csv).saveAsTextFile(args.output + "/csv/train")
test_rdd.map(to_csv).saveAsTextFile(args.output + "/csv/test")
# save as TFRecords (numpy vs. PNG)
# note: the MNIST tensorflow_dataset is already provided as TFRecords but with a PNG bytes_list
# this export format is less-efficient, but easier to work with later
def to_tfr(example):
ex = tf.train.Example(
features=tf.train.Features(
feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['label'].astype("int64")])),
'image': tf.train.Feature(int64_list=tf.train.Int64List(value=example['image'].reshape(784).astype("int64")))
}
)
)
return (bytearray(ex.SerializeToString()), None)
train_rdd.map(to_tfr).saveAsNewAPIHadoopFile(args.output + "/tfr/train",
"org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
test_rdd.map(to_tfr).saveAsNewAPIHadoopFile(args.output + "/tfr/test",
"org.tensorflow.hadoop.io.TFRecordFileOutputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")