-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
abe1411
commit 4fff7d5
Showing
3 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* author: kohill | ||
*/ | ||
#include "mobula_op.h" | ||
#ifdef USING_CUDA | ||
|
||
#if USING_CUDA | ||
#include<cuda.h> | ||
#else | ||
#define __device__ | ||
#include <cmath> | ||
#include <algorithm> | ||
using std::exp; | ||
using std::max; | ||
using std::pow; | ||
using std::log; | ||
#endif | ||
|
||
#endif | ||
namespace mobula { | ||
|
||
#define UNUSED(expr) do { (void)(expr); } while (0) | ||
|
||
template <typename T> | ||
__device__ inline T sigmoid(T x){ | ||
T max_val = max(static_cast<T>(0), -1 * x); | ||
T v0 = exp(0-max_val); | ||
return v0 / (v0 + exp(-x - max_val)); | ||
} | ||
|
||
template <typename T> | ||
__device__ inline T log_sigmoid(T x){ | ||
T max_val = max(static_cast<T>(0), -1 * x); | ||
return -1 * max_val - log(exp(0-max_val) + exp(-x - max_val)); | ||
} | ||
|
||
template <typename T> | ||
MOBULA_KERNEL focal_loss_forward_kernel(const int out_size, T alpha, T gamma, T* logits, T* targets, T* outputs) { | ||
parfor(out_size, [&](int index){ | ||
T y = targets[index]; | ||
T x = logits[index]; | ||
T sigmoid_x = sigmoid(x); | ||
T sigmoid_neg_x = sigmoid(-x); // 1 - sigmoid(x) | ||
T output = alpha * y * pow(sigmoid_neg_x, gamma) * log_sigmoid(x); | ||
output += (1 - alpha) * (1 - y) * log_sigmoid(-x) * pow(sigmoid_x, gamma); | ||
output *= -1; | ||
outputs[index] = output; | ||
}); | ||
} // focal_loss_forward_kernel | ||
|
||
template <typename T> | ||
MOBULA_KERNEL focal_loss_backward_kernel(const int out_size, T alpha, T gamma, T* logits, T* targets, T* outputs) { | ||
parfor(out_size, [&](int index){ | ||
T y = targets[index]; | ||
T x = logits[index]; | ||
T sigmoid_x = sigmoid(x); | ||
T sigmoid_neg_x = sigmoid(-x); // 1 - sigmoid(x) | ||
T output = (alpha- 1 - alpha * y) * pow(sigmoid_x, 1 + gamma); | ||
output += alpha * y * pow(sigmoid_neg_x, gamma + 1); | ||
output += (alpha - 1) * gamma * (y - 1) * sigmoid_neg_x * pow(sigmoid_x, gamma) * log_sigmoid(-x); | ||
output -= alpha * gamma * sigmoid_x * y * pow(sigmoid_neg_x, gamma) * log_sigmoid(x); | ||
output += sigmoid_x * y * pow(sigmoid_x, gamma); | ||
output *= -1; | ||
outputs[index] = output; | ||
}); | ||
} // focal_loss_forward_kernel | ||
|
||
|
||
} // namespace mobula |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import mobula | ||
from mobula.const import req | ||
import numpy as np | ||
|
||
@mobula.op.register | ||
class FocalLoss: | ||
def __init__(self, alpha=0.25, gamma=2): | ||
self.alpha = alpha | ||
self.gamma = gamma | ||
|
||
def forward(self, logits, targets): | ||
if self.req[0] == req.null: | ||
return | ||
out = self.y | ||
out_size = np.prod(out.size()) if callable(out.size) else out.size | ||
alpha = self.alpha | ||
gamma = self.gamma | ||
if self.req[0] == req.add: | ||
out_temp = self.F.zeros_like(out) | ||
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets, | ||
outputs=out_temp) | ||
self.y[:] += out_temp | ||
else: | ||
self.y[:] = 0 | ||
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets, | ||
outputs=self.y) | ||
|
||
def backward(self, dy): | ||
assert self.req[1] == "null" | ||
alpha = self.alpha | ||
gamma = self.gamma | ||
logits = self.X[0] | ||
targets = self.X[1] | ||
out_size = np.prod(targets.size()) if callable(targets.size) else targets.size | ||
if self.req[0] == req.add: | ||
out_temp = self.F.zeros_like(self.dX[0]) | ||
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets, | ||
outputs=out_temp) | ||
self.dX[0] += out_temp | ||
else: | ||
self.dX[0][:] = 0 | ||
mobula.func.focal_loss_backward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets, | ||
outputs=self.dX[0]) | ||
self.dX[0][:] = self.dX[0][:] * dy | ||
|
||
def infer_shape(self, in_shape): | ||
assert len(in_shape) == 2 | ||
assert in_shape[0] == in_shape[1] | ||
return in_shape, [in_shape[0]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import mobula | ||
|
||
mobula.op.load('FocalLoss') | ||
import mxnet as mx | ||
import mxnet.autograd as ag | ||
from mobula.testing import assert_almost_equal | ||
|
||
|
||
def BCEFocalLoss(x, target, alpha=.25, gamma=2): | ||
p = x.sigmoid() | ||
loss = alpha * target * ((1 - p) ** gamma) * mx.nd.log(p) | ||
loss = loss + (1 - alpha) * (1 - target) * (p ** gamma) * mx.nd.log(1 - p) | ||
return -loss | ||
|
||
|
||
def test_FocalLoss_mx_cpu(): | ||
ctx = mx.cpu() | ||
x = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx) | ||
y = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx) | ||
x1 = x.copy() | ||
y1 = y.copy() | ||
|
||
x.attach_grad() | ||
x1.attach_grad() | ||
|
||
with ag.record(): | ||
fl = BCEFocalLoss(x, y, alpha=.25, gamma=2) | ||
fl_mobula = mobula.op.FocalLoss(alpha=.25, gamma=2, logits=x1, targets=y1) | ||
fl.backward() | ||
fl_mobula.backward() | ||
|
||
assert_almost_equal(x.grad.asnumpy(), x1.grad.asnumpy()) | ||
assert_almost_equal(fl.asnumpy(), fl_mobula.asnumpy()) | ||
|
||
|
||
def test_FocalLoss_mx_cuda(): | ||
ctx = mx.gpu() | ||
x = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx) | ||
y = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx) | ||
x1 = x.copy() | ||
y1 = y.copy() | ||
|
||
x.attach_grad() | ||
x1.attach_grad() | ||
|
||
with ag.record(): | ||
fl = BCEFocalLoss(x, y, alpha=.25, gamma=2) | ||
fl_moubula = mobula.op.FocalLoss(alpha=.25, gamma=2, logits=x1, targets=y1) | ||
fl.backward() | ||
fl_moubula.backward() | ||
|
||
assert_almost_equal(x.grad, x1.grad) | ||
assert_almost_equal(fl, fl_moubula) |