In [89]:
import tensorflow as tf
import numpy as np
import shutil
print(tf.__version__)

1.15.0


In [90]:
tf.enable_eager_execution()

# Test Data CSV

In [91]:
%%writefile test.csv
fare_amount,dayofweek,hourofday,pickuplon,pickuplat,dropofflon,dropofflat
28,1,0,-73.0,41.0,-74.0,20.7
12.3,1,0,-72.0,44.0,-75.0,40.6
10,1,0,-71.0,41.0,-71.0,42.9

Overwriting test.csv


# Input Function reading from CSV

In [92]:
CSV_COLUMN_NAMES = ["fare_amount","dayofweek","hourofday","pickuplon","pickuplat","dropofflon","dropofflat"]
CSV_DEFAULTS = [[0.0],[1],[0],[-74.0], [40.0], [-74.0], [40.7]]

In [93]:
def parse_row(row):
    fields = tf.decode_csv(records = row, record_defaults = CSV_DEFAULTS)
    features = dict(zip(CSV_COLUMN_NAMES, fields))
    label = features.pop("fare_amount") # remove label from features and store
    return features, label

In [94]:
def read_dataset(csv_path):  
    dataset = tf.data.TextLineDataset(filenames = csv_path).skip(count = 1) # skip header
    dataset = dataset.map(map_func = parse_row) 
    return dataset

# parse_row 확인

In [95]:
a_row = "0.0,1,0,-74.0,40.0,-74.0,40.7"
features, labels = parse_row(a_row)

assert labels.numpy() == 0.0
assert features["pickuplon"].numpy() == -74.0
print("dayofweek:{}   hourofday:{}   pickuplon:{}".format(features["dayofweek"].numpy(), features["hourofday"].numpy(), features["pickuplon"].numpy()))
print("pickuplat:{}   dropofflon:{}   dropofflat:{}".format(features["pickuplat"].numpy(), features["dropofflon"].numpy(), features["dropofflat"].numpy()))
print("fare_amount:{}".format(labels.numpy()))
#                features["pickuplon"].numpy(), 
#                
#                pickuplat.numpy(), dropofflon.numpy(), dropofflat.numpy(), ))
print("You rock!")

dayofweek:1   hourofday:0   pickuplon:-74.0
pickuplat:40.0   dropofflon:-74.0   dropofflat:40.70000076293945
fare_amount:0.0
You rock!


# read_dataset 확인

In [96]:
for feature, label in read_dataset("./test.csv"):
    print('*'*100)
    print("dayofweek:{}   hourofday:{}   pickuplon:{}".format(feature["dayofweek"].numpy(), feature["hourofday"].numpy(), feature["pickuplon"].numpy()))
    print("pickuplat:{}   dropofflon:{}   dropofflat:{}".format(feature["pickuplat"].numpy(), feature["dropofflon"].numpy(), feature["dropofflat"].numpy()))
    print("fare_amount:{}".format(label.numpy()))

****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-73.0
pickuplat:41.0   dropofflon:-74.0   dropofflat:20.700000762939453
fare_amount:28.0
****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-72.0
pickuplat:44.0   dropofflon:-75.0   dropofflat:40.599998474121094
fare_amount:12.300000190734863
****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-71.0
pickuplat:41.0   dropofflon:-71.0   dropofflat:42.900001525878906
fare_amount:10.0


# iterator로 한건씩 읽기

In [97]:
dataset= read_dataset("./test.csv")
dataset_iterator = dataset.make_one_shot_iterator()

for i in np.array(range(3)):
    features, labels = dataset_iterator.get_next()
    print('*'*100)
    print("dayofweek:{}   hourofday:{}   pickuplon:{}".format(features["dayofweek"].numpy(), features["hourofday"].numpy(), features["pickuplon"].numpy()))
    print("pickuplat:{}   dropofflon:{}   dropofflat:{}".format(features["pickuplat"].numpy(), features["dropofflon"].numpy(), features["dropofflat"].numpy()))
    print("fare_amount:{}".format(labels.numpy()))


****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-73.0
pickuplat:41.0   dropofflon:-74.0   dropofflat:20.700000762939453
fare_amount:28.0
****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-72.0
pickuplat:44.0   dropofflon:-75.0   dropofflat:40.599998474121094
fare_amount:12.300000190734863
****************************************************************************************************
dayofweek:1   hourofday:0   pickuplon:-71.0
pickuplat:41.0   dropofflon:-71.0   dropofflat:42.900001525878906
fare_amount:10.0


# Train_input_fn Define

In [129]:
def train_input_fn(csv_path, batch_size = 3):
    dataset = read_dataset(csv_path)
    dataset = dataset.shuffle(buffer_size = 1000).repeat(count = None).batch(batch_size = batch_size)
    return dataset

In [130]:
for i, data in enumerate(train_input_fn('test.csv')):
    if i >= 10:
        break
    print('{} >> pickuplat:{}   pickuplon:{}   label:{}'.format(i, data[0]["pickuplat"].numpy(),data[0]["pickuplon"].numpy(),data[1].numpy()))

0 >> pickuplat:[41. 41. 44.]   pickuplon:[-73. -71. -72.]   label:[28.  10.  12.3]
1 >> pickuplat:[41. 44. 41.]   pickuplon:[-73. -72. -71.]   label:[28.  12.3 10. ]
2 >> pickuplat:[44. 41. 41.]   pickuplon:[-72. -71. -73.]   label:[12.3 10.  28. ]
3 >> pickuplat:[44. 41. 41.]   pickuplon:[-72. -73. -71.]   label:[12.3 28.  10. ]
4 >> pickuplat:[44. 41. 41.]   pickuplon:[-72. -73. -71.]   label:[12.3 28.  10. ]
5 >> pickuplat:[41. 41. 44.]   pickuplon:[-71. -73. -72.]   label:[10.  28.  12.3]
6 >> pickuplat:[41. 44. 41.]   pickuplon:[-71. -72. -73.]   label:[10.  12.3 28. ]
7 >> pickuplat:[41. 44. 41.]   pickuplon:[-73. -72. -71.]   label:[28.  12.3 10. ]
8 >> pickuplat:[41. 41. 44.]   pickuplon:[-71. -73. -72.]   label:[10.  28.  12.3]
9 >> pickuplat:[44. 41. 41.]   pickuplon:[-72. -71. -73.]   label:[12.3 10.  28. ]


# Eval_input_fn Define(suffle, repeat하지 않는다.)

In [131]:
def eval_input_fn(csv_path, batch_size = 3):
    dataset = read_dataset(csv_path)
    dataset = dataset.batch(batch_size = batch_size)
    return dataset

In [132]:
for i, data in enumerate(eval_input_fn('test.csv')):
    if i >= 10:
        break
    print('{} >> pickuplat:{}   pickuplon:{}   label:{}'.format(i, data[0]["pickuplat"].numpy(),data[0]["pickuplon"].numpy(),data[1].numpy()))

0 >> pickuplat:[41. 44. 41.]   pickuplon:[-73. -72. -71.]   label:[28.  12.3 10. ]
