diff --git a/examples/plot_piecewise_linear.py b/examples/plot_piecewise_linear.py index 552fa49..e624cbe 100644 --- a/examples/plot_piecewise_linear.py +++ b/examples/plot_piecewise_linear.py @@ -19,7 +19,9 @@ import matplotlib.pyplot as plt import torch from td3a_cpp_deep.fcts.piecewise_linear import ( - PiecewiseLinearFunction, PiecewiseLinearFunctionC) + PiecewiseLinearFunction, + PiecewiseLinearFunctionC, + PiecewiseLinearFunctionCBetter) def train_piecewise_linear(x, y, device, cls, @@ -80,9 +82,21 @@ def train_piecewise_linear(x, y, device, cls, print("duration=%f, alpha_neg=%f alpha_pos=%f" % (end - begin, alpha_neg, alpha_pos)) +################################ +# C++ implementation, second try +# ++++++++++++++++++++++++++++++ + +begin = time.perf_counter() +losses, alpha_neg, alpha_pos = train_piecewise_linear( + x, y, device, PiecewiseLinearFunctionCBetter) +end = time.perf_counter() +print("duration=%f, alpha_neg=%f alpha_pos=%f" % + (end - begin, alpha_neg, alpha_pos)) + ################################# -# The C++ implementation is very close to the python code -# and is not faster. +# The C++ implementation is very close to the python code. +# The second implementation in C++ is faster because +# it reuses created tensors. ################################## # Graphs diff --git a/td3a_cpp_deep/fcts/piecewise_linear.py b/td3a_cpp_deep/fcts/piecewise_linear.py index 8d7abe8..b21f7e2 100644 --- a/td3a_cpp_deep/fcts/piecewise_linear.py +++ b/td3a_cpp_deep/fcts/piecewise_linear.py @@ -4,7 +4,10 @@ import torch from torch.autograd import Function from .piecewise_linear_c import ( - piecewise_linear_forward, piecewise_linear_backward) + piecewise_linear_forward, + piecewise_linear_backward, + piecewise_linear_forward_better, + piecewise_linear_backward_better) class PiecewiseLinearFunction(Function): @@ -57,3 +60,25 @@ def backward(ctx, grad_output): weight, grad_alpha_neg, grad_alpha_pos = piecewise_linear_backward( grad_output, x, sign, weight) return weight, grad_alpha_neg, grad_alpha_pos + + +class PiecewiseLinearFunctionCBetter(Function): + """ + Same function as :class:`PiecewiseLinearFunctionC + `, + the implementation of forward and backward are is reducing the + memory allocations. + """ + + @staticmethod + def forward(ctx, x, alpha_neg, alpha_pos): + outputs = piecewise_linear_forward_better(x, alpha_neg, alpha_pos) + ctx.save_for_backward(*outputs[1:]) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + x, sign, weight = ctx.saved_tensors + weight, grad_alpha_neg, grad_alpha_pos = ( + piecewise_linear_backward_better(grad_output, x, sign, weight)) + return weight, grad_alpha_neg, grad_alpha_pos diff --git a/td3a_cpp_deep/fcts/piecewise_linear_c.cpp b/td3a_cpp_deep/fcts/piecewise_linear_c.cpp index dd06d83..4e7ee14 100644 --- a/td3a_cpp_deep/fcts/piecewise_linear_c.cpp +++ b/td3a_cpp_deep/fcts/piecewise_linear_c.cpp @@ -5,18 +5,30 @@ std::vector piecewise_linear_forward( torch::Tensor input, torch::Tensor alpha_neg, torch::Tensor alpha_pos) { - + // python code // sign = (input >= 0).to(torch.float32) // weight = (sign * alpha_pos + (- sign + 1) * alpha_neg) // output = input * weight - + auto sign = (input >= 0).to(torch::kFloat32); auto weight = (sign * alpha_pos) + (- sign + 1) * alpha_neg; return {input * weight, input, sign, weight}; } +std::vector piecewise_linear_forward_better( + torch::Tensor input, + torch::Tensor alpha_neg, + torch::Tensor alpha_pos) { + + auto sign = (input >= 0).to(torch::kFloat32); + auto weight = sign * (alpha_pos - alpha_neg); + weight += alpha_neg; + return {input * weight, input, sign, weight}; +} + + std::vector piecewise_linear_backward( torch::Tensor grad_output, torch::Tensor input, @@ -28,13 +40,29 @@ std::vector piecewise_linear_backward( // grad_alpha_neg = (input * grad_output * (- sign + 1)).sum(dim=0, keepdim=True) // grad_alpha_pos = (input * grad_output * sign).sum(dim=0, keepdim=True) // return grad_input, grad_alpha_neg, grad_alpha_pos - + auto grad_alpha_neg = at::sum((input * grad_output * (- sign + 1)), {0}, true); auto grad_alpha_pos = at::sum((input * grad_output * sign), {0}, true); return {weight, grad_alpha_neg, grad_alpha_pos}; } +std::vector piecewise_linear_backward_better( + torch::Tensor grad_output, + torch::Tensor input, + torch::Tensor sign, + torch::Tensor weight) { + + auto ig = input * grad_output; + auto igs = ig * sign; + auto grad_alpha_pos = at::sum(igs, {0}, true); + igs *= -1; + igs += ig; + auto grad_alpha_neg = at::sum(igs, {0}, true); + return {weight, grad_alpha_neg, grad_alpha_pos}; +} + + PYBIND11_MODULE(piecewise_linear_c, m) { m.doc() = #if defined(__APPLE__) @@ -46,4 +74,9 @@ PYBIND11_MODULE(piecewise_linear_c, m) { m.def("piecewise_linear_forward", &piecewise_linear_forward, "PiecewiseLinearC forward"); m.def("piecewise_linear_backward", &piecewise_linear_backward, "PiecewiseLinearC backward"); + + m.def("piecewise_linear_forward_better", &piecewise_linear_forward_better, + "PiecewiseLinearC improved forward"); + m.def("piecewise_linear_backward_better", &piecewise_linear_backward_better, + "PiecewiseLinearC improved backward"); } diff --git a/tests/test_fcts_piecewise_linear.py b/tests/test_fcts_piecewise_linear.py index e236523..3e22415 100644 --- a/tests/test_fcts_piecewise_linear.py +++ b/tests/test_fcts_piecewise_linear.py @@ -4,12 +4,28 @@ import unittest import torch from td3a_cpp_deep.fcts.piecewise_linear import ( - PiecewiseLinearFunction, PiecewiseLinearFunctionC) + PiecewiseLinearFunction, + PiecewiseLinearFunctionC, + PiecewiseLinearFunctionCBetter) class TestFctsPiecewiseLinear(unittest.TestCase): - def piecewise_linear(self, cls, device): + def test_equal_forward(self): + alpha_pos = torch.tensor([1], dtype=torch.float32) + alpha_neg = torch.tensor([0.5], dtype=torch.float32) + x = torch.tensor([-2, 1], dtype=torch.float32) + res1 = PiecewiseLinearFunction.apply(x, alpha_neg, alpha_pos) + res2 = PiecewiseLinearFunctionC.apply(x, alpha_neg, alpha_pos) + res3 = PiecewiseLinearFunctionCBetter.apply(x, alpha_neg, alpha_pos) + for a, b, c in zip(res1, res2, res3): + na = a.cpu().detach().numpy().tolist() + nb = b.cpu().detach().numpy().tolist() + nc = c.cpu().detach().numpy().tolist() + self.assertEqual(na, nb) + self.assertEqual(na, nc) + + def piecewise_linear(self, cls, device, verbose=False, max_iter=400): x = torch.randn(100, 1, device=device, dtype=torch.float32) y = x * 0.2 + (x > 0).to(torch.float32) * x * 1.5 @@ -23,7 +39,7 @@ def piecewise_linear(self, cls, device): learning_rate = 1e-4 fct = cls.apply - for t in range(400): + for t in range(max_iter): y_pred = fct(x, alpha_neg, alpha_pos) loss = (y_pred - y).pow(2).sum() @@ -31,6 +47,8 @@ def piecewise_linear(self, cls, device): losses.append(loss) with torch.no_grad(): + if verbose: + print(alpha_neg.grad, alpha_pos.grad) alpha_pos -= learning_rate * alpha_pos.grad alpha_neg -= learning_rate * alpha_neg.grad @@ -38,9 +56,10 @@ def piecewise_linear(self, cls, device): alpha_pos.grad.zero_() alpha_neg.grad.zero_() - self.assertTrue(losses[-1] < 1) - self.assertTrue(abs(alpha_neg - 0.2) < 0.2) - self.assertTrue(abs(alpha_pos - 1.7) < 0.2) + if max_iter > 300: + self.assertTrue(losses[-1] < 1) + self.assertTrue(abs(alpha_neg - 0.2) < 0.2) + self.assertTrue(abs(alpha_pos - 1.7) < 0.2) def test_piecewise_linear_cpu(self): self.piecewise_linear(PiecewiseLinearFunction, torch.device('cpu')) @@ -56,6 +75,21 @@ def test_piecewise_linear_c_cpu(self): def test_piecewise_linear_c_gpu(self): self.piecewise_linear(PiecewiseLinearFunctionC, torch.device("cuda:0")) + def test_piecewise_linear_c_cpu_better(self): + self.piecewise_linear( + PiecewiseLinearFunction, torch.device('cpu'), + verbose=True, max_iter=3) + self.piecewise_linear( + PiecewiseLinearFunctionCBetter, torch.device('cpu'), + verbose=True, max_iter=3) + self.piecewise_linear( + PiecewiseLinearFunctionCBetter, torch.device('cpu')) + + @unittest.skipIf(not torch.cuda.is_available(), reason="no GPU") + def test_piecewise_linear_c_gpu_better(self): + self.piecewise_linear( + PiecewiseLinearFunctionCBetter, torch.device("cuda:0")) + if __name__ == '__main__': unittest.main()