In [1]:
import os
import sys
%load_ext autoreload
%autoreload 2


os.getpid()

sys.path.append("../sketchformer")
sys.path.append("../src")


In [2]:
from isketcher import InteractiveSketcher
from tqdm import tqdm
import json
import time
import pandas as pd
import tensorflow as tf
import numpy as np
import datetime

In [3]:
# load dataset
with open('../data/isketcher/train.json', 'r') as f:
    train = json.load(f)
with open('../data/isketcher/valid.json', 'r') as f:
    valid = json.load(f)
with open('../data/isketcher/test.json', 'r') as f:
    test = json.load(f)

print(f"train: {len(train)}, valid: {len(valid)}, test: {len(test)}")


train: 5617, valid: 535, test: 1113


In [4]:
# load sketchformer
from basic_usage.sketchformer import continuous_embeddings
sketchformer = continuous_embeddings.get_pretrained_model()


[run-experiment] resorting checkpoint if exists
[Checkpoint] Restored, step #207536


In [5]:
# load class label
df = pd.read_csv('../outputs/sketchyscene_quickdraw.csv')
df = df.dropna(subset=['quickdraw_label'])
class_names = []
for row in df.itertuples():
    class_names.append(row.quickdraw_label)
class_to_num = dict(zip(class_names, range(0, len(class_names))))

print(len(class_names))
print(class_names)
print(class_to_num)


40
['airplane', 'apple', 'hot air balloon', 'banana', 'basket', 'bee', 'bench', 'bicycle', 'bird', 'wine bottle', 'bucket', 'bus', 'butterfly', 'car', 'cat', 'chair', 'cloud', 'cow', 'cup', 'dog', 'duck', 'fence', 'flower', 'grapes', 'grass', 'horse', 'house', 'moon', 'mountain', 'face', 'pig', 'rabbit', 'sheep', 'star', 'streetlight', 'sun', 'table', 'tree', 'truck', 'umbrella']
{'airplane': 0, 'apple': 1, 'hot air balloon': 2, 'banana': 3, 'basket': 4, 'bee': 5, 'bench': 6, 'bicycle': 7, 'bird': 8, 'wine bottle': 9, 'bucket': 10, 'bus': 11, 'butterfly': 12, 'car': 13, 'cat': 14, 'chair': 15, 'cloud': 16, 'cow': 17, 'cup': 18, 'dog': 19, 'duck': 20, 'fence': 21, 'flower': 22, 'grapes': 23, 'grass': 24, 'horse': 25, 'house': 26, 'moon': 27, 'mountain': 28, 'face': 29, 'pig': 30, 'rabbit': 31, 'sheep': 32, 'star': 33, 'streetlight': 34, 'sun': 35, 'table': 36, 'tree': 37, 'truck': 38, 'umbrella': 39}


In [6]:
# define preprocess
def preprocess(dataset):
    input_batch = []
    label_batch = []
    for scene in tqdm(dataset):
        sketches = list(map(lambda o: o['sketch'], scene))
        sketch_embeddings = sketchformer.get_embeddings(sketches)
        input_scene = []
        labels = []
        for se, obj in zip(sketch_embeddings, scene):
            p = [obj['position'][0] / 750, obj['position'][1] / 750]
            o = se.numpy().tolist() + p
            input_scene.append(o)  # オブジェクトの数が不規則
            labels.append(class_to_num[obj['label']])  # convert to num
        input_batch.append(input_scene)
        label_batch.append(labels)
    return tf.ragged.constant(input_batch).to_tensor(0.), tf.ragged.constant(label_batch).to_tensor(0)


In [7]:
# Create masks

def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

    # アテンション・ロジットにパディングを追加するため
    # さらに次元を追加する
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)


