# Writing Kanji with a time distributed MDN-RNN 

- This notebook is the same as the Kanji MDN-RNN except that it trains on predictions made over the whole sequence length.
- The MDN-RNN is also written in Keras' functional API for good measure!


In [None]:
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import backend as K
from context import * # imports the MDN layer 
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 
%matplotlib inline

# Only for GPU use:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
K.set_session(sess)

### Download the Dataset:

In [None]:
# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets
# Other datasets in "Sketch 3" format should also work.
import urllib.request
url = 'https://github.com/hardmaru/sketch-rnn-datasets/raw/master/kanji/kanji.rdp25.npz'  
urllib.request.urlretrieve(url, './kanji.rdp25.npz')  

### Dataset:

Includes about 11000 handwritten kanji characters divied into training, validation, and testing sets.

For creative purposes, we may not need the validation or testing sets, and can just focus on the training set.

In [None]:
with np.load('./kanji.rdp25.npz', allow_pickle=True) as data:
    train_set = data['train']
    valid_set = data['valid']
    test_set = data['test']
    
print("Training kanji:", len(train_set))
print("Validation kanji:", len(valid_set))
print("Testing kanji:", len(test_set))

### Setup an MDN RNN

In [None]:
# Training Hyperparameters:
SEQ_LEN = 50
BATCH_SIZE = 64
HIDDEN_UNITS = 256
EPOCHS = 100
SEED = 2345  # set random seed for reproducibility
random.seed(SEED)
np.random.seed(SEED)
OUTPUT_DIMENSION = 3
NUMBER_MIXTURES = 10

inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION), name='inputs')
lstm1_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm1', return_sequences=True)(inputs)
lstm2_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm2', return_sequences=True)(lstm1_out)
mdn_out = keras.layers.TimeDistributed(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs'), name='td_mdn')(lstm2_out)

model = keras.models.Model(inputs=inputs, outputs=mdn_out)
model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')
model.summary()

## Process the Data and Train the Model

- Chop up the data into slices of the correct length, generate `X` and `y` for the training process.
- Very similar process to the previous RNN examples!
- We end up with 330000 examples - a pretty healthy dataset.

In [None]:
# Functions for slicing up data
def slice_sequence_examples(sequence, num_steps):
    xs = []
    for i in range(len(sequence) - num_steps - 1):
        example = sequence[i: i + num_steps]
        xs.append(example)
    return xs

def seq_to_overlapping_format(examples):
    """Takes sequences of seq_len+1 and returns overlapping
    sequences of seq_len."""
    xs = []
    ys = []
    for ex in examples:
        xs.append(ex[:-1])
        ys.append(ex[1:])
    return (xs,ys)

# Prepare training data as X and Y.
slices = []
for seq in train_set:
    slices +=  slice_sequence_examples(seq, SEQ_LEN+1)
X, y = seq_to_overlapping_format(slices)

X = np.array(X)
y = np.array(y)

print("Number of training examples:")
print("X:", X.shape)
print("y:", y.shape)

In [None]:
# Prepare validation data as X and Y.
slices = []
for seq in valid_set:
    slices +=  slice_sequence_examples(seq, SEQ_LEN+1)
Xval, yval = seq_to_overlapping_format(slices)

Xval = np.array(Xval)
yval = np.array(yval)

print("Number of training examples:")
print("X:", Xval.shape)
print("y:", yval.shape)

## Do the training!

- We're not going to train in the tutorial!
- These settings take about 220 seconds per epoch, about 6 hours for the whole training run.

In [None]:
# Fit the model
filepath="kanji_mdnrnn-{epoch:02d}.h5"
checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min')
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)
callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]
history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_data=(Xval,yval))
model.save('kanji_mdnrnn_model_time_distributed.h5')  # creates a HDF5 file 'my_model.h5'

In [None]:
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title("Kanji MDRNN Training")
plt.show()

In [None]:
!ls

## Try out the model! Generate some Kanji!

