In [None]:
pip install quickdraw

Collecting quickdraw
  Downloading quickdraw-1.0.0-py3-none-any.whl.metadata (1.3 kB)
Downloading quickdraw-1.0.0-py3-none-any.whl (11 kB)
Installing collected packages: quickdraw
Successfully installed quickdraw-1.0.0


In [None]:
from quickdraw import QuickDrawData
from quickdraw import QuickDrawDataGroup
from pathlib import Path

image_size = (256, 256)

def generate_class_images(name, max_drawings, recognized):
    directory = Path("dataset/" + name)

    if not directory.exists():
        directory.mkdir(parents=True)

    images = QuickDrawDataGroup(name, max_drawings=max_drawings, recognized=recognized)
    for img in images.drawings:
        filename = directory.as_posix() + "/" + str(img.key_id) + ".png"
        img.get_image(stroke_width=3).resize(image_size).save(filename)

for label in QuickDrawData().drawing_names:
    generate_class_images(label, max_drawings=1200, recognized=True)

downloading aircraft carrier from https://storage.googleapis.com/quickdraw_dataset/full/binary/aircraft carrier.bin
download complete
loading aircraft carrier drawings
load complete
downloading airplane from https://storage.googleapis.com/quickdraw_dataset/full/binary/airplane.bin
download complete
loading airplane drawings
load complete
downloading alarm clock from https://storage.googleapis.com/quickdraw_dataset/full/binary/alarm clock.bin
download complete
loading alarm clock drawings
load complete
downloading ambulance from https://storage.googleapis.com/quickdraw_dataset/full/binary/ambulance.bin
download complete
loading ambulance drawings
load complete
downloading angel from https://storage.googleapis.com/quickdraw_dataset/full/binary/angel.bin
download complete
loading angel drawings
load complete
downloading animal migration from https://storage.googleapis.com/quickdraw_dataset/full/binary/animal migration.bin
download complete
loading animal migration drawings
load complete
d

KeyboardInterrupt: 

In [None]:
import datetime, os
import tensorflow as tf
from tensorflow import summary

from pathlib import Path
from matplotlib import pyplot as plt
from quickdraw import QuickDrawDataGroup, QuickDrawData

from tensorflow.keras.preprocessing import image_dataset_from_directory

from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import sparse_categorical_accuracy
from tensorflow.keras.layers import Rescaling
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout, BatchNormalization

from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ReduceLROnPlateau

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPU memory growth enabled for {len(gpus)} GPU(s)")

        policy = tf.keras.mixed_precision.Policy('mixed_float16')
        tf.keras.mixed_precision.set_global_policy(policy)
        print("Mixed precision enabled for A100 GPU")

    except RuntimeError as e:
        print(f"GPU configuration error: {e}")
else:
    print("No GPUs found, using CPU")

image_size = (256, 256)
batch_size = 32  
epochs = 80

train_ds = image_dataset_from_directory(
    "dataset",
    validation_split=0.2,
    subset="training",
    seed=123,
    color_mode="grayscale",
    image_size=image_size,
    batch_size=batch_size
)

val_ds = image_dataset_from_directory(
    "dataset",
    validation_split=0.2,
    subset="validation",
    seed=123,
    color_mode="grayscale",
    image_size=image_size,
    batch_size=batch_size
)

class_names = train_ds.class_names
n_classes = len(class_names)

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)  
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

print("Dataset optimization applied: prefetching enabled (cache disabled to save RAM)")

print(f"Found {n_classes} classes:")
for i, class_name in enumerate(class_names):
    print(f"  {i}: {class_name}")

class_to_index = {class_name: i for i, class_name in enumerate(class_names)}
index_to_class = {i: class_name for i, class_name in enumerate(class_names)}

print(f"\nClass mapping saved for later use:")
print(f"class_to_index: {class_to_index}")

input_shape = (256, 256, 1)

model = Sequential([
    Rescaling(1. / 255, input_shape=input_shape),

    Conv2D(32, kernel_size=(3, 3), padding="same", activation="relu"),
    Conv2D(32, kernel_size=(3, 3), padding="same", activation="relu"),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(64, kernel_size=(3, 3), padding="same", activation="relu"),
    Conv2D(64, kernel_size=(3, 3), padding="same", activation="relu"),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(128, kernel_size=(3, 3), padding="same", activation="relu"),
    Conv2D(128, kernel_size=(3, 3), padding="same", activation="relu"),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(256, kernel_size=(3, 3), padding="same", activation="relu"),
    MaxPooling2D(pool_size=(2, 2)),

    Flatten(),

    Dense(512, activation="relu"),
    BatchNormalization(),
    Dropout(0.5),

    Dense(256, activation="relu"),
    BatchNormalization(),
    Dropout(0.5),

    Dense(n_classes, activation="softmax")
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

model.summary()

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = TensorBoard(logdir, histogram_freq=0)

early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=12,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.5,  
    patience=4,
    min_lr=1e-7,
    verbose=1
)

callbacks = [tensorboard_callback, early_stopping, reduce_lr]

print(f"\nStarting training with:")
print(f"- Image size: {image_size}")
print(f"- Batch size: {batch_size}")
print(f"- Max epochs: {epochs}")
print(f"- Runtime: A100 GPU")
print(f"- Mixed precision: Enabled")
print(f"- Expected time: ~2.5-4 hours total (2min/epoch)")
print(f"- Early stopping patience: 12 epochs")
print(f"- Cache disabled to prevent RAM overflow")

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    verbose=1,
    callbacks=callbacks
)

