Skip to content

Commit

Permalink
Merge pull request #8 from paxcema/reset_hn_eval
Browse files Browse the repository at this point in the history
Reset states at evaluation time
  • Loading branch information
paxcema authored Oct 17, 2020
2 parents 186fb01 + 1721877 commit 5bb069a
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions model/gru4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ def get_metrics(model, args, train_generator_map, recall_k=20, mrr_k=20):
print("Evaluating model...")
for feat, label, mask in test_generator:

gru_layer = model.get_layer(name="GRU")
hidden_states = gru_layer.states[0].numpy()
for elt in mask:
hidden_states[elt, :] = 0
gru_layer.reset_states(states=hidden_states)

target_oh = to_categorical(label, num_classes=args.train_n_items)
input_oh = to_categorical(feat, num_classes=args.train_n_items)
input_oh = np.expand_dims(input_oh, axis=1)
Expand Down Expand Up @@ -242,14 +248,15 @@ def train_model(model, args):

if not args.eval_all_epochs:
(rec, rec_k), (mrr, mrr_k) = get_metrics(model_to_train, args, train_dataset.itemmap)
print("\t - Recall@{} epoch {}: {:5f}".format(rec_k, epoch, rec))
print("\t - MRR@{} epoch {}: {:5f}\n".format(mrr_k, epoch, mrr))
print("\t - Recall@{} epoch {}: {:5f}".format(rec_k, args.epochs, rec))
print("\t - MRR@{} epoch {}: {:5f}\n".format(mrr_k, args.epochs, mrr))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Keras GRU4REC: session-based recommendations')
parser.add_argument('--resume', type=str, help='stored model path to continue training')
parser.add_argument('--train-path', type=str, default='../../processedData/rsc15_train_tr.txt')
parser.add_argument('--eval-only', type=bool, default=False)
parser.add_argument('--dev-path', type=str, default='../../processedData/rsc15_train_valid.txt')
parser.add_argument('--test-path', type=str, default='../../processedData/rsc15_test.txt')
parser.add_argument('--batch-size', type=str, default=512)
Expand Down Expand Up @@ -277,5 +284,10 @@ def train_model(model, args):
else:
model = create_model(args)

train_model(model, args)

if args.eval_only:
train_dataset = SessionDataset(args.train_data)
(rec, rec_k), (mrr, mrr_k) = get_metrics(model, args, train_dataset.itemmap)
print("\t - Recall@{} epoch {}: {:5f}".format(rec_k, -1, rec))
print("\t - MRR@{} epoch {}: {:5f}\n".format(mrr_k, -1, mrr))
else:
train_model(model, args)

0 comments on commit 5bb069a

Please sign in to comment.