From c38f0fd56d6124f65698f7177cb16f7ba87c36f4 Mon Sep 17 00:00:00 2001 From: Rolf Jagerman Date: Wed, 28 Jun 2017 18:40:19 +0200 Subject: [PATCH] Update readme to reflect new loss code --- README.md | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a98044d..c4c0c61 100644 --- a/README.md +++ b/README.md @@ -30,9 +30,9 @@ Additionally, we provide minibatch iterators for Learning to Rank datasets. Thes Currently we provide implementations for the following loss functions - * Top-1 ListNet: `shoelace.loss.listwise.ListNetLoss` - * ListMLE: `shoelace.loss.listwise.ListMLELoss` - * ListPL: `shoelace.loss.listwise.ListPLLoss` + * Top-1 ListNet: `shoelace.loss.listwise.listnet` + * ListMLE: `shoelace.loss.listwise.listmle` + * ListPL: `shoelace.loss.listwise.listpl` ## Example @@ -40,26 +40,28 @@ Here is an example script that will train up a single-layer linear neural networ from shoelace.dataset import LtrDataset from shoelace.iterator import LtrIterator - from shoelace.loss.listwise import ListNetLoss - from chainer import training, optimizers, links + from shoelace.loss.listwise import listnet + from chainer import training, optimizers, links, Chain from chainer.training import extensions # Load data and set up iterator - with open('./path/to/svmrank.txt', 'r') as f: + with open('./path/to/ranksvm.txt', 'r') as f: training_set = LtrDataset.load_txt(f) training_iterator = LtrIterator(training_set, repeat=True, shuffle=True) - - # Create neural network with chainer and apply our loss function + + # Create neural network with chainer and apply loss function predictor = links.Linear(None, 1) - loss = ListNetLoss(predictor) - + class Ranker(Chain): + def __call__(self, x, t): + return listnet(self.predictor(x), t) + loss = Ranker(predictor=predictor) + # Build optimizer, updater and trainer optimizer = optimizers.Adam() optimizer.setup(loss) updater = training.StandardUpdater(training_iterator, optimizer) trainer = training.Trainer(updater, (40, 'epoch')) trainer.extend(extensions.ProgressBar()) - + # Train neural network trainer.run() -