Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to mix RF code with PT code #1287

Closed
albertz opened this issue Mar 27, 2023 · 5 comments
Closed

How to mix RF code with PT code #1287

albertz opened this issue Mar 27, 2023 · 5 comments
Assignees

Comments

@albertz
Copy link
Member

albertz commented Mar 27, 2023

For reference, see #1120, #1120 (comment).

RF = import returnn.frontend as rf.
RF code = Code which uses functions/classes from rf (rf.Module, rf.matmul etc).
PT code = pure PyTorch code, just using torch.

It is of high priority that mixing pure PT code with RF code is easy. In both ways, e.g. when having some pure PT code/module, it should be simple to embed some RF code/module in it, and vice versa, i.e. when having some RF code/module, it should be simple to embed some PT code/module in it.

I distinguish a bit between just code (function calls) and modules (rf.Module or torch.nn.Module).

I think just function calls are probably already straight-forward. RF functions get Tensor and Dim as arguments and return Tensor and maybe Dim again. You get the raw torch.Tensor by accessing raw_tensor. You can also easily create a Tensor and Dim on-the-fly. So both ways should be simple.

Example using PT inside RF code:

def rf_func(x: rf.Tensor) -> rf.Tensor:
  y = x.copy_template()  # prepare output Tensor for raw PT call
  y.raw_tensor = torch.nn.functional.elu(x.raw_tensor)  # raw PT call
  return y

Example using RF inside PT code:

def pt_func(x: torch.Tensor) -> torch.Tensor:
  # create RETURNN Tensor for RF
  x_rf = rf.Tensor("x", dims=[Dim(None) for i in range(x.ndim)], dtype=str(x.dtype).split(".")[-1], raw_tensor=x)
  y_rf = rf.elu(x_rf)  # RF call
  return y_rf.raw_tensor  # get raw PT tensor

For modules, it is a bit unclear. rf.Module is similar to torch.nn.Module, but they don't share any common base class. rf.Parameter also is different from torch.nn.Parameter. We maybe could have some automatic rf_module_to_pt_module and vice versa?

Example using PT module inside a RF module:

class Module(rf.Module):
  def __init__(self):
    self.submod = pt_module_to_rf_module(torch.nn.TransformerEncoderLayer(...))

  def __call__(self, x: rf.Tensor) -> rf.Tensor:
    # just like rf_func
    x = x.copy_permute(...)  # permute such that it is like what submod expects, B, T, F or so
    y = x.copy_template()
    y.raw_tensor = self.submod(x.raw_tensor)
    return y

Example using RF module inside a PT module:

class Module(torch.nn.Module):
  def __init__(self):
    self.submod = rf_module_to_pt_module(rf.TransformerEncoderLayer(...))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_dim = Dim(None)
    time_dim = Dim(None)
    feat_dim = Dim(...)
    x_rf = rf.Tensor("x", dims=(batch_dim, time_dim, feat_dim), dtype=str(x.dtype).split(".")[-1], raw_tensor=x)
    y_rf = self.submod(x_rf)
    y_rf = y_rf.copy_permute(...)  # to whatever the PT module should return
    return y_rf.raw_tensor
@albertz
Copy link
Member Author

albertz commented Mar 27, 2023

