## Tensorflow读取数据的Dataset API

In [1]:
import tensorflow as tf

### 1. 读取数据构造Dataset

#### 读取列表

In [2]:
datas = [[1, 2], [3, 4], [5, 6]]
dataset = tf.data.Dataset.from_tensor_slices(datas)
dataset

<TensorSliceDataset shapes: (2,), types: tf.int32>

In [3]:
for row in dataset:
    print(row)

tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([5 6], shape=(2,), dtype=int32)


#### 读取字典

In [4]:
datas = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7,8,9]}
dataset = tf.data.Dataset.from_tensor_slices(datas)
dataset

<TensorSliceDataset shapes: {a: (), b: (), c: ()}, types: {a: tf.int32, b: tf.int32, c: tf.int32}>

In [5]:
for data in dataset.as_numpy_iterator():
    print(data)

{'a': 1, 'b': 4, 'c': 7}
{'a': 2, 'b': 5, 'c': 8}
{'a': 3, 'b': 6, 'c': 9}


#### 读取元组

In [6]:
features = [
    [1,2,3],
    [4,5,6],
    [7,8,9]
]
labels = [1, 0, 1]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset

<TensorSliceDataset shapes: ((3,), ()), types: (tf.int32, tf.int32)>

In [7]:
for data in dataset.as_numpy_iterator():
    print(data)

(array([1, 2, 3]), 1)
(array([4, 5, 6]), 0)
(array([7, 8, 9]), 1)


### 2. 对Dataset执行各种转换

In [8]:
datas = list(range(20))
dataset = tf.data.Dataset.from_tensor_slices(datas)
list(dataset.as_numpy_iterator())

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

#### map操作

In [9]:
dataset_map = dataset.map(lambda x : x+1)
list(dataset_map.as_numpy_iterator())

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

#### batch操作

In [10]:
dataset_batch = dataset.batch(3)
list(dataset_batch.as_numpy_iterator())

[array([0, 1, 2]),
 array([3, 4, 5]),
 array([6, 7, 8]),
 array([ 9, 10, 11]),
 array([12, 13, 14]),
 array([15, 16, 17]),
 array([18, 19])]

#### repeat

In [11]:
dataset_repeat = dataset.repeat(3).batch(5)
list(dataset_repeat.as_numpy_iterator())

[array([0, 1, 2, 3, 4]),
 array([5, 6, 7, 8, 9]),
 array([10, 11, 12, 13, 14]),
 array([15, 16, 17, 18, 19]),
 array([0, 1, 2, 3, 4]),
 array([5, 6, 7, 8, 9]),
 array([10, 11, 12, 13, 14]),
 array([15, 16, 17, 18, 19]),
 array([0, 1, 2, 3, 4]),
 array([5, 6, 7, 8, 9]),
 array([10, 11, 12, 13, 14]),
 array([15, 16, 17, 18, 19])]

#### shuffle 

In [12]:
dataset_shuffle = dataset.shuffle(5)
list(dataset_shuffle.as_numpy_iterator())

[1, 2, 5, 7, 3, 8, 10, 9, 6, 4, 12, 13, 15, 0, 16, 17, 14, 11, 19, 18]

#### 链式操作

In [13]:
dataset_chain = (
    dataset
        .repeat(3)
        .shuffle(5)
        .batch(7)
)
list(dataset_chain.as_numpy_iterator())

[array([ 2,  4,  5,  6,  0,  7, 10]),
 array([ 8,  3, 12,  9, 15, 13, 16]),
 array([11, 14, 18,  1, 19,  0,  4]),
 array([17,  6,  7,  2,  5,  8, 10]),
 array([ 3,  1, 14, 15, 13, 16, 12]),
 array([17,  0,  9,  1, 19,  4,  5]),
 array([ 6,  3,  7,  2,  9, 11,  8]),
 array([11, 10, 12, 15, 18, 17, 16]),
 array([19, 14, 13, 18])]