Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Associate fx node as origins on IRNode, ensure all nodes going to scheduling have an associate FX Node #754

Merged
merged 9 commits into from
Aug 11, 2022

Conversation

voznesenskym
Copy link
Contributor

@voznesenskym voznesenskym commented Aug 9, 2022

An example of what these IRNodes look like now, running test_alexnet_prefix_cpu

StorageBox(
  ComputedBuffer(name=None, layout=FlexibleLayout('cpu', torch.float32, size=[s3, 64, 55, 55], stride=[193600, 3025, 55, 1]), data=Pointwise(
    'cpu',
    torch.float32,
    relu(load(buf0, 193600*i0 + 3025*i1 + 55*i2 + i3, False) + load(primals_1, i1, False)),
    ranges=[s3, 64, 55, 55],
    origins={relu_default}
  ))
)
StorageBox(
  ComputedBuffer(name=None, layout=FlexibleLayout('cpu', torch.float32, size=[s3, 64, 27, 27], stride=[46656, 729, 27, 1]), data=Pointwise(
    'cpu',
    torch.float32,
    maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 112, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 111, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 110, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 57, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 56, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 55, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 2, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 1, False), load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3, False))))))))),
    ranges=[s3, 64, 27, 27],
    origins={max_pool2d_with_indices_default}
  ))
)
Convolution(
  name=buf0,
  layout=FixedLayout('cpu', torch.float32, size=[s3, 64, 55, 55], stride=[193600, 3025, 55, 1]),
  inputs=[InputBuffer(name='primals_3', layout=FixedLayout('cpu', torch.float32, size=[s3, s1, s4, s4], stride=[s1*s4**2, s4**2, s4, 1])), InputBuffer(name='primals_2', layout=FixedLayout('cpu', torch.float32, size=[s0, s1, s2, s2], stride=[s1*s2**2, s2**2, s2, 1]))],
  constant_args=(None, (4, 4), (2, 2), (1, 1), False, (0, 0), 1),
  output_view=None,
  origins={convolution_default}
)
ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[s3, 64, 55, 55], stride=[193600, 3025, 55, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  relu(load(buf0, 193600*i0 + 3025*i1 + 55*i2 + i3, False) + load(primals_1, i1, False)),
  ranges=[s3, 64, 55, 55],
  origins={relu_default}
))
ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[s3, 64, 27, 27], stride=[46656, 729, 27, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 112, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 111, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 110, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 57, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 56, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 55, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 2, False), maximum(load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3 + 1, False), load(buf1, 193600*i0 + 3025*i1 + 110*i2 + 2*i3, False))))))))),
  ranges=[s3, 64, 27, 27],
  origins={max_pool2d_with_indices_default}
))

@@ -201,7 +201,13 @@ def is_triton(device):


class IRNode(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just put everything on the IRNode base class? Having things scattered on specific node constructors doesn't seem great.

Something like:

class IRNode:
   _current_origins = set()
   @staticmethod
   @contexlib.contextmanager
   def current_origin(origins: Set[fx.Node])
      old = IRNode._current_origins
      IRNode._current_origins = old | origins
      yield
      IRNode._current_origins = old

   def __init__(self):
     self.origins = set(self._current_origins)

Then we could just throw a:

with IRNode.current_origins({n}):
   ...

Around the entire body of run_node(n).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about it, but I wasn't sure it was right for all IRNode to have this property on it. Is there any case where a subclass of IRNode having an origins is nonsensical?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, every node can have an origins.

@@ -28,6 +28,8 @@
prims = torch.ops.prims
needs_realized_inputs = set()

current_origin = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

@voznesenskym voznesenskym marked this pull request as ready for review August 11, 2022 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants