In [1]:
import pandas as pd
import tensorflow as tf
from types import SimpleNamespace
from typing import Dict,List,Tuple,Optional

#  #####################数据ETL##########################
RData = SimpleNamespace()
RData.TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
RData.TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
RData.TARGET = 'Species'
RData.FEATURE_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']
RData.CSV_COLUMN_NAMES = RData.FEATURE_NAMES + [RData.TARGET]
RData.SPECIES = ['Setosa', 'Versicolor', 'Virginica']
RData.CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]

ETLData = SimpleNamespace()


# extract
def maybe_download():
    train_path = tf.keras.utils.get_file(RData.TRAIN_URL.split('/')[-1], RData.TRAIN_URL)
    test_path = tf.keras.utils.get_file(RData.TEST_URL.split('/')[-1], RData.TEST_URL)
    return train_path, test_path


# load dataframe
def load_data()->Tuple[Tuple[pd.DataFrame, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
    train_path, test_path = maybe_download()
    train = pd.read_csv(train_path, names=RData.CSV_COLUMN_NAMES, header=0)
    test = pd.read_csv(test_path, names=RData.CSV_COLUMN_NAMES, header=0)
    train_x, train_y = train, train.pop(RData.TARGET)
    test_x, test_y = test, test.pop(RData.TARGET)
    return (train_x, train_y), (test_x, test_y)


def parse_csv_row(csv_row)->Tuple[Dict[str, tf.Tensor], tf.Tensor]:
    columns = tf.decode_csv(csv_row, record_defaults=RData.CSV_TYPES)
    features = dict(zip(RData.CSV_COLUMN_NAMES, columns))
    target = features.pop(RData.TARGET)
    return features, target




def csv_input_fn(file_name_pattern, mode=tf.estimator.ModeKeys.EVAL, skip_header_lines=0,
                 num_epochs=None, batch_size=200)->tf.data.Dataset:
    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False
    file_names = tf.matching_files(file_name_pattern)
    dataset = tf.data.TextLineDataset(file_names)
    dataset = dataset.skip(skip_header_lines)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)
    dataset = dataset.map(parse_csv_row)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    return dataset

if __name__=="__main__":
    train_path, test_path = maybe_download()
    dataset = csv_input_fn(train_path)




  from ._conv import register_converters as _register_converters


In [8]:
dataset.shard

<bound method Dataset.shard of <RepeatDataset shapes: ({SepalLength: (?,), SepalWidth: (?,), PetalLength: (?,), PetalWidth: (?,)}, (?,)), types: ({SepalLength: tf.float32, SepalWidth: tf.float32, PetalLength: tf.float32, PetalWidth: tf.float32}, tf.int32)>>