# Sketch Recognition using Simple Neural Network (NN)

In [1]:
# Selecting Tensorflow version v2 (the command is relevant for Colab only).
# %tensorflow_version 2.x

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import math
import datetime
import platform
import pathlib

print('Python version:', platform.python_version())
print('Tensorflow version:', tf.__version__)
print('Keras version:', tf.keras.__version__)

Python version: 3.7.6
Tensorflow version: 2.1.0
Keras version: 2.2.4-tf


In [3]:
cache_dir = 'tmp';

In [4]:
# Create cache folder.
!mkdir tmp

mkdir: tmp: File exists


## Load labels

In [5]:
def load_quick_draw_labels(cache_dir):
    labels_url = 'https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt'
    labels_file_name = 'categories.txt'
    
    labels_path = tf.keras.utils.get_file(
        fname=labels_file_name, 
        origin=labels_url,
        cache_dir=pathlib.Path(cache_dir).absolute()
    )
    
    labels_file = open(labels_path, 'r')
    labels = labels_file.read().splitlines()
    labels_file.close()
    
    labels = [label.replace(' ', '_') for label in labels]
    
    return np.array(labels)

In [6]:
labels = load_quick_draw_labels(cache_dir)

print()
print('labels:\n', labels)


labels:
 ['aircraft_carrier' 'airplane' 'alarm_clock' 'ambulance' 'angel'
 'animal_migration' 'ant' 'anvil' 'apple' 'arm' 'asparagus' 'axe'
 'backpack' 'banana' 'bandage' 'barn' 'baseball' 'baseball_bat' 'basket'
 'basketball' 'bat' 'bathtub' 'beach' 'bear' 'beard' 'bed' 'bee' 'belt'
 'bench' 'bicycle' 'binoculars' 'bird' 'birthday_cake' 'blackberry'
 'blueberry' 'book' 'boomerang' 'bottlecap' 'bowtie' 'bracelet' 'brain'
 'bread' 'bridge' 'broccoli' 'broom' 'bucket' 'bulldozer' 'bus' 'bush'
 'butterfly' 'cactus' 'cake' 'calculator' 'calendar' 'camel' 'camera'
 'camouflage' 'campfire' 'candle' 'cannon' 'canoe' 'car' 'carrot' 'castle'
 'cat' 'ceiling_fan' 'cello' 'cell_phone' 'chair' 'chandelier' 'church'
 'circle' 'clarinet' 'clock' 'cloud' 'coffee_cup' 'compass' 'computer'
 'cookie' 'cooler' 'couch' 'cow' 'crab' 'crayon' 'crocodile' 'crown'
 'cruise_ship' 'cup' 'diamond' 'dishwasher' 'diving_board' 'dog' 'dolphin'
 'donut' 'door' 'dragon' 'dresser' 'drill' 'drums' 'duck' 'dumbbell' 'e

## Download dataset

In [53]:
def download_quick_draw_dataset(cache_dir, labels):    
    dataset_base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
    dataset_file_paths = []
    
    for label in labels:
        dataset_file_name = label + '.npy'
        dataset_url = dataset_base_url + label.replace('_', '%20') + '.npy'
            
        dataset_file_path = tf.keras.utils.get_file(
            fname=dataset_file_name, 
            origin=dataset_url,
            cache_dir=pathlib.Path(cache_dir).absolute()
        )
        
        dataset_file_paths.append(dataset_file_path)
    
    return dataset_file_paths

In [54]:
dataset_file_paths = download_quick_draw_dataset(cache_dir, labels)

In [55]:
%ls -la ./tmp/datasets

total 80425152
drwxr-xr-x  348 trekhleb  staff      11136 Apr 27 13:51 [34m.[m[m/
drwxr-xr-x    3 trekhleb  staff         96 Apr 27 12:19 [34m..[m[m/
-rw-r--r--    1 trekhleb  staff  105684064 Apr 27 13:44 The_Eiffel_Tower.npy
-rw-r--r--    1 trekhleb  staff  151323840 Apr 27 13:45 The_Great_Wall_of_China.npy
-rw-r--r--    1 trekhleb  staff   95164352 Apr 27 13:45 The_Mona_Lisa.npy
-rw-r--r--    1 trekhleb  staff   91339216 Apr 27 12:19 aircraft_carrier.npy
-rw-r--r--    1 trekhleb  staff  118872512 Apr 27 12:19 airplane.npy
-rw-r--r--    1 trekhleb  staff   96744896 Apr 27 12:19 alarm_clock.npy
-rw-r--r--    1 trekhleb  staff  116035216 Apr 27 12:42 ambulance.npy
-rw-r--r--    1 trekhleb  staff  117393104 Apr 27 12:42 angel.npy
-rw-r--r--    1 trekhleb  staff  108072128 Apr 27 12:42 animal_migration.npy
-rw-r--r--    1 trekhleb  staff   97695888 Apr 27 12:42 ant.npy
-rw-r--r--    1 trekhleb  staff   98965184 Apr 27 12:42 anvil.npy
-rw-r--r--    1 trekhleb  staff  1

In [87]:
def generate_quick_draw_dataset(examples_per_label, labels, cache_dir):
    pixels_per_drawing = 28 * 28
    
    x = np.empty(shape=(0, pixels_per_drawing));
    y = np.empty(shape=(0));
    
    for label_index, label in enumerate(labels[:2]):
        dataset_file_path = cache_dir + '/datasets/' + label + '.npy'
        images = np.load(dataset_file_path, allow_pickle=False)
        
        images = images[:examples_per_label]
        
        (batch_size, pixels) = images.shape
        
        images_labels = np.full(
            shape=batch_size,
            fill_value=label_index
        )

        x = np.concatenate((x, images), axis=0)
        y = np.append(y, images_labels)
    
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
        
#         ds = tf.data.Dataset.from_tensor_slices((images, label))
#         print(list(ds.take(1).as_numpy_iterator()))
        
#         print(label)
#         plt.imshow(images[11])

    return x, y
    
generate_quick_draw_dataset(
    examples_per_label=10,
    labels=labels,
    cache_dir=cache_dir
)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(10, 784)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(20, 784)


(array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.]))