# Doodle CNN
A neural network to predict what was drawn, given 28x28 input picture pixel data\
Trained with quickdraw-dataset

In [None]:
import os
import numpy as np
from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import random
import zipfile
from contextlib import ExitStack
from datetime import datetime
import glob
from functools import partial
import json

### Step 1: Get TFRecord dataset filepaths

In [None]:
print(tf.__version__) # make sure 2.4.1 for compatibility with server

Get all data from .npy files, store in tensoflow Dataset object. Shuffle the data, and save to a .tfrecord file for easy access in the future

In [None]:
def create_example_protobuff(image, label):
    # convert to binary string format for Example protobuf
    image_data = tf.io.serialize_tensor(image)

    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data.numpy()])),
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }
        )
    )

In [None]:
def write_tfrecords(name, dataset):
    path = f'{name}.tfrecord'
    with ExitStack() as stack:
        writer = stack.enter_context(tf.io.TFRecordWriter(path))
        
        # create example protobuffs from instances
        for image, label in dataset:
            example = create_example_protobuff(image, np.uint8(label))
            writer.write(example.SerializeToString())
    return path

In [None]:
all_files = glob.glob("data/*")

def load_data():
    class_names = []
    all_filepaths = []
    num_files = 0

    # load each data file 
    for idx, file in enumerate(all_files):
        data = np.load(file)
        # data is 784, but need to reshape to 28x28 for CNN
        data = data.reshape((data.shape[0], 28, 28)).astype(np.uint8)
        labels = np.full(data.shape[0], idx)
        
        # convert numpy array to Tensorflow Dataset object
        dataset = tf.data.Dataset.from_tensor_slices((data, labels))

        # class name will be name of file e.g. 'fork.npy' is 'fork'
        class_name, ext = os.path.splitext(os.path.basename(file))
        class_names.append(class_name)
        print(class_name)
        
        # write Dataset to files
        filepaths = write_tfrecords(f"doodle-{class_name}", dataset)
        all_filepaths.append(filepaths)
        
        # logging
        num_files += 1
        # every 35 files
        if num_files % 35 == 0:
            print(f'{num_files} file npy to tfrecord')
        
    return all_filepaths, class_names
    
filepaths, class_names = load_data()

In [None]:
# --For Google Colab---
# for tfrecord files already created from the .npy files
# in my case, located in google drive: /content/drive/MyDrive/DoodleData