We need to create a decoding model with batch size 1 and sequence length 1.

In [None]:
# Decoding Model
# Same as training model except for dimension and mixtures.

decoder = keras.Sequential()
decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True))
decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True))
decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))
decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam())
decoder.summary()

#decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5') # load weights independently from file
#decoder.load_weights('kanji_mdnrnn-99.hdf5')
decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5')

## Generating drawings

- First need some helper functions to view the output.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

def zero_start_position():
    """A zeroed out start position with pen down"""
    out = np.zeros((1, 1, 3), dtype=np.float32)
    out[0, 0, 2] = 1 # set pen down.
    return out

def generate_sketch(model, start_pos, num_points=100):
     return None

def cutoff_stroke(x):
    return np.greater(x,0.5) * 1.0

def plot_sketch(sketch_array):
    """Plot a sketch quickly to see what it looks like."""
    sketch_df = pd.DataFrame({'x':sketch_array.T[0],'y':sketch_array.T[1],'z':sketch_array.T[2]})
    sketch_df.x = sketch_df.x.cumsum()
    sketch_df.y = -1 * sketch_df.y.cumsum()
    # Do the plot
    fig = plt.figure(figsize=(8, 8))
    ax1 = fig.add_subplot(111)
    #ax1.scatter(sketch_df.x,sketch_df.y,marker='o', c='r', alpha=1.0)
    # Need to do something with sketch_df.z
    ax1.plot(sketch_df.x,sketch_df.y,'r-')
    plt.show()

## SVG Drawing Function

Here's Hardmaru's Drawing Functions from _write-rnn-tensorflow_. Big hat tip to Hardmaru for this!

Here's the source: https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py


In [None]:
# Hardmaru's Drawing Functions from write-rnn-tensorflow
# Big hat tip
# Here's the source:
# https://github.com/hardmaru/write-rnn-tensorflow/blob/master/utils.py

import svgwrite
from IPython.display import SVG, display

def get_bounds(data, factor):
    min_x = 0
    max_x = 0
    min_y = 0
    max_y = 0

    abs_x = 0
    abs_y = 0
    for i in range(len(data)):
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        abs_x += x
        abs_y += y
        min_x = min(min_x, abs_x)
        min_y = min(min_y, abs_y)
        max_x = max(max_x, abs_x)
        max_y = max(max_y, abs_y)

    return (min_x, max_x, min_y, max_y)

def draw_strokes(data, factor=1, svg_filename='sample.svg'):
    min_x, max_x, min_y, max_y = get_bounds(data, factor)
    dims = (50 + max_x - min_x, 50 + max_y - min_y)

    dwg = svgwrite.Drawing(svg_filename, size=dims)
    dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))

    lift_pen = 1

    abs_x = 25 - min_x
    abs_y = 25 - min_y
    p = "M%s,%s " % (abs_x, abs_y)

    command = "m"

    for i in range(len(data)):
        if (lift_pen == 1):
            command = "m"
        elif (command != "l"):
            command = "l"
        else:
            command = ""
        x = float(data[i, 0]) / factor
        y = float(data[i, 1]) / factor
        lift_pen = data[i, 2]
        p += command + str(x) + "," + str(y) + " "

    the_color = "black"
    stroke_width = 2

    dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none"))

    dwg.save()
    display(SVG(dwg.tostring()))

In [None]:
# Predict a character and plot the result.
temperature = 1.5 # seems to work well with rather high temperature (2.5)
sigma_temp = 0.01

p = zero_start_position()
sketch = [p.reshape(3,)]

for i in range(100):
    params = decoder.predict(p.reshape(1,1,3))
    p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=temperature, sigma_temp=sigma_temp)
    sketch.append(p.reshape((3,)))

sketch = np.array(sketch)
decoder.reset_states()

sketch.T[2] = cutoff_stroke(sketch.T[2])
draw_strokes(sketch, factor=0.5)
#plot_sketch(sketch)