Skip to content
This repository has been archived by the owner on Jun 29, 2020. It is now read-only.

Commit

Permalink
Update readme to reflect new loss code
Browse files Browse the repository at this point in the history
  • Loading branch information
rjagerman committed Jun 28, 2017
1 parent 668150c commit c38f0fd
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,38 @@ Additionally, we provide minibatch iterators for Learning to Rank datasets. Thes

Currently we provide implementations for the following loss functions

* Top-1 ListNet: `shoelace.loss.listwise.ListNetLoss`
* ListMLE: `shoelace.loss.listwise.ListMLELoss`
* ListPL: `shoelace.loss.listwise.ListPLLoss`
* Top-1 ListNet: `shoelace.loss.listwise.listnet`
* ListMLE: `shoelace.loss.listwise.listmle`
* ListPL: `shoelace.loss.listwise.listpl`

## Example

Here is an example script that will train up a single-layer linear neural network with a ListNet loss function:

from shoelace.dataset import LtrDataset
from shoelace.iterator import LtrIterator
from shoelace.loss.listwise import ListNetLoss
from chainer import training, optimizers, links
from shoelace.loss.listwise import listnet
from chainer import training, optimizers, links, Chain
from chainer.training import extensions

# Load data and set up iterator
with open('./path/to/svmrank.txt', 'r') as f:
with open('./path/to/ranksvm.txt', 'r') as f:
training_set = LtrDataset.load_txt(f)
training_iterator = LtrIterator(training_set, repeat=True, shuffle=True)
# Create neural network with chainer and apply our loss function

# Create neural network with chainer and apply loss function
predictor = links.Linear(None, 1)
loss = ListNetLoss(predictor)

class Ranker(Chain):
def __call__(self, x, t):
return listnet(self.predictor(x), t)
loss = Ranker(predictor=predictor)

# Build optimizer, updater and trainer
optimizer = optimizers.Adam()
optimizer.setup(loss)
updater = training.StandardUpdater(training_iterator, optimizer)
trainer = training.Trainer(updater, (40, 'epoch'))
trainer.extend(extensions.ProgressBar())

# Train neural network
trainer.run()

0 comments on commit c38f0fd

Please sign in to comment.