Skip to content

Commit

Permalink
Add support for broadcasting arithmetic ops (#3269)
Browse files Browse the repository at this point in the history
Summary:
PyTorch arithmetic ops support multidirectional [broadcasting](https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md) so add this to the loader

Documentation:
doxygen
Pull Request resolved: #3269

Test Plan: added tests

Differential Revision: D16565146

Pulled By: jackm321

fbshipit-source-id: d95a837eb997fad81cd060781b9b6b42c3098490
  • Loading branch information
jackm321 authored and facebook-github-bot committed Jul 31, 2019
1 parent 94d10eb commit a6f8c39
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 6 deletions.
14 changes: 10 additions & 4 deletions torch_glow/src/PyTorchModelLoader.cpp
Expand Up @@ -254,7 +254,9 @@ llvm::Error PyTorchModelLoader::loadMul(const torch::jit::Node *ptNode) {
glow::NodeValue rhs;
ASSIGN_VALUE_OR_RETURN_ERR(rhs, getGlowNodeValue(inputs[1]));

glow::MulNode *glowNode = F_.createMul("mul", lhs, rhs);
glow::MulNode *glowNode =
F_.createNodeWithBroadcast<glow::MulNode>("mul", /*axis*/ -1, lhs, rhs);

return addGlowNodeValue(outputs[0], glowNode->getResult());
}

Expand All @@ -268,7 +270,9 @@ llvm::Error PyTorchModelLoader::loadDiv(const torch::jit::Node *ptNode) {
glow::NodeValue rhs;
ASSIGN_VALUE_OR_RETURN_ERR(rhs, getGlowNodeValue(inputs[1]));

glow::DivNode *glowNode = F_.createDiv("div", lhs, rhs);
glow::DivNode *glowNode =
F_.createNodeWithBroadcast<glow::DivNode>("div", /*axis*/ -1, lhs, rhs);

return addGlowNodeValue(outputs[0], glowNode->getResult());
}

Expand All @@ -294,7 +298,8 @@ llvm::Error PyTorchModelLoader::loadAdd(const torch::jit::Node *ptNode) {
glow::NodeValue rhs;
ASSIGN_VALUE_OR_RETURN_ERR(rhs, getGlowNodeValue(inputs[1]));

glow::AddNode *glowNode = F_.createAdd("add", lhs, rhs);
glow::AddNode *glowNode =
F_.createNodeWithBroadcast<glow::AddNode>("add", /*axis*/ -1, lhs, rhs);
return addGlowNodeValue(outputs[0], glowNode->getResult());
}

Expand All @@ -320,7 +325,8 @@ llvm::Error PyTorchModelLoader::loadSub(const torch::jit::Node *ptNode) {
glow::NodeValue rhs;
ASSIGN_VALUE_OR_RETURN_ERR(rhs, getGlowNodeValue(inputs[1]));

glow::SubNode *glowNode = F_.createSub("sub", lhs, rhs);
glow::SubNode *glowNode =
F_.createNodeWithBroadcast<glow::SubNode>("sub", /*axis*/ -1, lhs, rhs);
return addGlowNodeValue(outputs[0], glowNode->getResult());
}

Expand Down
44 changes: 44 additions & 0 deletions torch_glow/tests/nodes/add_test.py
@@ -1,3 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch_glow

Expand Down Expand Up @@ -29,3 +31,45 @@ def add_inplace(a, b):
y = torch.randn(4)

jitVsGlow(add_inplace, x, y)

# Test of the PyTorch add Node on Glow with broadcasting.


def test_add_broadcast_1():

def test_f(a, b):
c = a.add(b)
return c.add(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(4, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch add Node on Glow with broadcasting.


def test_add_broadcast_2():

def test_f(a, b):
c = a.add(b)
return c.add(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(1, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch add Node on Glow with broadcasting.


def test_add_broadcast_3():

def test_f(a, b):
c = a.add(b)
return c.add(c)

x = torch.randn(4, 2)
y = torch.randn(8, 3, 4, 2)

jitVsGlow(test_f, x, y)
44 changes: 44 additions & 0 deletions torch_glow/tests/nodes/div_test.py
@@ -1,3 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch_glow

Expand All @@ -15,3 +17,45 @@ def div_basic(a, b):
y = torch.randn(4)

jitVsGlow(div_basic, x, y)

# Test of the PyTorch div Node on Glow with broadcasting.


def test_div_broadcast_1():

def test_f(a, b):
c = a.div(b)
return c.div(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(4, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch div Node on Glow with broadcasting.


def test_div_broadcast_2():

def test_f(a, b):
c = a.div(b)
return c.div(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(1, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch div Node on Glow with broadcasting.


def test_div_broadcast_3():

def test_f(a, b):
c = a.div(b)
return c.div(c)

x = torch.randn(4, 2)
y = torch.randn(8, 3, 4, 2)

jitVsGlow(test_f, x, y)
44 changes: 44 additions & 0 deletions torch_glow/tests/nodes/mul_test.py
@@ -1,3 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch_glow

Expand All @@ -15,3 +17,45 @@ def mul_basic(a, b):
y = torch.randn(4)

jitVsGlow(mul_basic, x, y)

# Test of the PyTorch mul Node on Glow with broadcasting.


def test_mul_broadcast_1():

def test_f(a, b):
c = a.mul(b)
return c.mul(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(4, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch mul Node on Glow with broadcasting.


def test_mul_broadcast_2():

def test_f(a, b):
c = a.mul(b)
return c.mul(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(1, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch mul Node on Glow with broadcasting.


def test_mul_broadcast_3():

def test_f(a, b):
c = a.mul(b)
return c.mul(c)

x = torch.randn(4, 2)
y = torch.randn(8, 3, 4, 2)

jitVsGlow(test_f, x, y)
44 changes: 44 additions & 0 deletions torch_glow/tests/nodes/sub_test.py
@@ -1,3 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch_glow

Expand All @@ -15,3 +17,45 @@ def sub_basic(a, b):
y = torch.randn(4)

jitVsGlow(sub_basic, x, y)

# Test of the PyTorch sub Node on Glow with broadcasting.


def test_sub_broadcast_1():

def test_f(a, b):
c = a.sub(b)
return c.sub(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(4, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch sub Node on Glow with broadcasting.


def test_sub_broadcast_2():

def test_f(a, b):
c = a.sub(b)
return c.sub(c)

x = torch.randn(8, 3, 4, 2)
y = torch.randn(1, 2)

jitVsGlow(test_f, x, y)

# Test of the PyTorch sub Node on Glow with broadcasting.


def test_sub_broadcast_3():

def test_f(a, b):
c = a.sub(b)
return c.sub(c)

x = torch.randn(4, 2)
y = torch.randn(8, 3, 4, 2)

jitVsGlow(test_f, x, y)
4 changes: 2 additions & 2 deletions utils/format.sh
Expand Up @@ -28,7 +28,7 @@ print_usage() {
}

fix_format() {
find lib tests/unittests/ tools/ include examples \
find lib tests/unittests/ tools/ include examples torch_glow \
-name \*.h -print0 \
-o -name \*.hpp -print0 \
-o -name \*.c -print0 \
Expand All @@ -37,7 +37,7 @@ fix_format() {
-o -name \*.cl -print0 \
| xargs -0 -P8 -n1 $CLANG_COMMAND -i;

autopep8 -i -r -j -1 --exclude="*.eggs" torch_glow utils
autopep8 -i -r -j -1 --exclude="*.eggs" --indent-size=4 torch_glow utils
}

check_format() {
Expand Down

0 comments on commit a6f8c39

Please sign in to comment.