paths = ['data/lollipop.npy', 'data/binoculars.npy', 'data/garden.npy', 'data/basket.npy', 'data/penguin.npy', 'data/washing machine.npy', 'data/canoe.npy', 'data/eyeglasses.npy', 'data/beach.npy', 'data/screwdriver.npy', 'data/mouse.npy', 'data/apple.npy', 'data/van.npy', 'data/grapes.npy', 'data/grass.npy', 'data/watermelon.npy', 'data/floor lamp.npy', 'data/moon.npy', 'data/zigzag.npy', 'data/nail.npy', 'data/leg.npy', 'data/rollerskates.npy', 'data/goatee.npy', 'data/sun.npy copy', 'data/cup.npy', 'data/anvil.npy', 'data/suitcase.npy', 'data/chair.npy', 'data/drill.npy', 'data/peanut.npy', 'data/squirrel.npy', 'data/matches.npy', 'data/sword.npy', 'data/cat.npy', 'data/toe.npy', 'data/snorkel.npy', 'data/pond.npy', 'data/calculator.npy', 'data/airplane.npy', 'data/squiggle.npy', 'data/blackberry.npy', 'data/ear.npy', 'data/frying pan.npy', 'data/chandelier.npy', 'data/tree.npy', 'data/wine bottle.npy', 'data/peas.npy', 'data/hot tub.npy', 'data/door.npy', 'data/calendar.npy', 'data/wine glass.npy', 'data/stove.npy', 'data/hockey stick.npy', 'data/toothpaste.npy', 'data/moustache.npy', 'data/mountain.npy', 'data/tooth.npy', 'data/firetruck.npy', 'data/cannon.npy', 'data/stereo.npy', 'data/shorts.npy', 'data/cloud.npy', 'data/paintbrush.npy', 'data/pear.npy', 'data/frog.npy', 'data/laptop.npy', 'data/dishwasher.npy', 'data/vase.npy', 'data/diving board.npy', 'data/octagon.npy', 'data/smiley face.npy', 'data/dumbbell.npy', 'data/sweater.npy', 'data/stitches.npy', 'data/tractor.npy', 'data/foot.npy', 'data/basketball.npy', 'data/helmet.npy', 'data/crab.npy', 'data/clock.npy', 'data/diamond.npy', 'data/car.npy', 'data/axe.npy', 'data/traffic light.npy', 'data/sleeping bag.npy', 'data/baseball.npy', 'data/eye.npy', 'data/flower.npy', 'data/hot air balloon.npy', 'data/waterslide.npy', 'data/coffee cup.npy', 'data/bottlecap.npy', 'data/banana.npy', 'data/dresser.npy', 'data/house plant.npy', 'data/skyscraper.npy', 'data/skateboard.npy', 'data/pizza.npy', 'data/hammer.npy', 'data/teapot.npy', 'data/giraffe.npy', 'data/underwear.npy', 'data/snowman.npy', 'data/monkey.npy', 'data/computer.npy', 'data/pencil.npy', 'data/shovel.npy', 'data/knife.npy', 'data/bat.npy', 'data/compass.npy', 'data/necklace.npy', 'data/bicycle.npy', 'data/teddy-bear.npy', 'data/bucket.npy', 'data/line.npy', 'data/bus.npy', 'data/cello.npy', 'data/ocean.npy', 'data/truck.npy', 'data/camouflage.npy', 'data/harp.npy', 'data/stairs.npy', 'data/telephone.npy', 'data/star.npy', 'data/guitar.npy', 'data/sandwich.npy', 'data/sun.npy', 'data/feather.npy', 'data/leaf.npy', 'data/toilet.npy', 'data/strawberry.npy', 'data/birthday cake.npy', 'data/saxophone.npy', 'data/rake.npy', 'data/broom.npy', 'data/stethoscope.npy', 'data/square.npy', 'data/crown.npy', 'data/fire hydrant.npy', 'data/donut.npy', 'data/jail.npy', 'data/oven.npy', 'data/beard.npy', 'data/syringe.npy', 'data/yoga.npy', 'data/The Eiffel Tower.npy', 'data/camera.npy', 'data/purse.npy', 'data/ice cream.npy', 'data/pig.npy', 'data/trumpet.npy', 'data/table.npy', 'data/bush.npy', 'data/scorpion.npy', 'data/fish.npy', 'data/hot dog.npy', 'data/see saw.npy', 'data/rain.npy', 'data/snail.npy', 'data/sink.npy', 'data/belt.npy', 'data/speedboat.npy', 'data/trombone.npy', 'data/pants.npy', 'data/crocodile.npy', 'data/broccoli.npy', 'data/hedgehog.npy', 'data/rainbow.npy', 'data/bulldozer.npy', 'data/fork.npy', 'data/sock.npy', 'data/snake.npy', 'data/paper clip.npy', 'data/bear.npy', 'data/marker.npy', 'data/tent.npy', 'data/rabbit.npy', 'data/clarinet.npy', 'data/whale.npy', 'data/boomerang.npy', 'data/hospital.npy', 'data/ceiling fan.npy', 'data/pillow.npy', 'data/saw.npy', 'data/fence.npy', 'data/parrot.npy', 'data/duck.npy', 'data/dog.npy', 'data/swing set.npy', 'data/spoon.npy', 'data/fan.npy', 'data/cruise ship.npy', 'data/picture frame.npy', 'data/mushroom.npy', 'data/headphones.npy', 'data/horse.npy', 'data/flying saucer.npy', 'data/skull.npy', 'data/rifle.npy', 'data/train.npy', 'data/hat.npy', 'data/mouth.npy', 'data/book.npy', 'data/drums.npy', 'data/radio.npy', 'data/roller coaster.npy', 'data/snowflake.npy', 'data/piano.npy', 'data/rhinoceros.npy', 'data/cake.npy', 'data/paint can.npy', 'data/toaster.npy', 'data/knee.npy', 'data/spider.npy', 'data/sea turtle.npy', 'data/popsicle.npy', 'data/pickup truck.npy', 'data/envelope.npy', 'data/remote control.npy', 'data/ambulance.npy', 'data/pliers.npy', 'data/bread.npy', 'data/castle.npy', 'data/river.npy', 'data/bandage.npy', 'data/lion.npy', 'data/postcard.npy', 'data/bench.npy', 'data/parachute.npy', 'data/keyboard.npy', 'data/streetlight.npy', 'data/arm.npy', 'data/police car.npy', 'data/sailboat.npy', 'data/cooler.npy', 'data/bathtub.npy', 'data/hurricane.npy', 'data/campfire.npy', 'data/soccer ball.npy', 'data/potato.npy', 'data/dolphin.npy', 'data/key.npy', 'data/elephant.npy', 'data/tornado.npy', 'data/jacket.npy', 'data/nose.npy', 'data/motorbike.npy', 'data/octopus.npy', 'data/bracelet.npy', 'data/brain.npy', 'data/toothbrush.npy', 'data/The Mona Lisa.npy', 'data/carrot.npy', 'data/barn.npy', 'data/zebra.npy', 'data/microphone.npy', 'data/map.npy', 'data/camel.npy', 'data/wheel.npy', 'data/bridge.npy', 'data/lighthouse.npy', 'data/spreadsheet.npy', 'data/hockey puck.npy', 'data/wristwatch.npy', 'data/helicopter.npy', 'data/swan.npy', 'data/flamingo.npy', 'data/backpack.npy', 'data/lobster.npy', 'data/golf club.npy', 'data/hexagon.npy', 'data/garden hose.npy', 'data/bird.npy', 'data/animal migration.npy', 'data/finger.npy', 'data/steak.npy', 'data/mailbox.npy', 'data/shark.npy', 'data/television.npy', 'data/mermaid.npy', 'data/cow.npy', 'data/crayon.npy', 'data/palm tree.npy', 'data/windmill.npy', 'data/cookie.npy', 'data/kangaroo.npy', 'data/blueberry.npy', 'data/tennis racquet.npy', 'data/tiger.npy', 'data/dragon.npy', 'data/cell phone.npy', 'data/pineapple.npy', 'data/sheep.npy', 'data/candle.npy', 'data/angel.npy', 'data/cactus.npy', 'data/mosquito.npy', 'data/couch.npy', 'data/church.npy', 'data/The Great Wall of China.npy', 'data/hamburger.npy', 'data/school bus.npy', 'data/lipstick.npy', 'data/light bulb.npy', 'data/flip flops.npy', 'data/alarm clock.npy', 'data/aircraft carrier.npy', 'data/face.npy', 'data/ant.npy', 'data/microwave.npy', 'data/hourglass.npy', 'data/panda.npy', 'data/pool.npy', 'data/circle.npy', 'data/onion.npy', 'data/lighter.npy', 'data/raccoon.npy', 'data/bowtie.npy', 'data/umbrella.npy', 'data/butterfly.npy', 'data/fireplace.npy', 'data/eraser.npy', 'data/bee.npy', 'data/flashlight.npy', 'data/megaphone.npy', 'data/asparagus.npy', 'data/shoe.npy', 'data/ladder.npy', 'data/t-shirt.npy', 'data/passport.npy', 'data/triangle.npy', 'data/hand.npy', 'data/lightning.npy', 'data/mug.npy', 'data/submarine.npy', 'data/violin.npy', 'data/owl.npy', 'data/scissors.npy', 'data/string bean.npy', 'data/baseball bat.npy', 'data/lantern.npy', 'data/house.npy', 'data/elbow.npy', 'data/power outlet.npy', 'data/stop sign.npy', 'data/bed.npy']
class_names = []
for path in paths:
    class_name, ext = os.path.splitext(os.path.basename(path))
    class_names.append(class_name)


