In [13]:
import tensorflow as tf
import numpy as np
from keras import layers 
import matplotlib.pyplot as plt
import cirq
import sympy
import seaborn as sns
import collections
import os
import cv2
from tensorflow.keras.preprocessing import image as tfimage
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from cirq.contrib.svg import SVGCircuit
import tensorflow_datasets 
from keras.utils import image_dataset_from_directory

In [14]:

train_ds = image_dataset_from_directory(
    directory='synthetic-data/images',
    labels='inferred',
    label_mode='int',
    batch_size=32,
    image_size=(256, 128),
    shuffle=True,
    seed=123, 
    validation_split=0.1,
    subset="training")

# validation_ds = image_dataset_from_directory(
#     directory='validation_data/',
#     labels='inferred',
#     label_mode='int',
#     batch_size=32,
#     image_size=(256, 128),
#     shffle=True,
#     seed=123, 
#     validation_split=0.1,
#     subset="validation")

# x, y = [], []
# for filename in os.listdir('synthetic-data/images'):
#     image = cv2.imread('synthetic-data/images/' + filename)
#     text = filename[:-4]
#     y.append(text)
#     x.append(image)
    

# def process_image(img):
#     gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     blur = cv2.medianBlur(gray, 5)
#     blur = tfimage.img_to_array(blur, dtype='uint8')
#     thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 21, 2)
#     return thresh

# x = map(process_image, x)
font_ds = train_ds

x_train, y_train, x_test, y_test = tfds.load(font_ds, split= 'train[:75%]')



# Rescale the images from [0,255] to the [0.0,1.0] range.
# x_train, x_test = x_train[..., np.newaxi]/255.0, x_test[..., np.newaxis]/255.0


print("Number of original training examples:", len(x_train))
print("Number of original test examples:", len(x_test))


def filter_36(x, y):
    keep = (y == 3) | (y == 6)
    x, y = x[keep], y[keep]
    y = y == 3
    return x,y


x_train, y_train = filter_36(x_train, y_train)
x_test, y_test = filter_36(x_test, y_test)


print("Number of filtered training examples:", len(x_train))
print("Number of filtered test examples:", len(x_test))


print(y_train[0])


plt.imshow(x_train[0, :, :, 0])
plt.colorbar()


x_train_small = tf.image.resize(x_train, (4,4)).numpy()
x_test_small = tf.image.resize(x_test, (4,4)).numpy()


print(y_train[0])


plt.imshow(x_train_small[0,:,:,0], vmin=0, vmax=1)
plt.colorbar()


def remove_contradicting(xs, ys):
    mapping = collections.defaultdict(set)
    orig_x = {}
    # Determine the set of labels for each unique image:
    for x,y in zip(xs,ys):
       orig_x[tuple(x.flatten())] = x
       mapping[tuple(x.flatten())].add(y)


    new_x = []
    new_y = []
    for flatten_x in mapping:
      x = orig_x[flatten_x]
      labels = mapping[flatten_x]
      if len(labels) == 1:
          new_x.append(x)
          new_y.append(next(iter(labels)))
      else:
          # Throw out images that match more than one label.
          pass


    num_uniq_3 = sum(1 for value in mapping.values() if len(value) == 1 and True in value)
    num_uniq_6 = sum(1 for value in mapping.values() if len(value) == 1 and False in value)
    num_uniq_both = sum(1 for value in mapping.values() if len(value) == 2)


    print("Number of unique images:", len(mapping.values()))
    print("Number of unique 3s: ", num_uniq_3)
    print("Number of unique 6s: ", num_uniq_6)
    print("Number of unique contradicting labels (both 3 and 6): ", num_uniq_both)
    print()
    print("Initial number of images: ", len(xs))
    print("Remaining non-contradicting unique images: ", len(new_x))


    return np.array(new_x), np.array(new_y)


x_train_nocon, y_train_nocon = remove_contradicting(x_train_small, y_train)


THRESHOLD = 0.5


x_train_bin = np.array(x_train_nocon > THRESHOLD, dtype=np.float32)
x_test_bin = np.array(x_test_small > THRESHOLD, dtype=np.float32)


_ = remove_contradicting(x_train_bin, y_train_nocon)


def create_classical_model():
    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(32, [3, 3], activation='relu', input_shape=(28,28,1)))
    model.add(tf.keras.layers.Conv2D(64, [3, 3], activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Dropout(0.25))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(1))
    return model


model = create_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])


model.summary()


model.fit(x_train,
          y_train,
          batch_size=128,
          epochs=1,
          verbose=1,
          validation_data=(x_test, y_test))


cnn_results = model.evaluate(x_test, y_test)


def create_fair_classical_model():
    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten(input_shape=(4,4,1)))
    model.add(tf.keras.layers.Dense(2, activation='relu'))
    model.add(tf.keras.layers.Dense(1))
    return model


model = create_fair_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])


model.summary()


model.fit(x_train_bin,
          y_train_nocon,
          batch_size=128,
          epochs=20,
          verbose=2,
          validation_data=(x_test_bin, y_test))


fair_nn_results = model.evaluate(x_test_bin, y_test)


cnn_accuracy = cnn_results[1]
fair_nn_accuracy = fair_nn_results[1]


sns.barplot(x=["Classical, full", "Classical, fair"],
            y=[cnn_accuracy, fair_nn_accuracy])


Found 0 files belonging to 0 classes.
Using 0 files for training.


ValueError: No images found in directory dataset/. Allowed formats: ('.bmp', '.gif', '.jpeg', '.jpg', '.png')