Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Potential ways for making Transforms V2 classes JIT-scriptable #6711

Closed
datumbox opened this issue Oct 5, 2022 · 5 comments
Closed

Comments

@datumbox
Copy link
Contributor

datumbox commented Oct 5, 2022

🚀 The feature

Note: This is an exploratory proof-of-concept to discuss potential workarounds for offering limited support of JIT in our Transforms V2 Classes. I am NOT advocating for following this approach. I'm hoping we can kick off the discussion for other alternative and simpler approaches.

Currently the Transforms V2 classes are not JIT-scriptable. This breaks BC and will make the rollout of the new API harder. Here are some of the choices that are incompatible with JIT:

  1. We wanted to support arbitrary number of inputs.
  2. We rely on Tensor Subclassing to do the dispatch to the right kernel.
  3. We use real typing information on the inputs which often includes types that are not scriptable.
  4. We opted for using more Pythonic idioms (such as for ... else)

Points 3 & 4 could be addressed by (painful) refactoring, nevertheless points 1 & 2 are our main blockers.

To ensure our users can still do inference using JIT, we offer presets/transforms attached to each model weights. Those will remain JIT-scriptable. In addition, we applied a workaround (#6553) to maintain the F dispatcher JIT-scriptable for plain Tensors. Hopefully these mitigations will help most users migrate easier to the new API.

But what if they don't? Many downstream users might want to continue relying on transforms such as Resize, CenterCrop, Pad etc for inference. In that case, one option could be to offer JIT-scriptable alternatives that work only for pure tensors. Another alternative is to write a utility that can modify the existing implementations on-the-fly to update key functions and make them JIT-scriptable.

Motivation, pitch

This is a proof-of-concept of how such a utility can work. It only supports a handful of transforms (due to points 3 & 4 from above) but it can be extended to support more.

There are 2 approaches show-cased below:

  1. We use ast to replace on-the-fly problematic idioms from the Transform classes. Since JIT also uses ast internally, we need to make the updated code available to JIT during scripting.
  2. We replace the forward() to remove the packing/unpacking of arbitrary number of inputs. We also hardcode plain tensors as the only accepted input type.
import ast
import inspect
import tempfile
import torch
import types

from torchvision import transforms as V1
from torchvision.prototype import transforms as V2
from torchvision.prototype import features


class JITWrapper(torch.nn.Module):

    def __init__(self, cls, *args, **kwargs):
        super().__init__()
        # Patch _transform types, can be avoided by defining directly JIT-scriptable types
        code = inspect.getsource(cls)
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                node.name = f"{cls.__name__}JIT"
            elif isinstance(node, ast.FunctionDef):
                if node.name == "_transform":
                    node.args.args[1].annotation.id = "features.InputTypeJIT"
                    node.returns.id = "features.InputTypeJIT"
        source = ast.unparse(tree)

        # Writes the source on a temp file. Needed for JIT's inspect calls to work properly.
        with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp:
            temp.write(source)
            filename = temp.name

        # Compiles the new modified Class from source
        code = compile(source, filename, "exec")
        mod = {}
        exec(code, vars(inspect.getmodule(cls)), mod)
        cls = next(iter(mod.values()))

        # initialize transform
        transform = cls(*args, **kwargs)

        # Patch forward
        if hasattr(transform, "_jit_forward"):
            # Use the one defined in the class if available
            transform.forward = transform._jit_forward
        else:
            # Use the default implementation
            setattr(transform, "forward", types.MethodType(JITWrapper.__default_jit_forward, transform))

        self._wrapped = transform

    @staticmethod
    def __default_jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        params = self._get_params(inputs)
        result = self._transform(inputs, params)
        return result

    def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        return self._wrapped.forward(inputs)


def assert_jit_scriptable(t, inpt):
    torch.manual_seed(0)
    eager_out = t(inpt)

    t_scripted = torch.jit.script(t)
    with tempfile.NamedTemporaryFile(delete=False) as temp:
        t_scripted.save(temp.name)
        t_scripted = torch.jit.load(temp.name)

    torch.manual_seed(0)
    script_out = t_scripted(inpt)
    torch.testing.assert_close(eager_out, script_out)
    return script_out


img = torch.randn((1, 3, 224, 224))

t = V1.Resize((32, 32))
out1 = assert_jit_scriptable(t, img)
print("T1: OK")

t = JITWrapper(V2.Resize, (32, 32))
out2 = assert_jit_scriptable(t, img)
print("T2: OK")

torch.testing.assert_close(out1, out2)
print("T1 == T2: OK")

The above works on our latest main without modifications:

T1: OK
T2: OK
T1 == T2: OK

This approach can currently only support a handful of simple Transforms, that don't require overwriting the forward() and that contain most of their logic inside their _get_params() and _transform() methods. Many such simple transforms are still not supported because they inherit from _RandomApplyTransform which does the random call in its forward (this could be refactored to move to _get_params()). The rest of the existing inference transforms can be supported by addressing points 3 & 4 from above.

The above approach is very over-engineered, brittle and opaque because it tries to fix the JIT-scriptability issues without any modifications on the code-base for the selected example. If we accept minor refactoring on the existing classes, we can remove the ast logic. We could also avoid defining a default JIT-compatible forward by explicitly defining such a method on the original class when available. Here is one potential simplified version that would require changes on our current API:

class JITWrapper(torch.nn.Module):

    def __init__(self, transform: Transform):
        super().__init__()
        # Patch forward
        if hasattr(transform, "_jit_forward"):
            # Use the one defined in the class if available, should reuse `_get_params` and `_transform`
            transform.forward = transform._jit_forward
        else:
            raise Exception(f"The {cls.__name__} transform doesn't support scripting")

        self._wrapped = transform

    def forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        return self._wrapped.forward(inputs)


class Resize(Transform):
    # __init__ and _get_params() goes here

    def _transform(self, inpt: features.InputTypeJIT, params: Dict[str, Any]) -> features.InputTypeJIT:
        # we changed the types. Everything else in the method should be the same

    def _jit_forward(self, inputs: features.InputTypeJIT) -> features.InputTypeJIT:
        params = self._get_params(inputs)
        result = self._transform(inputs, params)
        return result

Alternatives

There are several other alternatives we could follow. One of them could be to offer JIT-scriptable versions for a limited number of Transforms that are commonly used during inference. Another one could be to make some of our transforms FX-traceable instead of JIT-scriptable. Though not all classes can become traceable (because their behaviour branches based on the input), considering making them compatible will future proof us for PyTorch 2.

Additional context

No response

cc @vfdev-5 @bjuncek @pmeier

@YosuaMichael
Copy link
Contributor

@datumbox I think this wrapper approach is quite similar with how detectron2 did it. They have this torchscript_patch.py which I found from this article about torchscript.

@pmeier pmeier self-assigned this Jan 24, 2023
@pmeier
Copy link
Contributor

pmeier commented Jan 25, 2023

Depending on how long we want to support scriptability, this might be pretty easy. Assuming that we only support it until v2 is stable and deprecate and remove it together with v1 afterwards, we can simply do:

class MyTransformV2(transforms.Transform):
    def __init__(self, foo, bar):
        super().__init__()
        self.foo = foo
        self.bar = bar

    ...

    def __prepare_scriptable__(self):
        # This hook is called early by `torch.jit.script`. See
        # https://github.com/pytorch/pytorch/blob/a6ac922eabee8fce7a48dedac81e82ac8cfe9a45/torch/jit/_script.py#L1284-L1288
        # https://github.com/pytorch/pytorch/blob/a6ac922eabee8fce7a48dedac81e82ac8cfe9a45/torch/jit/_script.py#L982
        # If this method exists, its return value is used over the original object for scripting.
        return MyTransformV1(self.foo, self.bar)

If we want to support scriptability for longer, we can still use the hook, but need to return something custom. One option is of course here to copy-paste the important v1 code in there.

I'm currently exploring using this hook together with AST rewriting proposed in the top comment to automate that.

@NicolasHug
Copy link
Member

Thanks Philip, that looks very promising and could potentially remove a major roadblock toward migration to v2. As discussed offline, if the __prepare_scriptable__ hook works out as we think it does, we can consider implementing it on all v2 transforms and keep the v1 transforms around for as long as needed. Possibly in a __deprecated_and_unmaintained_transforms_AVOID_USING_THIS_MODULE module.

@pmeier
Copy link
Contributor

pmeier commented Jan 25, 2023

The name has the same energy as the pytest option disable_test_id_escaping_and_forfeit_all_rights_to_community_support 😆

@pmeier
Copy link
Contributor

pmeier commented Jan 31, 2023

Closed in #7135.

@pmeier pmeier closed this as completed Jan 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants