Skip to content

Commit 3740a6f

Browse files
committed
update
1 parent 4de23e3 commit 3740a6f

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

nlp_class3/attention.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -324,18 +324,42 @@ def stack_and_transpose(x):
324324
outputs=outputs
325325
)
326326

327+
328+
def custom_loss(y_true, y_pred):
329+
# both are of shape N x T x K
330+
mask = K.cast(y_true > 0, dtype='float32')
331+
out = mask * y_true * K.log(y_pred)
332+
return -K.sum(out) / K.sum(mask)
333+
334+
335+
def acc(y_true, y_pred):
336+
# both are of shape N x T x K
337+
targ = K.argmax(y_true, axis=-1)
338+
pred = K.argmax(y_pred, axis=-1)
339+
correct = K.cast(K.equal(targ, pred), dtype='float32')
340+
341+
# 0 is padding, don't include those
342+
mask = K.cast(K.greater(targ, 0), dtype='float32')
343+
n_correct = K.sum(mask * correct)
344+
n_total = K.sum(mask)
345+
return n_correct / n_total
346+
347+
327348
# compile the model
328-
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
349+
model.compile(optimizer='adam', loss=custom_loss, metrics=[acc])
350+
# model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
329351

330352
# train the model
331-
z = np.zeros((NUM_SAMPLES, LATENT_DIM_DECODER)) # initial [s, c]
353+
z = np.zeros((len(encoder_inputs), LATENT_DIM_DECODER)) # initial [s, c]
332354
r = model.fit(
333355
[encoder_inputs, decoder_inputs, z, z], decoder_targets_one_hot,
334356
batch_size=BATCH_SIZE,
335357
epochs=EPOCHS,
336358
validation_split=0.2
337359
)
338360

361+
362+
339363
# plot some data
340364
plt.plot(r.history['loss'], label='loss')
341365
plt.plot(r.history['val_loss'], label='val_loss')
@@ -371,6 +395,9 @@ def stack_and_transpose(x):
371395
# combine context with last word
372396
decoder_lstm_input = context_last_word_concat_layer([context, decoder_inputs_single_x])
373397

398+
399+
400+
374401
# lstm and final dense
375402
o, s, c = decoder_lstm(decoder_lstm_input, initial_state=[initial_s, initial_c])
376403
decoder_outputs = decoder_dense(o)

0 commit comments

Comments
 (0)