Skip to content

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Dec 1, 2021

Stack from ghstack:

JIT optimization passes are part of the CPU-only build (i.e. necessary GPU flags are not passed in). This separates the implementation of frozen_conv_add_relu_fusion so that the GPU-enabled implementation is registered at runtime (if it is available)

Test Plan:
In the following script, conv_add_relu fusion is not observed without this change, but is observed when this change is added.

from typing import List, Optional

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))
        self.add_tensor = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))

    def forward(
        self,
        inp: torch.Tensor,
        bias: Optional[torch.Tensor],
        stride: List[int],
        padding: List[int],
        dilation: List[int],
        groups: int,
    ):
        # weight = torch.zeros((3, 3, 7, 7), device="cuda")
        inp = inp.to("cuda")
        conv_result = torch.conv2d(
            inp, self.weight, bias, stride, padding, dilation, groups
        )
        add_result = conv_result.add_(self.add_tensor)
        return add_result.relu_()

    torch.jit.export
    def make_prediction(self, inp: torch.Tensor):
        bias = None
        groups = 1
        stride = (1, 1)
        padding = (0, 0)
        dilation = (1, 1)

        return self.forward(inp, bias, stride, padding, dilation, groups)

if __name__ == "__main__":
    # generate some sample input
    groups = 1
    channels_in = 3
    channels_out = 3
    kernel_size = (7, 7)
    stride = (1, 1)
    padding = (0, 0)
    dilation = (1, 1)
    inp = torch.rand((64, 3, 432, 432))
    weight = torch.rand(
        (channels_out, channels_in, kernel_size[0], kernel_size[1]), device="cuda"
    )
    bias = None

    model = Model()
    model.eval()
    script = torch.jit.script(model)
    script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)

    print("~~~~ FORWARD ~~~~")
    print(script.graph)

    print("with preserved_attrs")
    print(torch.sum(script.forward(inp, bias, stride, padding, dilation, groups)))

fbshipit-source-id: c0f10da4b9540c588819efe3ec540baa0fae4b35

…68149)

JIT optimization passes are part of the CPU-only build (i.e. necessary GPU flags are not passed in). This separates the implementation of frozen_conv_add_relu_fusion so that the GPU-enabled implementation is registered at runtime (if it is available)
ghstack-source-id: 143676384

Test Plan:
In the following script, conv_add_relu fusion is not observed without this change, but is observed when this change is added.
```
from typing import List, Optional

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))
        self.add_tensor = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))

    def forward(
        self,
        inp: torch.Tensor,
        bias: Optional[torch.Tensor],
        stride: List[int],
        padding: List[int],
        dilation: List[int],
        groups: int,
    ):
        # weight = torch.zeros((3, 3, 7, 7), device="cuda")
        inp = inp.to("cuda")
        conv_result = torch.conv2d(
            inp, self.weight, bias, stride, padding, dilation, groups
        )
        add_result = conv_result.add_(self.add_tensor)
        return add_result.relu_()

    torch.jit.export
    def make_prediction(self, inp: torch.Tensor):
        bias = None
        groups = 1
        stride = (1, 1)
        padding = (0, 0)
        dilation = (1, 1)

        return self.forward(inp, bias, stride, padding, dilation, groups)

if __name__ == "__main__":
    # generate some sample input
    groups = 1
    channels_in = 3
    channels_out = 3
    kernel_size = (7, 7)
    stride = (1, 1)
    padding = (0, 0)
    dilation = (1, 1)
    inp = torch.rand((64, 3, 432, 432))
    weight = torch.rand(
        (channels_out, channels_in, kernel_size[0], kernel_size[1]), device="cuda"
    )
    bias = None

    model = Model()
    model.eval()
    script = torch.jit.script(model)
    script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)

    print("~~~~ FORWARD ~~~~")
    print(script.graph)

    print("with preserved_attrs")
    print(torch.sum(script.forward(inp, bias, stride, padding, dilation, groups)))
```

fbshipit-source-id: c0f10da4b9540c588819efe3ec540baa0fae4b35

[ghstack-poisoned]
@pytorch-probot
Copy link

pytorch-probot bot commented Dec 1, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/16e51bfd7f1779d5eedc245073f678f5769e514a/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Dec 1, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 16e51bf (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

davidberard98 added a commit that referenced this pull request Dec 1, 2021
…68149)

JIT optimization passes are part of the CPU-only build (i.e. necessary GPU flags are not passed in). This separates the implementation of frozen_conv_add_relu_fusion so that the GPU-enabled implementation is registered at runtime (if it is available)
ghstack-source-id: 143676384

Test Plan:
In the following script, conv_add_relu fusion is not observed without this change, but is observed when this change is added.
```
from typing import List, Optional

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))
        self.add_tensor = torch.nn.Parameter(torch.rand((3, 3, 7, 7), device="cuda"))

    def forward(
        self,
        inp: torch.Tensor,
        bias: Optional[torch.Tensor],
        stride: List[int],
        padding: List[int],
        dilation: List[int],
        groups: int,
    ):
        # weight = torch.zeros((3, 3, 7, 7), device="cuda")
        inp = inp.to("cuda")
        conv_result = torch.conv2d(
            inp, self.weight, bias, stride, padding, dilation, groups
        )
        add_result = conv_result.add_(self.add_tensor)
        return add_result.relu_()

    torch.jit.export
    def make_prediction(self, inp: torch.Tensor):
        bias = None
        groups = 1
        stride = (1, 1)
        padding = (0, 0)
        dilation = (1, 1)

        return self.forward(inp, bias, stride, padding, dilation, groups)

if __name__ == "__main__":
    # generate some sample input
    groups = 1
    channels_in = 3
    channels_out = 3
    kernel_size = (7, 7)
    stride = (1, 1)
    padding = (0, 0)
    dilation = (1, 1)
    inp = torch.rand((64, 3, 432, 432))
    weight = torch.rand(
        (channels_out, channels_in, kernel_size[0], kernel_size[1]), device="cuda"
    )
    bias = None

    model = Model()
    model.eval()
    script = torch.jit.script(model)
    script = torch.jit.freeze(script)
    script = torch.jit.optimize_for_inference(script)

    print("~~~~ FORWARD ~~~~")
    print(script.graph)

    print("with preserved_attrs")
    print(torch.sum(script.forward(inp, bias, stride, padding, dilation, groups)))
```

fbshipit-source-id: c0f10da4b9540c588819efe3ec540baa0fae4b35

ghstack-source-id: e921804
Pull Request resolved: #69253
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 1, 2021
@facebook-github-bot facebook-github-bot deleted the gh/davidberard98/22/head branch January 1, 2022 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants