@@ -324,18 +324,42 @@ def stack_and_transpose(x):
324
324
outputs = outputs
325
325
)
326
326
327
+
328
+ def custom_loss (y_true , y_pred ):
329
+ # both are of shape N x T x K
330
+ mask = K .cast (y_true > 0 , dtype = 'float32' )
331
+ out = mask * y_true * K .log (y_pred )
332
+ return - K .sum (out ) / K .sum (mask )
333
+
334
+
335
+ def acc (y_true , y_pred ):
336
+ # both are of shape N x T x K
337
+ targ = K .argmax (y_true , axis = - 1 )
338
+ pred = K .argmax (y_pred , axis = - 1 )
339
+ correct = K .cast (K .equal (targ , pred ), dtype = 'float32' )
340
+
341
+ # 0 is padding, don't include those
342
+ mask = K .cast (K .greater (targ , 0 ), dtype = 'float32' )
343
+ n_correct = K .sum (mask * correct )
344
+ n_total = K .sum (mask )
345
+ return n_correct / n_total
346
+
347
+
327
348
# compile the model
328
- model .compile (optimizer = 'rmsprop' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
349
+ model .compile (optimizer = 'adam' , loss = custom_loss , metrics = [acc ])
350
+ # model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
329
351
330
352
# train the model
331
- z = np .zeros ((NUM_SAMPLES , LATENT_DIM_DECODER )) # initial [s, c]
353
+ z = np .zeros ((len ( encoder_inputs ) , LATENT_DIM_DECODER )) # initial [s, c]
332
354
r = model .fit (
333
355
[encoder_inputs , decoder_inputs , z , z ], decoder_targets_one_hot ,
334
356
batch_size = BATCH_SIZE ,
335
357
epochs = EPOCHS ,
336
358
validation_split = 0.2
337
359
)
338
360
361
+
362
+
339
363
# plot some data
340
364
plt .plot (r .history ['loss' ], label = 'loss' )
341
365
plt .plot (r .history ['val_loss' ], label = 'val_loss' )
@@ -371,6 +395,9 @@ def stack_and_transpose(x):
371
395
# combine context with last word
372
396
decoder_lstm_input = context_last_word_concat_layer ([context , decoder_inputs_single_x ])
373
397
398
+
399
+
400
+
374
401
# lstm and final dense
375
402
o , s , c = decoder_lstm (decoder_lstm_input , initial_state = [initial_s , initial_c ])
376
403
decoder_outputs = decoder_dense (o )
0 commit comments