Skip to content

Commit

Permalink
Add FocalLoss into opzoo
Browse files Browse the repository at this point in the history
  • Loading branch information
kohillyang committed Aug 13, 2020
1 parent abe1411 commit 4fff7d5
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
69 changes: 69 additions & 0 deletions opzoo/FocalLoss/FocalLoss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* author: kohill
*/
#include "mobula_op.h"
#ifdef USING_CUDA

#if USING_CUDA
#include<cuda.h>
#else
#define __device__
#include <cmath>
#include <algorithm>
using std::exp;
using std::max;
using std::pow;
using std::log;
#endif

#endif
namespace mobula {

#define UNUSED(expr) do { (void)(expr); } while (0)

template <typename T>
__device__ inline T sigmoid(T x){
T max_val = max(static_cast<T>(0), -1 * x);
T v0 = exp(0-max_val);
return v0 / (v0 + exp(-x - max_val));
}

template <typename T>
__device__ inline T log_sigmoid(T x){
T max_val = max(static_cast<T>(0), -1 * x);
return -1 * max_val - log(exp(0-max_val) + exp(-x - max_val));
}

template <typename T>
MOBULA_KERNEL focal_loss_forward_kernel(const int out_size, T alpha, T gamma, T* logits, T* targets, T* outputs) {
parfor(out_size, [&](int index){
T y = targets[index];
T x = logits[index];
T sigmoid_x = sigmoid(x);
T sigmoid_neg_x = sigmoid(-x); // 1 - sigmoid(x)
T output = alpha * y * pow(sigmoid_neg_x, gamma) * log_sigmoid(x);
output += (1 - alpha) * (1 - y) * log_sigmoid(-x) * pow(sigmoid_x, gamma);
output *= -1;
outputs[index] = output;
});
} // focal_loss_forward_kernel

template <typename T>
MOBULA_KERNEL focal_loss_backward_kernel(const int out_size, T alpha, T gamma, T* logits, T* targets, T* outputs) {
parfor(out_size, [&](int index){
T y = targets[index];
T x = logits[index];
T sigmoid_x = sigmoid(x);
T sigmoid_neg_x = sigmoid(-x); // 1 - sigmoid(x)
T output = (alpha- 1 - alpha * y) * pow(sigmoid_x, 1 + gamma);
output += alpha * y * pow(sigmoid_neg_x, gamma + 1);
output += (alpha - 1) * gamma * (y - 1) * sigmoid_neg_x * pow(sigmoid_x, gamma) * log_sigmoid(-x);
output -= alpha * gamma * sigmoid_x * y * pow(sigmoid_neg_x, gamma) * log_sigmoid(x);
output += sigmoid_x * y * pow(sigmoid_x, gamma);
output *= -1;
outputs[index] = output;
});
} // focal_loss_forward_kernel


} // namespace mobula
49 changes: 49 additions & 0 deletions opzoo/FocalLoss/FocalLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import mobula
from mobula.const import req
import numpy as np

@mobula.op.register
class FocalLoss:
def __init__(self, alpha=0.25, gamma=2):
self.alpha = alpha
self.gamma = gamma

def forward(self, logits, targets):
if self.req[0] == req.null:
return
out = self.y
out_size = np.prod(out.size()) if callable(out.size) else out.size
alpha = self.alpha
gamma = self.gamma
if self.req[0] == req.add:
out_temp = self.F.zeros_like(out)
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets,
outputs=out_temp)
self.y[:] += out_temp
else:
self.y[:] = 0
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets,
outputs=self.y)

def backward(self, dy):
assert self.req[1] == "null"
alpha = self.alpha
gamma = self.gamma
logits = self.X[0]
targets = self.X[1]
out_size = np.prod(targets.size()) if callable(targets.size) else targets.size
if self.req[0] == req.add:
out_temp = self.F.zeros_like(self.dX[0])
mobula.func.focal_loss_forward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets,
outputs=out_temp)
self.dX[0] += out_temp
else:
self.dX[0][:] = 0
mobula.func.focal_loss_backward(out_size=out_size, alpha=alpha, gamma=gamma, logits=logits, targets=targets,
outputs=self.dX[0])
self.dX[0][:] = self.dX[0][:] * dy

def infer_shape(self, in_shape):
assert len(in_shape) == 2
assert in_shape[0] == in_shape[1]
return in_shape, [in_shape[0]]
53 changes: 53 additions & 0 deletions opzoo/FocalLoss/test_focalloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import mobula

mobula.op.load('FocalLoss')
import mxnet as mx
import mxnet.autograd as ag
from mobula.testing import assert_almost_equal


def BCEFocalLoss(x, target, alpha=.25, gamma=2):
p = x.sigmoid()
loss = alpha * target * ((1 - p) ** gamma) * mx.nd.log(p)
loss = loss + (1 - alpha) * (1 - target) * (p ** gamma) * mx.nd.log(1 - p)
return -loss


def test_FocalLoss_mx_cpu():
ctx = mx.cpu()
x = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx)
y = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx)
x1 = x.copy()
y1 = y.copy()

x.attach_grad()
x1.attach_grad()

with ag.record():
fl = BCEFocalLoss(x, y, alpha=.25, gamma=2)
fl_mobula = mobula.op.FocalLoss(alpha=.25, gamma=2, logits=x1, targets=y1)
fl.backward()
fl_mobula.backward()

assert_almost_equal(x.grad.asnumpy(), x1.grad.asnumpy())
assert_almost_equal(fl.asnumpy(), fl_mobula.asnumpy())


def test_FocalLoss_mx_cuda():
ctx = mx.gpu()
x = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx)
y = mx.nd.random.randn(300, 300, dtype="float64", ctx=ctx)
x1 = x.copy()
y1 = y.copy()

x.attach_grad()
x1.attach_grad()

with ag.record():
fl = BCEFocalLoss(x, y, alpha=.25, gamma=2)
fl_moubula = mobula.op.FocalLoss(alpha=.25, gamma=2, logits=x1, targets=y1)
fl.backward()
fl_moubula.backward()

assert_almost_equal(x.grad, x1.grad)
assert_almost_equal(fl, fl_moubula)

0 comments on commit 4fff7d5

Please sign in to comment.