- Reference: https://ithelp.ithome.com.tw/articles/10241789

In [1]:
import tensorflow as tf

## 單一CSV檔案

In [2]:
titanic_file_path = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")

Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv


In [4]:
titanic_csv_ds = tf.data.experimental.make_csv_dataset(
    titanic_file_path,
    batch_size = 5,
    label_name = 'survived',
    ignore_errors = True

)

#file_pattern：指定的資料來源
#batch_size：單一批次處理的數量
#column_names：欄位的名稱，若無指定則會自動推論名稱
#column_defaults：欄位中指定的格式，如float32, float64, int32, int64, string
#label_name：指定模型推論的欄位
#select_columns：可以挑選指定的欄位資料
#field_delim：預設為csv的","，可用於指定分隔資料的符號
#use_quote_delim：預設值為True，若設定為False則會將雙引號讀取為常規字串
#na_value：可以設定是否識別NA / NaN值
#header：是否包含檔頭
#num_epochs：可以指定重複該數據的次數
#shuffle：隨機抽換資料
#shuffle_buffer_size：隨機抽換資料的大小，數字越大，記憶體消耗越大
#shuffle_seed：隨機種子
#prefetch_buffer_size：預設為自動調整，通常是配合批次量來處理
#num_parallel_reads：指定讀取資料的線程數，預設為1
#sloppy:如果設定為True，則會盡可能高效率的方式給予讀取資料，但不保證排序或資料是否有問題
#num_rows_for_inference：預設為100，如果設定為None則可以讀取所有的資料
#compression_type：預設為無壓縮，支援ZLIB與GZIP格式
#ignore_errors：忽略CSV文件解析過程中的錯誤

In [10]:
for batch,label in titanic_csv_ds.take(1): # 取出一個batch
    for key,value in batch.items():
        print(f"{key:1s}:{value}")
    print(f"{'label':1s}:{label}")

sex:[b'female' b'female' b'male' b'male' b'female']
age:[35.  22.  22.  32.5 24. ]
n_siblings_spouses:[1 0 1 1 0]
parch:[0 2 0 0 3]
fare:[83.475  49.5     7.25   30.0708 19.2583]
class:[b'First' b'First' b'Third' b'Second' b'Third']
deck:[b'C' b'B' b'unknown' b'unknown' b'unknown']
embark_town:[b'Southampton' b'Cherbourg' b'Southampton' b'Cherbourg' b'Cherbourg']
alone:[b'n' b'n' b'n' b'n' b'n']
label:[1 1 0 0 1]


## 單一GZ檔

In [18]:
traffic_volumn_csv_gz = tf.keras.utils.get_file(
    'Metro_Interstate_Traffic_Volume.csv.gz', 
    "https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz",
    cache_dir='.', cache_subdir='traffic')

In [22]:
traffic_volume_csv_gz_ds = tf.data.experimental.make_csv_dataset(
    traffic_volumn_csv_gz,
    batch_size = 256,
    label_name = 'traffic_volume',
    num_epochs = 1,
    compression_type = "GZIP"
)

for batch, label in traffic_volume_csv_gz_ds.take(1):
    for key,value in batch.items():
        print(f"{key:20s}:{value[:5]}")
    print()
    print(f"{'label':20s}:{label[:5]}")
        

holiday             :[b'None' b'None' b'None' b'None' b'None']
temp                :[264.15 290.26 277.61 272.08 280.1 ]
rain_1h             :[0. 0. 0. 0. 0.]
snow_1h             :[0. 0. 0. 0. 0.]
clouds_all          :[75  0 20 90 92]
weather_main        :[b'Clouds' b'Mist' b'Clouds' b'Mist' b'Rain']
weather_description :[b'broken clouds' b'mist' b'few clouds' b'mist' b'light rain']
date_time           :[b'2012-12-24 18:00:00' b'2013-06-25 06:00:00' b'2012-11-17 12:00:00'
 b'2013-04-12 09:00:00' b'2012-10-19 23:00:00']

label               :[2536 6386 4753 5197 2051]


In [23]:
traffic_volume_csv_gz_ds

<PrefetchDataset shapes: (OrderedDict([(holiday, (None,)), (temp, (None,)), (rain_1h, (None,)), (snow_1h, (None,)), (clouds_all, (None,)), (weather_main, (None,)), (weather_description, (None,)), (date_time, (None,))]), (None,)), types: (OrderedDict([(holiday, tf.string), (temp, tf.float32), (rain_1h, tf.float32), (snow_1h, tf.float32), (clouds_all, tf.int32), (weather_main, tf.string), (weather_description, tf.string), (date_time, tf.string)]), tf.int32)>

### 透過快取(Caching)或快照(Snapshot)處理數據

- repeat: 重複執行次數
- prefetch()：在訓練時，同時讀取下一批資料，並作轉換。
- cache()：可將讀出的資料留在快取記憶體，之後可重複使用。

In [25]:
%%time

for i, (batch,label) in enumerate(traffic_volume_csv_gz_ds.repeat(20)):
    if i % 40 == 0:
        print('.',end = '')
print()

...............................................................................................
Wall time: 5.87 s


In [28]:
%%time
#快取(Caching)將數據在第一次epoch就做快取


caching = traffic_volume_csv_gz_ds.cache().shuffle(1000)

for i,(batch,label) in enumerate(caching.shuffle(1000).repeat(20)):
    if i % 40 == 0:
        print('.',end=  '')
print()

...............................................................................................
Wall time: 905 ms
