21
21
import collections
22
22
import gzip
23
23
import os
24
+ import urllib
24
25
25
26
import numpy
26
- from six.moves import urllib
27
- from six.moves import xrange # pylint: disable=redefined-builtin
28
-
29
- from tensorflow.python.framework import dtypes
30
- from tensorflow.python.framework import random_seed
27
+ from tensorflow .python .framework import dtypes , random_seed
31
28
from tensorflow .python .platform import gfile
32
29
from tensorflow .python .util .deprecation import deprecated
33
30
@@ -46,16 +43,16 @@ def _read32(bytestream):
46
43
def _extract_images (f ):
47
44
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
48
45
49
- Args:
50
- f: A file object that can be passed into a gzip reader.
46
+ Args:
47
+ f: A file object that can be passed into a gzip reader.
51
48
52
- Returns:
53
- data: A 4D uint8 numpy array [index, y, x, depth].
49
+ Returns:
50
+ data: A 4D uint8 numpy array [index, y, x, depth].
54
51
55
- Raises:
56
- ValueError: If the bytestream does not start with 2051.
52
+ Raises:
53
+ ValueError: If the bytestream does not start with 2051.
57
54
58
- """
55
+ """
59
56
print ("Extracting" , f .name )
60
57
with gzip .GzipFile (fileobj = f ) as bytestream :
61
58
magic = _read32 (bytestream )
@@ -86,17 +83,17 @@ def _dense_to_one_hot(labels_dense, num_classes):
86
83
def _extract_labels (f , one_hot = False , num_classes = 10 ):
87
84
"""Extract the labels into a 1D uint8 numpy array [index].
88
85
89
- Args:
90
- f: A file object that can be passed into a gzip reader.
91
- one_hot: Does one hot encoding for the result.
92
- num_classes: Number of classes for the one hot encoding.
86
+ Args:
87
+ f: A file object that can be passed into a gzip reader.
88
+ one_hot: Does one hot encoding for the result.
89
+ num_classes: Number of classes for the one hot encoding.
93
90
94
- Returns:
95
- labels: a 1D uint8 numpy array.
91
+ Returns:
92
+ labels: a 1D uint8 numpy array.
96
93
97
- Raises:
98
- ValueError: If the bystream doesn't start with 2049.
99
- """
94
+ Raises:
95
+ ValueError: If the bystream doesn't start with 2049.
96
+ """
100
97
print ("Extracting" , f .name )
101
98
with gzip .GzipFile (fileobj = f ) as bytestream :
102
99
magic = _read32 (bytestream )
@@ -115,8 +112,8 @@ def _extract_labels(f, one_hot=False, num_classes=10):
115
112
class _DataSet :
116
113
"""Container class for a _DataSet (deprecated).
117
114
118
- THIS CLASS IS DEPRECATED.
119
- """
115
+ THIS CLASS IS DEPRECATED.
116
+ """
120
117
121
118
@deprecated (
122
119
None ,
@@ -135,21 +132,21 @@ def __init__(
135
132
):
136
133
"""Construct a _DataSet.
137
134
138
- one_hot arg is used only if fake_data is true. `dtype` can be either
139
- `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
140
- `[0, 1]`. Seed arg provides for convenient deterministic testing.
141
-
142
- Args:
143
- images: The images
144
- labels: The labels
145
- fake_data: Ignore inages and labels, use fake data.
146
- one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
147
- False).
148
- dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
149
- range [0,255]. float32 output has range [0,1].
150
- reshape: Bool. If True returned images are returned flattened to vectors.
151
- seed: The random seed to use.
152
- """
135
+ one_hot arg is used only if fake_data is true. `dtype` can be either
136
+ `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
137
+ `[0, 1]`. Seed arg provides for convenient deterministic testing.
138
+
139
+ Args:
140
+ images: The images
141
+ labels: The labels
142
+ fake_data: Ignore inages and labels, use fake data.
143
+ one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
144
+ False).
145
+ dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
146
+ range [0,255]. float32 output has range [0,1].
147
+ reshape: Bool. If True returned images are returned flattened to vectors.
148
+ seed: The random seed to use.
149
+ """
153
150
seed1 , seed2 = random_seed .get_seed (seed )
154
151
# If op level seed is not set, use whatever graph level seed is returned
155
152
numpy .random .seed (seed1 if seed is None else seed2 )
@@ -206,8 +203,8 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
206
203
else :
207
204
fake_label = 0
208
205
return (
209
- [fake_image for _ in xrange (batch_size)],
210
- [fake_label for _ in xrange (batch_size)],
206
+ [fake_image for _ in range (batch_size )],
207
+ [fake_label for _ in range (batch_size )],
211
208
)
212
209
start = self ._index_in_epoch
213
210
# Shuffle for the first epoch
@@ -250,19 +247,19 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
250
247
def _maybe_download (filename , work_directory , source_url ):
251
248
"""Download the data from source url, unless it's already here.
252
249
253
- Args:
254
- filename: string, name of the file in the directory.
255
- work_directory: string, path to working directory.
256
- source_url: url to download from if file doesn't exist.
250
+ Args:
251
+ filename: string, name of the file in the directory.
252
+ work_directory: string, path to working directory.
253
+ source_url: url to download from if file doesn't exist.
257
254
258
- Returns:
259
- Path to resulting file.
260
- """
255
+ Returns:
256
+ Path to resulting file.
257
+ """
261
258
if not gfile .Exists (work_directory ):
262
259
gfile .MakeDirs (work_directory )
263
260
filepath = os .path .join (work_directory , filename )
264
261
if not gfile .Exists (filepath ):
265
- urllib.request.urlretrieve(source_url, filepath)
262
+ urllib .request .urlretrieve (source_url , filepath ) # noqa: S310
266
263
with gfile .GFile (filepath ) as f :
267
264
size = f .size ()
268
265
print ("Successfully downloaded" , filename , size , "bytes." )
@@ -328,15 +325,16 @@ def fake():
328
325
329
326
if not 0 <= validation_size <= len (train_images ):
330
327
raise ValueError (
331
- f"Validation size should be between 0 and {len(train_images)}. Received: {validation_size}."
328
+ f"Validation size should be between 0 and { len (train_images )} . "
329
+ f"Received: { validation_size } ."
332
330
)
333
331
334
332
validation_images = train_images [:validation_size ]
335
333
validation_labels = train_labels [:validation_size ]
336
334
train_images = train_images [validation_size :]
337
335
train_labels = train_labels [validation_size :]
338
336
339
- options = dict( dtype= dtype, reshape= reshape, seed= seed)
337
+ options = { " dtype" : dtype , " reshape" : reshape , " seed" : seed }
340
338
341
339
train = _DataSet (train_images , train_labels , ** options )
342
340
validation = _DataSet (validation_images , validation_labels , ** options )
0 commit comments