def create_masks(inp, tar):
    # Encoderパディング・マスク
    enc_padding_mask = create_padding_mask(inp)

    # デコーダーの 2つ目のアテンション・ブロックで使用
    # このパディング・マスクはエンコーダーの出力をマスクするのに使用
    dec_padding_mask = create_padding_mask(inp)

    # デコーダーの 1つ目のアテンション・ブロックで使用
    # デコーダーが受け取った入力のパディングと将来のトークンをマスクするのに使用
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask


def create_combined_mask(tar):

    # デコーダーの 1つ目のアテンション・ブロックで使用
    # デコーダーが受け取った入力のパディングと将来のトークンをマスクするのに使用
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return combined_mask


In [8]:
# hyper parameters
# TODO: adjust parameters

num_layers = 4
d_model = 130
dff = 512
num_heads = 5

target_object_num = 40  # object num
dropout_rate = 0.1

# Optimizer


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)


In [9]:
# create model

interactive_sketcher = InteractiveSketcher(
    num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff,
    object_num=target_object_num, pe_target=100, rate=dropout_rate)


In [10]:
# checkpoint

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=interactive_sketcher,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# チェックポイントが存在したなら、最後のチェックポイントを復元
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')


Latest checkpoint restored!!


In [11]:
# tensorboard

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/' + 'train/' + current_time
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
valid_log_dir = 'logs/' + 'valid/' + current_time
valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

In [None]:
# preprocess
print("Preprocessing train dataset")
x_train, y_train = preprocess(train)
print("Preprocessing valid dataset")
x_valid, y_valid = preprocess(valid)
print("Preprocessing test dataset")
x_test, y_test = preprocess(test)

In [12]:
# training step


scc = tf.keras.losses.SparseCategoricalCrossentropy(
    reduction=tf.keras.losses.Reduction.NONE)
mse = tf.keras.losses.MeanSquaredError(
    reduction=tf.keras.losses.Reduction.NONE)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_class_accuracy')

valid_loss = tf.keras.metrics.Mean(name='train_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_class_accuracy')

# @tf.functionは高速に実行するためにtrain_stepをTFグラフにトレースコンパイルします。
# この関数は、引数となるテンソルのshapeに特化したものです。
# シーケンスの長さや（最後のバッチが小さくなるなど）バッチサイズが可変となることによって
# 再トレーシングが起きないようにするため、input_signatureを使って、より一般的なshapeを
# 指定します。

train_step_signature = [
    tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
    tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
]

def loss_function(c_real, x_real, y_real, c_pred, x_pred, y_pred):
    # class loss
    # クラスラベルはカテゴリカルクロスエントロピー
    c_loss_ = scc(c_real, c_pred)

    # position loss
    # 位置座標は平均二乗誤差
    p_loss_ = tf.math.square(x_real - x_pred) + \
        tf.math.square(y_real - y_pred)

    # mask padded object
    # パディングしたオブジェクトの部分を損失に加えないようにマスクする
    mask = tf.math.logical_not(tf.math.equal(c_real, 0))
    c_mask = tf.cast(mask, dtype=c_loss_.dtype)
    c_loss = tf.reduce_mean(c_loss_ * c_mask)
    p_mask = tf.cast(mask, dtype=p_loss_.dtype)
    p_loss = tf.reduce_mean(p_loss_ * p_mask)
    
    return c_loss + p_loss

# @tf.function(input_signature=train_step_signature) # なぜか通らないため一旦コメントアウト
def train_step(tar, labels):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    x_real, y_real = tar_real[:, :, -2], tar_real[:, :, -1]

    labels_inp = labels[:, :-1]
    labels_real = labels[:, 1:]

    # パディングしたオブジェクトの位置はlabelsが0の位置のため、そこからマスクを作成
    combined_mask = create_combined_mask(labels_inp)

    with tf.GradientTape() as tape:
        c_out, x_out, y_out, _ = interactive_sketcher(
            tar_inp, True, combined_mask)
        
        loss = loss_function(labels_real, x_real, y_real, c_out, x_out, y_out)

    gradients = tape.gradient(loss, interactive_sketcher.trainable_variables)
    optimizer.apply_gradients(
        zip(gradients, interactive_sketcher.trainable_variables))

    train_loss(loss)
    train_accuracy(labels_real, c_out)