model_filename = './models/model_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '.h5'
model.save(model_filename)

import json
labels_filename = './models/class_labels_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '.json'
with open(labels_filename, 'w') as f:
    json.dump({
        'class_names': class_names,
        'class_to_index': class_to_index,
        'index_to_class': index_to_class,
        'n_classes': n_classes
    }, f, indent=2)

print(f"\nModel saved to: {model_filename}")
print(f"Class labels saved to: {labels_filename}")

def predict_with_labels(model, image, class_names):
    """
    Make a prediction and return both the class index and class name
    """
    predictions = model.predict(image)
    predicted_index = tf.argmax(predictions[0]).numpy()
    predicted_class = class_names[predicted_index]
    confidence = predictions[0][predicted_index]

    return predicted_index, predicted_class, confidence

print(f"\nTo use the saved model later:")
print(f"# Load the model")
print(f"model = tf.keras.models.load_model('{model_filename}')")
print(f"# Load the class labels")
print(f"with open('{labels_filename}', 'r') as f:")
print(f"    label_data = json.load(f)")
print(f"class_names = label_data['class_names']")
print(f"# Then use predict_with_labels() function for predictions")

GPU memory growth enabled for 1 GPU(s)
Mixed precision enabled for A100 GPU
Found 414000 files belonging to 345 classes.
Using 331200 files for training.
Found 414000 files belonging to 345 classes.
Using 82800 files for validation.
Dataset optimization applied: prefetching enabled (cache disabled to save RAM)
Found 345 classes:
  0: The Eiffel Tower
  1: The Great Wall of China
  2: The Mona Lisa
  3: aircraft carrier
  4: airplane
  5: alarm clock
  6: ambulance
  7: angel
  8: animal migration
  9: ant
  10: anvil
  11: apple
  12: arm
  13: asparagus
  14: axe
  15: backpack
  16: banana
  17: bandage
  18: barn
  19: baseball
  20: baseball bat
  21: basket
  22: basketball
  23: bat
  24: bathtub
  25: beach
  26: bear
  27: beard
  28: bed
  29: bee
  30: belt
  31: bench
  32: bicycle
  33: binoculars
  34: bird
  35: birthday cake
  36: blackberry
  37: blueberry
  38: book
  39: boomerang
  40: bottlecap
  41: bowtie
  42: bracelet
  43: brain
  44: bread
  45: bridge
  46: b

  super().__init__(**kwargs)



Starting training with:
- Image size: (256, 256)
- Batch size: 32
- Max epochs: 80
- Runtime: A100 GPU
- Mixed precision: Enabled
- Expected time: ~2.5-4 hours total (2min/epoch)
- Early stopping patience: 12 epochs
- Cache disabled to prevent RAM overflow
Epoch 1/80
[1m10350/10350[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m463s[0m 41ms/step - accuracy: 0.0693 - loss: 5.0083 - val_accuracy: 0.0127 - val_loss: 13.0651 - learning_rate: 1.0000e-04
Epoch 2/80
[1m10350/10350[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 40ms/step - accuracy: 0.2363 - loss: 3.4931 - val_accuracy: 0.0282 - val_loss: 10.6459 - learning_rate: 1.0000e-04
Epoch 3/80
[1m10350/10350[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 40ms/step - accuracy: 0.2952 - loss: 3.1315 - val_accuracy: 0.0578 - val_loss: 8.2519 - learning_rate: 1.0000e-04
Epoch 4/80
[1m10350/10350[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 40ms/step - accuracy: 0.3269 - loss: 2.9402 - val_accuracy: 0.0590




Model saved to: ./models/model_20250707-231404.h5
Class labels saved to: ./models/class_labels_20250707-231405.json

To use the saved model later:
# Load the model
model = tf.keras.models.load_model('./models/model_20250707-231404.h5')
# Load the class labels
with open('./models/class_labels_20250707-231405.json', 'r') as f:
    label_data = json.load(f)
class_names = label_data['class_names']
# Then use predict_with_labels() function for predictions