# filepaths are all file located in tfrecord dir
filepaths = glob.glob("tfrecords/*")

In [None]:
class DoodleDataset:
  '''
    Create TFRecordDataset from filepaths
  '''
  def __init__(self, filepaths, shuffle_buffer_size, batch_size=128):
    self.filepaths = filepaths
    self.shuffle_buffer_size = shuffle_buffer_size
    self.batch_size = batch_size

  # parse serialized Example protobuf
  def preprocess(self, tfrecord):
    # to parse we need the feature description of the protobuf
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=""),
        'label': tf.io.FixedLenFeature([], tf.int64, default_value=1)
    }
    parsed_example = tf.io.parse_single_example(tfrecord, feature_description)
    image = tf.io.parse_tensor(parsed_example['image'], out_type=tf.uint8)
    # now reshape
    image = tf.reshape(image, [28, 28])
    return image, parsed_example['label']


  def create_dataset(self):
    # reading all filepaths in parallel
    dataset = tf.data.TFRecordDataset(self.filepaths, num_parallel_reads=len(self.filepaths))

    # shuffling
    dataset = dataset.shuffle(self.shuffle_buffer_size)

    # parse serialized Dataset
    dataset = dataset.map(self.preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(self.batch_size)
    # be 1 batch ahead
    return dataset.prefetch(1)



In [None]:
# create doodle dataset object
# need good shuffling so set the shuffle buffer size very large (even though it takes longer)
doodle_dataset = DoodleDataset(filepaths, shuffle_buffer_size=4000000)

full_set = doodle_dataset.create_dataset()

val_set = full_set.take(50000)

tmp_set = full_set.skip(50000)

test_set = tmp_set.take(50000)

# rest
train_set = tmp_set.skip(50000)

In [None]:
def create_dataset(filepath):
  # reading all filepaths in parallel
    dataset = tf.data.TFRecordDataset(filepath)
    
    dataset = dataset.shuffle(4000)

  # parse serialized Dataset
    dataset = dataset.map(doodle_dataset.preprocess)
  # be 1 batch ahead
    return dataset.prefetch(1)
  
datasets = []
for filepath in filepaths:
    dataset = create_dataset(filepath)
    datasets.append(dataset)

## Generate label data
Go through different samples of all 346 classes and pick a image from each class to show the user when they click the help button on the website

In [None]:
def get_label_data():
    label_data = dict()
    
    for idx, dataset in enumerate(datasets):
        satisfied = False
        while not satisfied:
            for (x, y) in dataset.take(1):
                x_npy = np.array(x)
                filepath = filepaths[idx]
                class_name, ext = os.path.splitext(os.path.basename(filepath))
                class_name = class_name[7:] # trim off 'doodle-'
                # display pixel data of sample
                for i in range(28):
                    for j in range(28):
                        if x_npy[i, j] < 10:
                            print(x_npy[i, j], end="")
                            print("   ", end="")
                        elif x_npy[i, j] < 100:
                            print(x_npy[i, j], end="")
                            print("  ", end="")
                        else:
                            print(x_npy[i, j], end="")
                            print(" ", end="")
                    print('')
                
                question = input(f"Are you satisfied with {class_name}? (y/N): ")
                if question == 'y':
                    label_data[class_name] = x_npy.tolist()
                    satisfied = True
    return label_data



label_data = get_label_data()
with open('label_data.json', 'w') as f:
  # put data into file
  json.dump(label_data, f)


### Now that we have and can save and load from tfrecord files, lets see some example images that we loaded!

In [None]:
import matplotlib.pyplot as plt
for (X, y) in test_set.take(1):
    for i in range(5):
        plt.subplot(1, 5, i + 1)
        plt.imshow(X[i].numpy(), cmap="binary")
        
        plt.axis("off")
        plt.title(class_names[y[i].numpy()])

## Step 2
Machine learning model

In [None]:
# clear session from possible previous models
keras.backend.clear_session()

tf.random.set_seed(42)
np.random.seed(42)

# standardization
# for each feature, subract the mean and divide by standard deviation
# epsilon for divide by 0
class Standardization(keras.layers.Layer):
    def adapt(self, data_sample):
        self.means_ = np.mean(data_sample, axis=0, keepdims=True)
        self.stds_ = np.std(data_sample, axis=0, keepdims=True)
    def call(self, inputs):
        return (inputs - self.means_) / (self.stds_ + keras.backend.epsilon())

standardization = Standardization(input_shape=[28, 28, 1])

# have to adapt to dataset
# this will allow it to use the right mean and std dev for each feature
sample_image_batches = train_set.take(1000).map(lambda image, label: image)
sample_images = np.concatenate(list(sample_image_batches.as_numpy_iterator()),
                               axis=0).astype(np.float32)
standardization.adapt(sample_images)

### Architecture

Using Google's ResNet-34 Architecture

In [None]:
# creates a new partial class 'DefaultConv2D' with starting params of a keras.layers.Conv2D
DefaultConv2D = partial(keras.layers.Conv2D, kernel_size=3, strides=1,
                        padding="SAME", use_bias=False)

class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.filters=filters
        self.strides=strides
        self.main_layers = [
            DefaultConv2D(filters, strides=strides),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(filters),
            keras.layers.BatchNormalization()]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                DefaultConv2D(filters, kernel_size=1, strides=strides),
                keras.layers.BatchNormalization()]

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'strides': self.strides,
            'activation': self.activation,
        })
        return config

    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

