Skip to content

Commit ce650d6

Browse files
committed
update
1 parent 3740a6f commit ce650d6

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

nlp_class3/wseq2seq.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323

2424
# some config
2525
BATCH_SIZE = 64 # Batch size for training.
26-
EPOCHS = 100 # Number of epochs to train for.
26+
EPOCHS = 40 # Number of epochs to train for.
2727
LATENT_DIM = 256 # Latent dimensionality of the encoding space.
2828
NUM_SAMPLES = 10000 # Number of samples to train on.
29-
MAX_SEQUENCE_LENGTH = 100
3029
MAX_NUM_WORDS = 20000
3130
EMBEDDING_DIM = 100
3231

@@ -219,12 +218,37 @@
219218
# Create the model object
220219
model = Model([encoder_inputs_placeholder, decoder_inputs_placeholder], decoder_outputs)
221220

221+
222+
def custom_loss(y_true, y_pred):
223+
# both are of shape N x T x K
224+
mask = K.cast(y_true > 0, dtype='float32')
225+
out = mask * y_true * K.log(y_pred)
226+
return -K.sum(out) / K.sum(mask)
227+
228+
229+
def acc(y_true, y_pred):
230+
# both are of shape N x T x K
231+
targ = K.argmax(y_true, axis=-1)
232+
pred = K.argmax(y_pred, axis=-1)
233+
correct = K.cast(K.equal(targ, pred), dtype='float32')
234+
235+
# 0 is padding, don't include those
236+
mask = K.cast(K.greater(targ, 0), dtype='float32')
237+
n_correct = K.sum(mask * correct)
238+
n_total = K.sum(mask)
239+
return n_correct / n_total
240+
241+
model.compile(optimizer='adam', loss=custom_loss, metrics=[acc])
242+
222243
# Compile the model and train it
223-
model.compile(
224-
optimizer='rmsprop',
225-
loss='categorical_crossentropy',
226-
metrics=['accuracy']
227-
)
244+
# model.compile(
245+
# optimizer='rmsprop',
246+
# loss='categorical_crossentropy',
247+
# metrics=['accuracy']
248+
# )
249+
250+
251+
228252
r = model.fit(
229253
[encoder_inputs, decoder_inputs], decoder_targets_one_hot,
230254
batch_size=BATCH_SIZE,

0 commit comments

Comments
 (0)