Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recursive data type in TorchScript #42487

Open
yf225 opened this issue Aug 3, 2020 · 2 comments
Open

Support recursive data type in TorchScript #42487

yf225 opened this issue Aug 3, 2020 · 2 comments
Assignees
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module weeks

Comments

@yf225
Copy link
Contributor

yf225 commented Aug 3, 2020

馃殌 Feature

It would be really awesome to support recursive data type in TorchScript. For example:

import torch
from typing import Dict

class TypedDataDict(object):
  def __init__(self):
    self.str_to_dict: Dict[str, 'TypedDataDict'] = {}

  def set_str_to_dict(self, value: Dict[str, 'TypedDataDict']):
    self.str_to_dict = value


class TestModule(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, input):
    return TypedDataDict().set_str_to_dict({"123": TypedDataDict()})


m = TestModule()
m_scripted = torch.jit.script(m)
m_scripted(torch.tensor(1.))

'''
Currently throws:

RuntimeError: 
Assignment to attribute 'str_to_dict' cannot be of a type that contains class '__torch__.TypedDataDict'.
Classes that recursively contain instances of themselves are not yet supported:
  File "test_yf225.py", line 6
  def __init__(self):
    self.str_to_dict: Dict[str, 'TypedDataDict'] = {}
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'TypedDataDict.__init__' is being compiled since it was called from '__torch__.TypedDataDict'
  File "test_yf225.py", line 17
  def forward(self, input):
    return TypedDataDict().set_str_to_dict({"123": TypedDataDict()})
           ~~~~~~~~~~~~~ <--- HERE
'__torch__.TypedDataDict' is being compiled since it was called from 'TestModule.forward'
  File "test_yf225.py", line 17
  def forward(self, input):
    return TypedDataDict().set_str_to_dict({"123": TypedDataDict()})
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'''

Motivation

Many DPER3 modules take input of nested data types like Dict[str, Dict[str, Dict[str, torch.Tensor]]], and it will be really hard to maintain if we had to add support to all variations of those nested data types to TypedDataDict (which is DPER3's typed dictionary object class). But if recursive data type is supported, then we only need Dict[str, TypedDataDict] and Dict[str, torch.Tensor] to cover all nesting possibilities.

Having recursive data type support will greatly speed up the work for moving all PyPer models to 100% TorchScript, demonstrating TorchScript's production readiness for large-scale ranking models.

cc. @wanchaol @suo

cc @ezyang @gchanan @zou3519 @suo @gmagogsfm

@yf225 yf225 added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 3, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Aug 3, 2020
@SplitInfinity SplitInfinity moved this from Need triage to HIGH PRIORITY in JIT Triage Aug 3, 2020
@SplitInfinity SplitInfinity added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 3, 2020
@gmagogsfm gmagogsfm assigned suo and unassigned gmagogsfm Aug 11, 2020
@gmagogsfm
Copy link
Contributor

Discussed with @suo offline, we will need to sync with yf225 more to figure out a solution.

@suo suo removed this from HIGH PRIORITY in JIT Triage Aug 17, 2020
@0xJchen
Copy link

0xJchen commented Dec 29, 2021

Discussed with @suo offline, we will need to sync with yf225 more to figure out a solution.

Hi, guys. I wonder if there are any workouts to avoid this problem?

I found recursive data types are common in many algorithms (as mentioned in PEP 484). Take tree search as an example, we have a node class, and it also has children (let's say, a mapping from index to new nodes) self.children: Dict[int, Node]={}. I met the above problem when trying to convert a tree-search algorithm to Torchscript, and I wonder if there are any ways to avoid it?

This is an example:


@torch.jit.script
class Node(object):

    def __init__(self, prior: float):
        self.children: Dict[int, 'Node'] = { }

    def expanded(self) -> bool:
        return len(self.children) > 0

    def expand(self, priors):
        for id, prior in enumerate(priors):
            self.children[id] = Node(prior )

Also, I wonder if I wrote a Cython script for managing the above tree search logic (I can directly call those Cython functions after compilation in a normal python script), can I still re-use them in Torchscript?

cc @yf225 @suo @gmagogsfm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module weeks
Projects
None yet
Development

No branches or pull requests

6 participants