model = keras.models.Sequential()
model.add(tf.keras.layers.Reshape((28, 28, 1), input_shape=(28, 28)))

model.add(DefaultConv2D(64, kernel_size=7, strides=2))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.MaxPool2D(pool_size=3, strides=2, padding="SAME"))
prev_filters = 64

# 3 ResidualUnits with 64 feature maps, then 4 ResidualUnits with 128 feature maps, then . . .
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
    strides = 1 if filters == prev_filters else 2
    model.add(ResidualUnit(filters, strides=strides))
    prev_filters = filters
model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(346, activation="softmax"))
 
    
    
model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])

In [None]:
model_filepath = "/content/drive/MyDrive/DoodleData/my_doodle_model.h5"

# create log directory for tensorboard
logs = os.path.join(os.curdir, "my_logs", "run_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=logs, histogram_freq=1, profile_batch=10, update_freq=10000)
# stop if model doesn't improve after 2 epochs
early_stopping_cb = keras.callbacks.EarlyStopping(patience=2)
# save best model after each epoch
model_checkpoint_cb = keras.callbacks.ModelCheckpoint(filepath=model_filepath,
                                                      save_best_only=True,
                                                      save_freq='epoch',
                                                      monitor='accuracy')

callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]



In [None]:
history = model.fit(train_set, epochs=5, verbose=1, validation_data=val_set, callbacks=callbacks)

In [None]:
%load_ext tensorboard
%tensorboard --logdir=./my_logs --port=6006