In [36]:
# import

import pandas as pd
from quickdraw import QuickDrawData, QuickDrawing
from dataclasses import dataclass, asdict
from PIL import Image, ImageChops, ImageDraw, ImageFont
from pprint import pprint
import numpy as np
import random
import os

In [37]:
label_info = pd.read_csv("../labels.csv")
coco_info = pd.read_csv("../output.csv")

In [38]:
qd = QuickDrawData()
qd.load_all_drawings()

loading aircraft carrier drawings
load complete
loading airplane drawings
load complete
loading alarm clock drawings
load complete
loading ambulance drawings
load complete
loading angel drawings
load complete
loading animal migration drawings
load complete
loading ant drawings
load complete
loading anvil drawings
load complete
loading apple drawings
load complete
loading arm drawings
load complete
loading asparagus drawings
load complete
loading axe drawings
load complete
loading backpack drawings
load complete
loading banana drawings
load complete
loading bandage drawings
load complete
loading barn drawings
load complete
loading baseball bat drawings
load complete
loading baseball drawings
load complete
loading basket drawings
load complete
loading basketball drawings
load complete
loading bat drawings
load complete
loading bathtub drawings
load complete
loading beach drawings
load complete
loading bear drawings
load complete
loading beard drawings
load complete
loading bed drawings
l

In [39]:
# COCO label name -> label info
label_info_from_coco_name = {}
for row in label_info.itertuples():
    label_info_from_coco_name[row.coco_label] = row

In [50]:
@dataclass
class Draw:
    drawing: QuickDrawing
    position: tuple[int, int]


@dataclass
class Draws:
    drawings: list[Draw]
    filename: str

    def get_animation(self, margin=100, down_scale=3, in_text=True):
        canvas = Image.new('RGB', (640, 480), (255, 255, 255))
        canvas = self.__add_margin(canvas, margin, margin, margin, margin, (255, 255, 255))
        images = []
        for draw in self.drawings:
            draw_image = draw.drawing.get_image(stroke_width=2 * down_scale)
            mask_image = self.__get_mask(draw_image, (0, 0, 0))
            draw_image = draw_image.resize((draw_image.width // down_scale, draw_image.height // down_scale))
            mask_image = mask_image.resize((mask_image.width // down_scale, mask_image.height // down_scale))
            canvas.paste(
                draw_image, 
                (draw.position[0] - (draw_image.width // 2) + margin, draw.position[1] - (draw_image.height // 2) + margin),
                mask=mask_image)
            image_draw = ImageDraw.Draw(canvas)
            image_draw.text(
                (draw.position[0] + margin, draw.position[1] + (draw_image.height // 2) + margin),
                draw.drawing.name,
                fill='black',
                font=ImageFont.truetype(font='/System/Library/Fonts/Helvetica.ttc', size=18))
            images.append(canvas.copy())
        return images
    
    def save_animation(self, filename, duration=40, loop=0):
        animation = self.get_animation()
        animation[0].save(filename, save_all=True, append_images=animation[1:], optimize=False, duration=duration, loop=loop)
    
    def save(self, filename):
        animation = self.get_animation()
        animation[-1].save(filename)
    
    def __add_margin(self, pil_img, top, right, bottom, left, color):
        width, height = pil_img.size
        new_width = width + right + left
        new_height = height + top + bottom
        result = Image.new(pil_img.mode, (new_width, new_height), color)
        result.paste(pil_img, (left, top))
        return result
    
    def __get_mask(self, image, color):
        r, g, b = image.split()
        _r = r.point(lambda _: 1 if _ == color[0] else 0, mode="1")
        _g = g.point(lambda _: 1 if _ == color[1] else 0, mode="1")
        _b = b.point(lambda _: 1 if _ == color[2] else 0, mode="1")
        mask = ImageChops.logical_and(_r, _g)
        mask = ImageChops.logical_and(mask, _b)
        return mask

In [41]:
def generate_drawings(index=None, min_objects=2, seed=None, erosion=None):
    # set seed
    random.seed(seed)
    np.random.seed(seed)
    # get random coco info
    coco = list(coco_info.query(f'num>={min_objects}').sample(random_state=seed).itertuples())[0] if not index else list(coco_info[index:index+1].itertuples())[0]
    # get coco image
    image = np.asarray(Image.open(f"../{coco.name}"))
    # iterate objects and create sequence
    drawings = []
    for coco_name in coco.objects.split(", "):
        # get label info
        label = label_info_from_coco_name[coco_name]
        # if erosion
        if erosion:
            # TODO: erosion
            pass
        # get random position of object from image
        pos = random.choice(list(zip(*np.where(image == label.id - 1))))
        # get quick draw strokes
        name = random.choice(label.quick_draw_labels.split(", "))
        drawings.append(Draw(qd.get_drawing(name), (pos[1], pos[0])))
    random.shuffle(drawings)
    return Draws(drawings, coco.name)

In [68]:
draws = generate_drawings(index=888, min_objects=10, seed=67)
pprint(asdict(draws))
filename = os.path.splitext(os.path.basename(draws.filename))[0]
draws.save_animation(f'{filename}.gif', duration=800)
draws.save(f'{filename}.png')

{'drawings': [{'drawing': <quickdraw.data.QuickDrawing object at 0x17aa63670>,
               'position': (407, 83)},
              {'drawing': <quickdraw.data.QuickDrawing object at 0x17aa635e0>,
               'position': (476, 186)},
              {'drawing': <quickdraw.data.QuickDrawing object at 0x17aafecd0>,
               'position': (46, 308)},
              {'drawing': <quickdraw.data.QuickDrawing object at 0x17aafeaf0>,
               'position': (223, 185)}],
 'filename': 'data/stuffthingmaps_trainval2017/val2017/000000450758.png'}


In [67]:
for i in range(1, 101):
    draws = generate_drawings(index=888, min_objects=10, seed=i)
    draws.save(f'{filename}-{i}.png')