def valid_step(tar, labels):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    x_real, y_real = tar_real[:, :, -2], tar_real[:, :, -1]

    labels_inp = labels[:, :-1]
    labels_real = labels[:, 1:]

    # パディングしたオブジェクトの位置はlabelsが0の位置のため、そこからマスクを作成
    combined_mask = create_combined_mask(labels_inp)

    c_out, x_out, y_out, _ = interactive_sketcher(
        tar_inp, False, combined_mask)
    
    loss = loss_function(labels_real, x_real, y_real, c_out, x_out, y_out)

    valid_loss(loss)
    valid_accuracy(labels_real, c_out)


In [None]:
# training

EPOCHS = 10
BATCH_SIZE = 16

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx: min(ndx + n, l)]


for epoch in range(EPOCHS):
    start = time.time()

    train_loss.reset_states()
    train_accuracy.reset_states()
    valid_loss.reset_states()
    valid_accuracy.reset_states()

    # train
    for i, (x_batch, y_batch) in enumerate(batch(zip(x_train, y_train), BATCH_SIZE)):
        train_step(x_batch, y_batch)

        if (i + 1) % 50 == 0:
            print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
                epoch + 1, i + 1, train_loss.result(), train_accuracy.result()))

    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)

    # valid
    valid_step(x_valid, y_valid)

    with valid_summary_writer.as_default():
        tf.summary.scalar('val_loss', valid_loss.result(), step=epoch)
        tf.summary.scalar('val_accuracy', valid_accuracy.result(), step=epoch)

    if (epoch + 1) % 1 == 0:
        ckpt_save_path = ckpt_manager.save()
        print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
                                                            ckpt_save_path))

    print('Epoch {} Loss {:.4f} Accuracy {:.4f} Valid_Loss {:.4f} Valid_Accuracy {:.4f}'.format(
        epoch + 1, train_loss.result(), train_accuracy.result(), valid_loss.result(), valid_accuracy.result()))

    print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))


In [None]:
# test model

scene_index = 0

print("----", 0, "-------")
print("real label: ", test[scene_index][0]["label"])
print("real position: ", test[scene_index][0]["position"])

for i in range(1, len(test[scene_index])):

    c_out, x_out, y_out, _ = interactive_sketcher(
        x[:, :i], training=False, look_ahead_mask=None)

    print("----", i, "-------")
    print("real label: ", test[scene_index][i]["label"])
    print("pred label: ", class_names[tf.argmax(c_out[scene_index][0])])

    print("real position: ", test[scene_index][i]["position"])
    print("pred position:  [{0}, {1}]".format(round(x_out[scene_index][0].numpy(
    ) * 750, 2), round(y_out[scene_index][0].numpy() * 750, 2)))


---- 0 -------
real label:  cloud
real position:  [80, 55]
---- 1 -------
real label:  cloud
pred label:  grass
real position:  [330, 55]
pred position:  [373.57, 471.98]
---- 2 -------
real label:  fence
pred label:  grass
real position:  [480, 335]
pred position:  [372.19, 492.52]
---- 3 -------
real label:  fence
pred label:  grass
real position:  [720, 335]
pred position:  [360.95, 467.97]
---- 4 -------
real label:  mountain
pred label:  grass
real position:  [275, 180]
pred position:  [345.43, 438.94]
---- 5 -------
real label:  cloud
pred label:  grass
real position:  [610, 70]
pred position:  [347.45, 439.75]
---- 6 -------
real label:  house
pred label:  grass
real position:  [620, 320]
pred position:  [391.66, 485.33]
---- 7 -------
real label:  fence
pred label:  grass
real position:  [60, 335]
pred position:  [396.86, 482.74]
---- 8 -------
real label:  car
pred label:  grass
real position:  [415, 535]
pred position:  [400.46, 486.73]
---- 9 -------
real label:  fence
pred 