In [75]:
import numpy as np
import tensorflow as tf
from utils.draw import draw_single
print(f'Tensorflow v{tf.__version__}')

# Squares

In [76]:
def generate(count):
    X = np.random.randint(0, high=256, size=(count, 9))
    a = np.array([[1, 1, 1, 0, 0, 0, 0, 0, 0], 
                  [0, 0, 0, 1, 1, 1, 0, 0, 0], 
                  [0, 0, 0, 0, 0, 0, 1, 1, 1]])

    #Y = np.eye(3)[np.argmax(X.dot(a.T), axis=1)]
    Y = np.argmax(X.dot(a.T), axis=1)
    return X, Y

In [77]:
draw_single(*generate(120))

# Tensorflow Dataset
Datasets in Tensorflow are pretty cool - they let you slice, dice, shuffle, repeat, prefetch, pretty much anything you might want to do with data as it is pushed into the machine learning process. Here's how I deal with the 9 squares:

In [78]:
# create dataset from tensors
train_ds = tf.data.Dataset.from_tensor_slices(generate(10))
train_ds.element_spec

In [79]:
for x, y in train_ds:
    print(x)
    print(y)

In [80]:
def convert(x, y): return x, tf.one_hot(y, depth=3)

mapped_ds = train_ds.map(convert)
mapped_ds.element_spec

In [81]:
for x, y in mapped_ds:
    print(x)
    print(y)

In [82]:
batched_ds = mapped_ds.batch(2)
batched_ds = batched_ds.repeat(2)

idx = 0
for x, y in batched_ds:
    idx+=1
    print(idx)
    print(x)
    print(y)

In [83]:
shuffled_ds = mapped_ds.shuffle(buffer_size=2)
shuffled_ds = shuffled_ds.batch(2)
shuffled_ds = shuffled_ds.repeat(2)

idx = 0
for x, y in shuffled_ds:
    idx+=1
    print(idx)
    print(x)
    print(y)

# Digits

In [106]:
# get raw numbers
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print(x_train.shape)

In [102]:
mnist_ds = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(x_train),
                                tf.data.Dataset.from_tensor_slices(y_train)))

mnist_ds.element_spec

In [103]:
for digit in mnist_ds:
    print(digit)
    break

In [104]:
def transform(x, y): return tf.reshape(x, [28*28]), tf.one_hot(y, depth=10)

mnist_final = mnist_ds.map(transform).batch(64).repeat(10)
mnist_final.element_spec

In [105]:
for X, Y in mnist_final:
    print(X)
    print(Y)
    break