There was also the suggestion by @braddockcg to automatically wrap PT code/modules (see #1120 (comment), #1120 (comment)) in a way that automatically adjusts the signatures in some way to convert between our Tensor and torch.Tensor and adds the Dim args. But I'm pessimistic that there is a good automatic way to do this, and I think it would not really provide much benefit, despite having more complexity for such automatic wrapping, which would anyway not really work in all cases.

@patrick-wilken
Copy link
Contributor

I think you mixed up the two cases, in the rf.Module the submodule should get the rf.Tensor, right?

For rf_module_to_pt_module would something like this already work?

TorchModuleFromFrontendModule(torch.nn.Module):
    def __init__(rf_module, input_templates: Dict[str, rf.Tensor]):
        super().__init__()
        self._rf_module = rf_module
        self._input_templates = input_templates  # or is there a way to get this from rf_module?
        for name, parameter in rf_module.named_parameters():
            self.register_parameter(name, parameter.raw_tensor)
            
    def forward(raw_inputs: Dict[str, torch.Tensor]):
         inputs = self._input_templates.copy()
         for name, raw_tensor in raw_inputs.items():
             inputs[name].raw_tensor = raw_tensor
         outputs = self._rf_module(inputs)
         return {name: tensor.raw_tensor for name, tensor in outputs}

I think torch.nn.Module is all about passing functions like train() and .gpu() recursively to the submodules and their parameters. Simply registering the parameters at top level without submodules should be enough, I guess?

With that you can use the Frontend to define torch Modules that you can then arbitrarily extend with custom PyTorch code.
This direction is kind of the equivalent of extracting the network dict from an rf.Module for the case TF layers backend.

The other direction doesn't sound so useful to me. Well, it would be needed to support torch Models if get_model() -> rf.Module. But I would tend to keep the torch engine independent of the rf.Module class except for this:

model = get_model_func()
if isinstance(model, rf.Module):
    model = TorchModuleFromFrontendModule(model)
assert isinstance(model, torch.nn.Module)
self._model = model    

Or as an alternative:

get_model_func = self.config.typed_value("get_model")  # -> rf.Module
if get_model_func:
    self._model = TorchModuleFromFrontendModule(get_model_func())
else:
    get_model_func = self.config.typed_value("get_torch_model")  # -> torch.nn.Module
    self._model = get_model_func()

@albertz
Copy link
Member Author

albertz commented Mar 27, 2023

I think you mixed up the two cases, in the rf.Module the submodule should get the rf.Tensor, right?

In "Example using PT module inside a RF module", the submodule is a PT module, so it gets a normal torch.Tensor.

rf_module_to_pt_module would wrap __call__ to forward, but leave all the args just as-is, just just passes *args, **kwargs around. So it would still expect rf.Tensor as arguments. But this shouldn't really be a problem. The same for any other member functions. The main thing it would do is to handle the parameters, i.e. wrap all rf.Parameter as torch.nn.Parameter. I'm not sure whether it would do the same recursively for any sub modules, or how we need to treat sub modules, or parameters of sub modules. Or maybe just registering all deep parameters in the top module. But I'm not sure it is allowed to register params named like sub1.param1 (with a dot because it's from the sub module)? The parameters is what actually matters most for other PT code then, e.g. the optimizer. Maybe we also need to handle train() or gpu() somehow, as you say.

The other direction doesn't sound so useful to me.

You mean using a RF module inside a PT module? Why is this not relevant? Maybe it's less often used, but for example, when you are testing some existing external PT modules for some experiments, and now just want to switch out their self-attention implementation by ours with our custom rel pos encoding implementation, or sth like that, then you can do that.

In any case, it's basically just the same logic, just reverse. It should be quite straightforward. Maybe a small question is what dim tags to use for the rf.Parameter, but it could just create unique new dim tags for every single dim.

@albertz
Copy link
Member Author

albertz commented Apr 4, 2023

One other set of functions on modules are weight-normalization (weight_norm). They not simply assign values to a param like in param-init (which is not really a problem with what we proposed here) but actually reassign the param, and create new params. We have an implementation for weight-norm in RC, and PT also has one. RC was only graph-based so far, so this was a bit simpler. PT uses register_forward_pre_hook to explicitly calculate it for every module-forward. We actually have issue rwth-i6/returnn_common#250 about how to unify the RC/RF implementation to support both graph-based and eager-based frameworks. So the question is, using rf_module_to_pt_module, would the PT weight_norm work correctly on the resulting converted module? Or rather, can we make it work? Is this possible, in a not-too-hacky way? Or vice versa, using pt_module_to_rf_module, could the RF weight_norm work correctly on the resulting converted module?

Looking at the PT weight_norm, it actually creates two new parameters (<name>_g and <name>_v) (register_parameter), and then via register_forward_pre_hook, it registers a function which will do this:

setattr(module, self.name, self.compute_weight(module))

So, our converted PT module, could it support additional register_parameter calls? And more specifically, register_forward_pre_hook?

The PT code is anyway a bit problematic as this register_forward_pre_hook really only covers module(...) calls, via __call__, but no other function calls. For most modules, this should be fine, though.

We could also implement our __setattr__ to get the setattr calls. But then what? This would do the corresponding assignment in the original RF module? Similarly, we can overwrite register_parameter (and related). But would that always work?

We could also say, this is just not supported. But then it would be good if we can detect such usages and not just silently ignore it?

Similar to weight-norm is also weight-dropout, or basically any transformations on the weights.

Maybe custom Python descriptors can be useful?

albertz added a commit that referenced this issue Apr 8, 2023
Makes mixing RF code with pure PT code easier.
#1287
@albertz albertz closed this as completed in c11c15b Apr 8, 2023
@albertz
Copy link
Member Author

albertz commented Apr 8, 2023

I pushed now a simple initial implementation.

I also relaxed rf.convert_to_tensor a bit, such that you can directly call rf.convert_to_tensor(pt_tensor) without needing to specify the dims.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants