# Convert custom data to TFRecord

Demonstrates how to convert your custom data to TFRecord format. 


In [None]:
%config IPCompleter.greedy=True

In [None]:
!pip install tensorflow
!pip install tfx

In [None]:
import tensorflow as tf 
import csv
import os, pwd
from tfx.utils.dsl_utils import external_input
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.components import (
    FileBasedExampleGen,
    ImportExampleGen
)

## Helper Functions

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def clean_rows(row):
    if not row["zip_code"]:
        row["zip_code"] = "99999"
    return row

def convert_zipcode_to_int(zipcode):
    if isinstance(zipcode, str) and "XX" in zipcode:
        zipcode = zipcode.replace("XX", "00")
    int_zipcode = int(zipcode)
    return int_zipcode

## Convert the csv file to tfrecord format

In [None]:
base_dir = pwd.getpwuid(os.getuid()).pw_dir
data_dir_str = 'Github/building-machine-learning-pipelines/data'
data_dir = os.path.join(base_dir, data_dir_str)
original_data_file = os.path.join(data_dir, 'consumer_complaints_with_narrative.csv')
tfrecord_filename = 'consumer-complaints.tfrecord'

tfrecord_writer = tf.io.TFRecordWriter(tfrecord_filename)

with open(original_data_file) as csv_file:
    reader = csv.DictReader(csv_file, delimiter=",", quotechar='"')
    for row in reader:
        row = clean_rows(row)
        example = tf.train.Example(features=tf.train.Features(feature={
            "product": _bytes_feature(row["product"]),
            "sub_product": _bytes_feature(row["sub_product"]),
            "issue": _bytes_feature(row["issue"]),
            "sub_issue": _bytes_feature(row["sub_issue"]),
            "state": _bytes_feature(row["state"]),
            "zip_code": _int64_feature(convert_zipcode_to_int(row["zip_code"])),
            "company": _bytes_feature(row["company"]),
            "company_response": _bytes_feature(row["company_response"]),
            "consumer_complaint_narrative": _bytes_feature(row["consumer_complaint_narrative"]),
            "timely_response": _bytes_feature(row["timely_response"]),
            "consumer_disputed": _bytes_feature(row["consumer_disputed"])
        }))
        tfrecord_writer.write(example.SerializeToString())
    tfrecord_writer.close()

context = InteractiveContext()

examples = external_input(data_dir)
example_gen = ImportExampleGen(input=examples)
context.run(example_gen)
