In [1]:
# NOTE: this is a custom cell that contains the common imports I personally 
# use these may/may not be necessary for the following examples

# DL framework
import tensorflow as tf

from datetime import datetime

# common packages
import numpy as np
import os # handling file i/o
import sys
import math
import time # timing epochs

# for ordered dict when building layer components
import collections

# plotting pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib import colors # making colors consistent
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar helper

# read image
### from imageio import imread
# + data augmentation
from scipy import ndimage
from scipy import misc

# used for manually saving best params
import pickle

# for shuffling data batches
from sklearn.utils import shuffle

# const
SEED = 42

# Helper to make the output consistent
def reset_graph(seed=SEED):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

# helper to create dirs if they don't already exist
def maybe_create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print("{} created".format(dir_path))
    else:
        print("{} already exists".format(dir_path))
    
# set tf log level to supress messages, unless an error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Important Version information
print("Python: {}".format(sys.version_info[:]))
print('TensorFlow: {}'.format(tf.__version__))

# Check if using GPU
if not tf.test.gpu_device_name():
    print('No GPU')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    
reset_graph()

Python: (3, 6, 5, 'final', 0)
TensorFlow: 1.8.0
Default GPU Device: /device:GPU:0


In [2]:
# `/record_holder` will (hopefully) contain our tf_records file# `/rec 
# by the end of this notebook
FINAL_DIR = "../data/224_224/"
maybe_create_dir(FINAL_DIR)

../data/224_224/ created


In [4]:
ROOT_DIR = "./numpy_final/224_224/"

for _, _, files in os.walk(ROOT_DIR):
    files = sorted(files)
    for filename in files:
        print(filename)

X_test = np.load(os.path.join(ROOT_DIR, files[0]))
y_test = np.load(os.path.join(ROOT_DIR, files[1]))

X_train = np.load(os.path.join(ROOT_DIR, files[2]))
y_train = np.load(os.path.join(ROOT_DIR, files[3]))

X_val = np.load(os.path.join(ROOT_DIR, files[4]))
y_val = np.load(os.path.join(ROOT_DIR, files[5]))

test.npy
test_masks.npy
train.npy
train_masks.npy
validation.npy
validation_masks.npy


In [6]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [8]:
def numpy_to_tfrecords(features, lables, setType):
    assert len(features) == len(lables), "features & labels are not equal in len"
    tfrecords_file_name = str(setType) + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(os.path.join(FINAL_DIR, tfrecords_file_name))
    
    # TODO: assert same length
    for i in range(len(features)):
        img = features[i]
        label = lables[i]
    
        # create features
        feature = {'/label': _bytes_feature(tf.compat.as_bytes(img.tostring())),
                   '/image': _bytes_feature(tf.compat.as_bytes(label.tostring()))}
        
        # create example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        
        writer.write(example.SerializeToString())
        
        if i % 250 == 0:
            print("{} {} written".format(i, setType))
        
    writer.close()
    sys.stdout.flush()
    print("done")

In [10]:
numpy_to_tfrecords(X_val, y_val, "validation")

0 validation written
250 validation written
done


In [11]:
numpy_to_tfrecords(X_train, y_train, "train")

0 train written
250 train written
500 train written
750 train written
1000 train written
1250 train written
1500 train written
done
