In [1]:
import sys

sys.path.append("../sketchformer")


In [2]:
from tqdm import tqdm
import json
import pandas as pd
import tensorflow as tf
import numpy as np


In [3]:
# load dataset
with open('../data/isketcher/train.json', 'r') as f:
    train = json.load(f)
with open('../data/isketcher/valid.json', 'r') as f:
    valid = json.load(f)
with open('../data/isketcher/test.json', 'r') as f:
    test = json.load(f)

print(f"train: {len(train)}, valid: {len(valid)}, test: {len(test)}")


train: 5617, valid: 535, test: 1113


In [4]:
# load class label
df = pd.read_csv('../outputs/sketchyscene_quickdraw.csv')
df = df.dropna(subset=['quickdraw_label'])
class_names = []
for row in df.itertuples():
    class_names.append(row.quickdraw_label)
class_to_num = dict(zip(class_names, range(0, len(class_names))))

print(len(class_names))
print(class_names)
print(class_to_num)


40
['airplane', 'apple', 'hot air balloon', 'banana', 'basket', 'bee', 'bench', 'bicycle', 'bird', 'wine bottle', 'bucket', 'bus', 'butterfly', 'car', 'cat', 'chair', 'cloud', 'cow', 'cup', 'dog', 'duck', 'fence', 'flower', 'grapes', 'grass', 'horse', 'house', 'moon', 'mountain', 'face', 'pig', 'rabbit', 'sheep', 'star', 'streetlight', 'sun', 'table', 'tree', 'truck', 'umbrella']
{'airplane': 0, 'apple': 1, 'hot air balloon': 2, 'banana': 3, 'basket': 4, 'bee': 5, 'bench': 6, 'bicycle': 7, 'bird': 8, 'wine bottle': 9, 'bucket': 10, 'bus': 11, 'butterfly': 12, 'car': 13, 'cat': 14, 'chair': 15, 'cloud': 16, 'cow': 17, 'cup': 18, 'dog': 19, 'duck': 20, 'fence': 21, 'flower': 22, 'grapes': 23, 'grass': 24, 'horse': 25, 'house': 26, 'moon': 27, 'mountain': 28, 'face': 29, 'pig': 30, 'rabbit': 31, 'sheep': 32, 'star': 33, 'streetlight': 34, 'sun': 35, 'table': 36, 'tree': 37, 'truck': 38, 'umbrella': 39}


In [5]:
# load sketchformer
from basic_usage.sketchformer import continuous_embeddings
sketchformer = continuous_embeddings.get_pretrained_model()


[run-experiment] resorting checkpoint if exists
[Checkpoint] Restored, step #207536


In [6]:
# define preprocess
def preprocess(dataset):
    input_batch = []
    label_batch = []
    for scene in tqdm(dataset):
        sketches = list(map(lambda o: o['sketch'], scene))
        sketch_embeddings = sketchformer.get_embeddings(sketches)
        input_scene = []
        labels = []
        for se, obj in zip(sketch_embeddings, scene):
            p = [obj['position'][0] / 750, obj['position'][1] / 750]
            o = se.numpy().tolist() + p
            input_scene.append(o)  # オブジェクトの数が不規則
            labels.append(class_to_num[obj['label']])  # convert to num
        input_batch.append(input_scene)
        label_batch.append(labels)
    return tf.ragged.constant(input_batch).to_tensor(0.), tf.ragged.constant(label_batch).to_tensor(0)


In [7]:
# preprocess
print("Preprocessing train dataset")
x_train, y_train = preprocess(train)
print("Preprocessing valid dataset")
x_valid, y_valid = preprocess(valid)
print("Preprocessing test dataset")
x_test, y_test = preprocess(test)

Preprocessing train dataset


100%|██████████| 5617/5617 [03:44<00:00, 24.97it/s]


Preprocessing valid dataset


100%|██████████| 535/535 [00:21<00:00, 24.97it/s]


Preprocessing test dataset


100%|██████████| 1113/1113 [00:45<00:00, 24.64it/s]


In [8]:
np.savez_compressed('../data/isketcher/dataset.npz', x_train=x_train, y_train=y_train,
                    x_valid=x_valid, y_valid=y_valid, x_test=x_test, y_test=y_test)


In [9]:
dataset = np.load('../data/isketcher/dataset.npz')
print(dataset.files)
print(dataset['x_train'].shape)
print(dataset['y_train'].shape)
print(dataset['x_valid'].shape)
print(dataset['y_valid'].shape)
print(dataset['x_test'].shape)
print(dataset['y_test'].shape)


['x_train', 'y_train', 'x_valid', 'y_valid', 'x_test', 'y_test']
(5617, 92, 130)
(5617, 92)
(535, 55, 130)
(535, 55)
(1113, 43, 130)
(1113, 43)
