This repository has been archived by the owner on Jun 29, 2020. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from rjagerman/fix-updated-cupy
Update 0.2.0
- Loading branch information
Showing
8 changed files
with
254 additions
and
125 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from chainer import cuda, function | ||
|
||
|
||
class NDCG(function.Function): | ||
def __init__(self, k=0): | ||
self.k = k | ||
|
||
def forward(self, inputs): | ||
xp = cuda.get_array_module(*inputs) | ||
y, t = inputs | ||
|
||
# Assert arrays have the same shape | ||
if t.shape != y.shape: | ||
raise ValueError("Input arrays have different shapes") | ||
|
||
# Computing nDCG on empty array should just return 0.0 | ||
if t.shape[0] == 0: | ||
return xp.asarray(0.0), | ||
|
||
# Compute predicted indices by arg sorting | ||
predicted_indices = xp.argsort(y) | ||
best_indices = xp.argsort(t) | ||
|
||
# Predicted and theoretically best relevance labels | ||
predicted_relevance = xp.flip(t[predicted_indices], axis=0) | ||
best_relevance = xp.flip(t[best_indices], axis=0) | ||
|
||
# Compute needed statistics | ||
length = predicted_relevance.shape[0] | ||
arange = xp.arange(length) | ||
last = min(self.k, length) | ||
if last < 1: | ||
last = length | ||
|
||
# Compute regular DCG | ||
dcg_numerator = 2 ** predicted_relevance[:last] - 1 | ||
dcg_denominator = xp.log2(arange[:last] + 2) | ||
dcg = xp.sum(dcg_numerator / dcg_denominator) | ||
|
||
# Compute iDCG for normalization | ||
idcg_numerator = (2 ** best_relevance[:last] - 1) | ||
idcg_denominator = (xp.log2(arange[:last] + 2)) | ||
idcg = xp.sum(idcg_numerator / idcg_denominator) | ||
|
||
if idcg == 0.0: | ||
return xp.asarray(1.0), | ||
|
||
return xp.asarray(dcg / idcg), | ||
|
||
|
||
def ndcg(y, t, k=0): | ||
""" | ||
Computes the nDCG@k for given list of true relevance labels (y_true) and | ||
given list of predicted relevance labels (y_score) | ||
:param y_true: The ground truth relevance labels | ||
:param y_score: The predicted relevance scores | ||
:param k: The cut-off point (if set to smaller or equal to 0, it does not | ||
cut-off) | ||
:return: The nDCG@k value | ||
""" | ||
return NDCG(k=k)(y, t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,76 @@ | ||
import numpy as np | ||
import chainer | ||
import chainer.functions as F | ||
from chainer import Chain, cuda | ||
from chainer import cuda | ||
from shoelace.functions.logcumsumexp import logcumsumexp | ||
|
||
|
||
class AbstractListLoss(Chain): | ||
""" | ||
An abstract listwise loss function | ||
This loss calls the prediction function on the target variable and calls | ||
a local `AbstractListLoss.loss` function which should be implemented by | ||
subclasses | ||
""" | ||
def __init__(self, predictor): | ||
super(AbstractListLoss, self).__init__(predictor=predictor) | ||
|
||
def __call__(self, x, t): | ||
x_hat = self.predictor(x) | ||
loss = self.loss(x_hat, t) | ||
return loss | ||
|
||
def loss(self, x, t): | ||
raise NotImplementedError | ||
|
||
|
||
class ListMLELoss(AbstractListLoss): | ||
def listmle(x, t): | ||
""" | ||
The ListMLE loss as in Xia et al (2008), Listwise Approach to Learning to | ||
Rank - Theory and Algorithm. | ||
:param x: The activation of the previous layer | ||
:param t: The target labels | ||
:return: The loss | ||
""" | ||
def __init__(self, predictor): | ||
super(ListMLELoss, self).__init__(predictor=predictor) | ||
|
||
def loss(self, x, t): | ||
""" | ||
Assuming target labels are already sorted by relevance | ||
:param x: The x variable | ||
:param t: The target variable | ||
:return: The loss | ||
""" | ||
final = logcumsumexp(x) | ||
return F.sum(final - x) | ||
# Get the ground truth by sorting activations by the relevance labels | ||
xp = cuda.get_array_module(t) | ||
t_hat = t[:, 0] | ||
x_hat = x[xp.flip(xp.argsort(t_hat), axis=0)] | ||
|
||
# Compute MLE loss | ||
final = logcumsumexp(x_hat) | ||
return F.sum(final - x_hat) | ||
|
||
|
||
class ListNetLoss(AbstractListLoss): | ||
def listnet(x, t): | ||
""" | ||
The Top-1 approximated ListNet loss as in Cao et al (2006), Learning to | ||
Rank: From Pairwise Approach to Listwise Approach | ||
:param x: The activation of the previous layer | ||
:param t: The target labels | ||
:return: The loss | ||
""" | ||
def __init__(self, predictor): | ||
super(ListNetLoss, self).__init__(predictor=predictor) | ||
|
||
def loss(self, x, t): | ||
""" | ||
ListNet top-1 reduces to a softmax and simple cross entropy | ||
:param x: The x variable | ||
:param t: The target variable | ||
:return: The loss | ||
""" | ||
st = F.softmax(t, axis=0) | ||
sx = F.softmax(x, axis=0) | ||
return -F.mean(st * F.log(sx)) | ||
# ListNet top-1 reduces to a softmax and simple cross entropy | ||
st = F.softmax(t, axis=0) | ||
sx = F.softmax(x, axis=0) | ||
return -F.mean(st * F.log(sx)) | ||
|
||
|
||
class ListPLLoss(AbstractListLoss): | ||
def listpl(x, t, α=15.0): | ||
""" | ||
The ListPL loss, a stochastic variant of ListMLE that in expectation | ||
approximates the true ListNet loss. | ||
:param x: The activation of the previous layer | ||
:param t: The target labels | ||
:param α: The smoothing factor | ||
:return: The loss | ||
""" | ||
def __init__(self, predictor, α=15.0): | ||
super(ListPLLoss, self).__init__(predictor=predictor) | ||
self.α = α | ||
|
||
def loss(self, x, t): | ||
# Sample permutation from PL(t) | ||
index = self.pl_sample(t) | ||
x = x[index] | ||
# Sample permutation from PL(t) | ||
index = _pl_sample(t, α) | ||
x = x[index] | ||
|
||
# Compute MLE loss | ||
final = logcumsumexp(x) | ||
return F.sum(final - x) | ||
# Compute MLE loss | ||
final = logcumsumexp(x) | ||
return F.sum(final - x) | ||
|
||
def pl_sample(self, t): | ||
""" | ||
Sample from the plackett luce distribution directly | ||
|
||
:param t: The target labels | ||
:return: A random permutation from the plackett-luce distribution | ||
parameterized by the target labels | ||
""" | ||
xp = cuda.get_array_module(t) | ||
if not hasattr(xp, 'asnumpy'): | ||
xp.asnumpy = lambda x: x | ||
t = t[:, 0] | ||
def _pl_sample(t, α): | ||
""" | ||
Sample from the plackett luce distribution directly | ||
probs = xp.exp(t * self.α) | ||
probs /= xp.sum(probs) | ||
return np.random.choice(probs.shape[0], probs.shape[0], replace=False, | ||
p=xp.asnumpy(probs)) | ||
:param t: The target labels | ||
:return: A random permutation from the plackett-luce distribution | ||
parameterized by the target labels | ||
""" | ||
xp = cuda.get_array_module(t) | ||
t = t[:, 0] | ||
|
||
probs = xp.exp(t * α) | ||
probs /= xp.sum(probs) | ||
return xp.random.choice(probs.shape[0], probs.shape[0], replace=False, | ||
p=probs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.