-
Notifications
You must be signed in to change notification settings - Fork 6
/
model_class.py
290 lines (243 loc) · 10.5 KB
/
model_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sketch-RNN Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
# internal imports
import numpy as np
import tensorflow as tf
from magenta.models.sketch_rnn import rnn
def copy_hparams(hparams):
"""Return a copy of an HParams instance."""
return tf.contrib.training.HParams(**hparams.values())
def get_default_hparams():
"""Return default HParams for sketch-rnn."""
hparams = tf.contrib.training.HParams(
data_set=['aaron_sheep.npz'], # Our dataset.
num_steps=10000000, # Total number of steps of training. Keep large.
save_every=500, # Number of batches per checkpoint creation.
max_seq_len=250, # Not used. Will be changed by model. [Eliminate?]
dec_rnn_size=512, # Size of decoder.
dec_model='lstm', # Decoder: lstm, layer_norm or hyper.
enc_rnn_size=256, # Size of encoder.
enc_model='lstm', # Encoder: lstm, layer_norm or hyper.
z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
kl_weight=0.5, # KL weight of loss equation. Recommend 0.5 or 1.0.
kl_weight_start=0.01, # KL start weight when annealing.
kl_tolerance=0.2, # Level of KL loss at which to stop optimizing for KL.
batch_size=100, # Minibatch size. Recommend leaving at 100.
grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
num_mixture=20, # Number of mixtures in Gaussian mixture model.
learning_rate=0.001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.
min_learning_rate=0.00001, # Minimum learning rate.
use_recurrent_dropout=True, # Dropout with memory loss. Recomended
recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output droput. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
random_scale_factor=0.15, # Random scaling data augmention proportion.
augment_stroke_prob=0.10, # Point dropping augmentation proportion.
conditional=True, # When False, use unconditional decoder-only model.
is_training=True # Is model training? Recommend keeping true.
)
return hparams
class Model(object):
"""Define a SketchRNN model."""
def __init__(self, hps, gpu_mode=True, reuse=False):
"""Initializer for the SketchRNN model.
Args:
hps: a HParams object containing model hyperparameters
gpu_mode: a boolean that when True, uses GPU mode.
reuse: a boolean that when true, attemps to reuse variables.
"""
self.hps = hps
with tf.variable_scope('vector_rnn', reuse=reuse):
if not gpu_mode:
with tf.device('/cpu:0'):
tf.logging.info('Model using cpu.')
self.build_model(hps)
else:
tf.logging.info('Model using gpu.')
self.build_model(hps)
def encoder(self, batch, sequence_lengths):
"""Define the bi-directional encoder module of sketch-rnn."""
unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
self.enc_cell_fw,
self.enc_cell_bw,
batch,
sequence_length=sequence_lengths,
time_major=False,
swap_memory=True,
dtype=tf.float32,
scope='ENC_RNN')
last_state_fw, last_state_bw = last_states
last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
last_h = tf.concat([last_h_fw, last_h_bw], 1)
return last_h
def build_model(self, hps):
"""Define model architecture."""
if hps.is_training:
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if hps.enc_model == 'lstm':
enc_cell_fn = rnn.LSTMCell
elif hps.enc_model == 'layer_norm':
enc_cell_fn = rnn.LayerNormLSTMCell
elif hps.enc_model == 'hyper':
enc_cell_fn = rnn.HyperLSTMCell
else:
assert False, 'please choose a respectable cell'
use_recurrent_dropout = self.hps.use_recurrent_dropout
use_input_dropout = self.hps.use_input_dropout
use_output_dropout = self.hps.use_output_dropout
if hps.conditional: # vae mode:
if hps.enc_model == 'hyper':
self.enc_cell_fw = enc_cell_fn(
hps.enc_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
self.enc_cell_bw = enc_cell_fn(
hps.enc_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
else:
self.enc_cell_fw = enc_cell_fn(
hps.enc_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
self.enc_cell_bw = enc_cell_fn(
hps.enc_rnn_size,
use_recurrent_dropout=use_recurrent_dropout,
dropout_keep_prob=self.hps.recurrent_dropout_prob)
self.sequence_lengths = tf.placeholder(dtype=tf.int32, shape=[self.hps.batch_size])
self.input_data = tf.placeholder(dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5])
self.y_labels = tf.placeholder(dtype=tf.int32, shape=[self.hps.batch_size])
# The target/expected vectors of strokes
self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :]
# either do vae-bit and get z, or do unconditional, decoder-only
if hps.conditional: # vae mode:
self.batch_z = self.encoder(self.output_x, self.sequence_lengths)
else: # unconditional, decoder-only generation
self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size), dtype=tf.float32)
# TODO(deck): Better understand this comment.
# Number of outputs is 3 (one logit per pen state) plus 6 per mixture
# component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
# mixture weight/probability (Pi_k)
n_out = 16 #num_classes
with tf.variable_scope('RNN'):
output_w = tf.get_variable('output_w', [2*self.hps.enc_rnn_size, n_out])
output_b = tf.get_variable('output_b', [n_out])
output = tf.nn.xw_plus_b(self.batch_z, output_w, output_b)
self.output = output
if self.y_labels is not None:
self.ce_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, labels=self.y_labels))
else:
self.ce_loss = 0
if self.hps.is_training:
self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
optimizer = tf.train.AdamOptimizer(self.lr)
self.cost = self.ce_loss
gvs = optimizer.compute_gradients(self.cost)
g = self.hps.grad_clip
capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]
self.train_op = optimizer.apply_gradients(capped_gvs, global_step=self.global_step, name='train_step')
def sample(sess, model, seq_len=250, temperature=1.0, greedy_mode=False, z=None):
"""Samples a sequence from a pre-trained model."""
def adjust_temp(pi_pdf, temp):
pi_pdf = np.log(pi_pdf) / temp
pi_pdf -= pi_pdf.max()
pi_pdf = np.exp(pi_pdf)
pi_pdf /= pi_pdf.sum()
return pi_pdf
def get_pi_idx(x, pdf, temp=1.0, greedy=False):
"""Samples from a pdf, optionally greedily."""
if greedy:
return np.argmax(pdf)
pdf = adjust_temp(np.copy(pdf), temp)
accumulate = 0
for i in range(0, pdf.size):
accumulate += pdf[i]
if accumulate >= x:
return i
tf.logging.info('Error with sampling ensemble.')
return -1
def sample_gaussian_2d(mu1, mu2, s1, s2, rho, temp=1.0, greedy=False):
if greedy:
return mu1, mu2
mean = [mu1, mu2]
s1 *= temp * temp
s2 *= temp * temp
cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
x = np.random.multivariate_normal(mean, cov, 1)
return x[0][0], x[0][1]
prev_x = np.zeros((1, 1, 5), dtype=np.float32)
prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
if z is None:
z = np.random.randn(1, model.hps.z_size) # not used if unconditional
if not model.hps.conditional:
prev_state = sess.run(model.initial_state)
else:
prev_state = sess.run(model.initial_state, feed_dict={model.batch_z: z})
strokes = np.zeros((seq_len, 5), dtype=np.float32)
mixture_params = []
greedy = False
temp = 1.0
for i in range(seq_len):
if not model.hps.conditional:
feed = {
model.input_x: prev_x,
model.sequence_lengths: [1],
model.initial_state: prev_state
}
else:
feed = {
model.input_x: prev_x,
model.sequence_lengths: [1],
model.initial_state: prev_state,
model.batch_z: z
}
params = sess.run([
model.pi, model.mu1, model.mu2, model.sigma1, model.sigma2, model.corr,
model.pen, model.final_state
], feed)
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = params
if i < 0:
greedy = False
temp = 1.0
else:
greedy = greedy_mode
temp = temperature
idx = get_pi_idx(random.random(), o_pi[0], temp, greedy)
idx_eos = get_pi_idx(random.random(), o_pen[0], temp, greedy)
eos = [0, 0, 0]
eos[idx_eos] = 1
next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx],
o_sigma1[0][idx], o_sigma2[0][idx],
o_corr[0][idx], np.sqrt(temp), greedy)
strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]
params = [
o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0],
o_pen[0]
]
mixture_params.append(params)
prev_x = np.zeros((1, 1, 5), dtype=np.float32)
prev_x[0][0] = np.array(
[next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
prev_state = next_state
return strokes, mixture_params