Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unpacking python dictionary in torch.jit.trace() #81623

Closed

Conversation

tangleintel
Copy link
Contributor

@tangleintel tangleintel commented Jul 18, 2022

Support unpacking python dictionary in torch.jit.trace()

Problem statement & Motivation

Problem 1(usability):

Say, if you have a model and its forward method defined as follows:
def forward(self, key1=value1, key2=value2, key3=value3)
And you have a dataset and each data point in the dataset is a python dict as follows:
data = {key1:value1, key3:value3, key2:value2}

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as data_tuple = (value1, value2, value3) as the example_inputs parameter of torch.jit.trace(). This marshalling process is not user friendly.

Problem 2 (feasibility):

Say, if you have a model and its forward method defined as follows:
def forward(self, key1=None, key2=None, key3=None) -> The default value is None
And you have a dataset and each data point in the dataset is a python dict as follows:
data = {key1:value1, key3:value3} -> Only part of the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like T1 = (value1, value3) nor T2 = (value1, None, value3). T1 will mismatch value3 with key2 and T2 include None type which will be blocked by tracer's type checking. (Of course you can pass T3 = (value1,) to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as MRPC, MNLI etc.

Solution

To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**"
operator.

Limitation & Mitigation

  1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:

# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

Note

  1. This PR will make some UT introduced in 39601 fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
  2. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in torch.jit.trace()'s documentation, but it seems we can still passing a dictionary.

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @mingfeima @XiaobingSuper @ashokei @jingxu10

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 18, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit bf3e897 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@tangleintel
Copy link
Contributor Author

Hi,
@davidberard98
@jamesr66a
@wanchaol
@xw285cornell
@gmagogsfm
@suo
@zdevito
Could you pls help review this PR? Thx!

@yanbing-j yanbing-j added the intel This tag is for PR from Intel label Jul 18, 2022
@bdhirsh bdhirsh requested a review from eellison July 20, 2022 01:57
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 20, 2022
@yanbing-j yanbing-j added the intel priority matters to intel architecture from performance wise label Jul 27, 2022
@tangleintel tangleintel changed the title Proposal for a more general usage of the parameter "example_inputs" in torch.jit.trace() Proposal for an automatic parameter matching mechanism for "example_inputs" argument in torch.jit.trace(). Aug 10, 2022
@tangleintel tangleintel changed the title Proposal for an automatic parameter matching mechanism for "example_inputs" argument in torch.jit.trace(). Support unpacking python dictionary in **torch.jit.trace()** Aug 11, 2022
@tangleintel tangleintel changed the title Support unpacking python dictionary in **torch.jit.trace()** Support unpacking python dictionary in torch.jit.trace() Aug 11, 2022
@yanbing-j
Copy link
Collaborator

Hi @eellison , could you please help review this PR?

@eellison
Copy link
Contributor

Hi, thank you for the PR! I think traditionally we haven't done this because if you change the keys it might silently succeed without error. Would it be possible to add something similar ? Could you add a test ?

Will do more thorough review tomorrow.

@tangleintel
Copy link
Contributor Author

tangleintel commented Aug 24, 2022

Hi, thank you for the PR! I think traditionally we haven't done this because if you change the keys it might silently succeed without error. Would it be possible to add something similar ? Could you add a test ?

Will do more thorough review tomorrow.

Hi, @eellison
Thanks for your replay! Can you be more specific on "change the keys" and the correspond case that will silently succeed without error? If you mean change the keys name, I think tracer will report error msg suggesting the name not matching with the arguments name since we can capture and record forwards arguments name. If you mean change the keys (serialization)order, I am not sure if it is the same situation which I noted in the Limitation & Mitigation section.
I have added one UT which correspond the problem 2 I mentioned above. I can add more if you think it's not informative.

@tangleintel
Copy link
Contributor Author

Hi, @eellison @davidberard98
Any comments about this PR?

@eellison
Copy link
Contributor

Hi, do you mind adding a test like the following ?

def foo(x, y):
    return x + y

out = torch.jit.trace({'x': torch.rand([2]), 'y': torch.rand([2])})
# following should fail
out({'z': torch.rand([2]), 'z': torch.rand([2])}


@tangleintel
Copy link
Contributor Author

tangleintel commented Sep 5, 2022

Hi, do you mind adding a test like the following ?

def foo(x, y):
    return x + y

out = torch.jit.trace({'x': torch.rand([2]), 'y': torch.rand([2])})
# following should fail
out({'z': torch.rand([2]), 'z': torch.rand([2])}

@eellison
Sure, I have added the example you mentioned above. Please check it.(UT)

@tangleintel
Copy link
Contributor Author

tangleintel commented Sep 10, 2022

Hi, @eellison

def test_dictionary_as_example_inputs_for_jit_trace(self):
        class TestModule_v1(torch.nn.Module):
            def __init__(self):
                super(TestModule_v1, self).__init__()

            def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
                return key1 + key2 + key3

        class TestModule_v2(torch.nn.Module):
            def __init__(self):
                super(TestModule_v2, self).__init__()

            def forward(self, x, y):
                return x + y

        model_1 = TestModule_v1()
        model_2 = TestModule_v2()
        value1 = torch.ones(1)
        value2 = torch.ones(1)
        value3 = torch.ones(1)
        example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
        traced_model_1 = torch.jit.trace(model_1, example_input_dict, strict=False)
        traced_model_2 = torch.jit.trace(model_2, {'x': torch.rand([2]), 'y': torch.rand([2])})
        res_1 = traced_model_1(**example_input_dict)
        self.assertEqual(res_1, 3 * torch.ones(1)) # Positive
        with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'x'."):
            res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) # Negative
        with self.assertRaisesRegex(RuntimeError, "forward\(\) is missing value for argument 'y'."):
            res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) # Negative

Above code example is the complete unit test I added for this feature. TestModule_v1 class wants to show the positive example while the TestModule_v2 wants to show the negative example which I think is what you mean. Is that right?

@tangleintel
Copy link
Contributor Author

Hi, @eellison
Sorry for the notification again. May I know when we can review the UT and the PR?

@davidberard98
Copy link
Contributor

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased support_dict_for_jit_trace onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout support_dict_for_jit_trace && git pull --rebase)

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81623

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 073861c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tangleintel , sorry this has taken so long to review. I appreciate the contribution and this does look like a very helpful feature!

I added a few requests inline. In particular I think this will break some use cases (the ones that you already had to patch in the tests). I think it would be better if we guard this behavior with an extra parameter to avoid breaking existing use cases.

Comment on lines 2359 to 2362
example_input = list()
example_input.append(input_map)
example_input = tuple(example_input)
traced_model = torch.jit.trace(model, example_inputs=example_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid breaking this type of use case? This same usage would break for other users with similar usage patterns (i.e. trace(model, example_inputs={'forward': dict})).

Instead, can we provide a flag that disables the dict -> tuple(dict) conversion? e.g. torch.jit.trace(model, inputs, unpack_input_dict=True). And unpack_input_dict would default to False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @davidberard98
Thanks for your review!
For this use case, I personally think it is not appropriate. My protest opinion comes from below two points:

  1. The current PyTorch documentation for torch.jit.trace() specifies that the example_inputs parameter should come with the type of tuple or torch.Tensor. I'm not sure these two type for example_inputs argument is a must or just a suggestion. If it is a must, this usage case shouldn't been included at first. At least, this undocumented behavior shouldn't been view as a standard.
  2. Our solution is for the problem of how to pass variable args(positional) and kwargs in torch.jit.trace() for a given module. In this respective, a tuple is used for a packed of positional args while dict is used for pass keyword arguments, and dict used by this example should be viewed as an element of tuple since tracer support tuple with nested collection types. So, our dict is not in the same dimension with the this one.

For your suggestion with an additional argument to address backward compatibility:

  1. I don't know if it is important to maintain this undocumented behavior for PyTorch with such a flag.
  2. With this design of API, we need to make the user to be aware that the dict would substantially comes with two different meanings and the semantic of example_inputs parameter should depend on another one, that is, unpack_input_dict. For my opinion, making the semantics of different parameters as orthogonal as possible will give users a cleaner API.
  3. Inside the tracer, we have to deal with the case or a single tensor and dict. If we give user a unified front end interface, maybe we don't have to deal with such corner case.

Anyway, above is just my opinion and I think the decision of how to design the interface is up to PyTorch team.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tangleintel I definitely see your point; however I discussed with some other pytorch maintainers who agree with me - i.e. even though this isn't documented, torch.jit.trace is widely used and we'd prefer to avoid breaking use cases that depend on the current behavior.

We can probably deprecate the old behavior with dict inputs, e.g. see the 180 days + 2 releases policy. However I'd probably need your help pushing an update after 180 days, since I'd probably forget :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidberard98
OK, then let's follow your suggestion to maintain the backward compatibility and follow the Pytorch's BC-breaking rules to mark the undocumented behavior as deprecated. Is it OK for you?
BTW, according to the BC-breaking rules, if we want to deprecate an old behavior, we need specify which pytorch version the deprecation first takes effect in a warning msg. So, which version you think this feature can be landed in(and the deprecated in effect version is at least plus 2 right? I give it 2.1 in the updated patch's warning msg) ?
Yes, we will consistently track this feature and give the update when time is right.

torch/csrc/jit/python/script_init.cpp Outdated Show resolved Hide resolved
Comment on lines 93 to 116
// The argument_names parameter is parsed in python and its order
// is the same as the arguments' decalaration order in forward() method.
// These name shall be added to the graph as debug name and the order
// should align with the traceable stack we generated by the python dict.
std::vector<std::string> reordered_argument_names;
for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
for (size_t i = 0; i < argument_names.size(); i++) {
if (py::cast<std::string>(it->first) == argument_names[i]) {
reordered_argument_names.push_back(argument_names[i]);
break;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reorder the inputs instead of the argument names? This might make more sense to people when they trace a model, so that their arguments don't get reordered compared to the method definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @davidberard98
I think we have to do this reordering.
According to this code snippet:

// When enough argument name hints are provided, use them as debug names
// for traced function/modules.
// Here argument_names is allowed to have more names than needed because
// some arguments may have valid default values, therefore they don't need
// example inputs.
if (argument_names.size() >= inputs.size()) {
  for (size_t i = 0, e = inputs.size(); i < e; ++i) {
    IValue& input = inputs[i];
    input = addInput(
        state,
        input,
        input.type(),
        state->graph->addInput(argument_names[i]));
  }
} else {
  for (IValue& input : inputs) {
    input = addInput(state, input, input.type(), state->graph->addInput());
  }
}

Pytorch will set the debug name of the traced torchscript according to the real input's order, in fact, I think what matters is the number of the real input other than the order. Use the example I used in the UT and RFC's Problem 2 in the first section, there was missing arguments, which cause
the input and argments_name not match with each other. So, one of the intention here is to compact the argument_name and make them align with inputs in length.(I think it's impossible(at least I haven't come out a decent method) to expand the input with the default value of the missing args.), and another is reorder.
The debug name will be used in _check_trace() after the tracing procedure.
This is more or less a workaround for this issue. Do you have any suggestions for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think I understand... Can you do something like this instead?

for (i in 0..argument_names.size()) {
  if (argument_names[i] in input_dict) {
    reordered_argument_names.push_back(argument_names[i])
  }
}

Or am I misunderstanding the problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @davidberard98
This way will not reorder the arguments, just compact them. But I think this way make sense to users, so I followed it and this need modify the traced_input at the same time.

Comment on lines 977 to 1035
if isinstance(example_inputs, dict):
# Raise exception when the user provided key names are not aligned with forward() method's arguments' name/
for key in example_inputs:
if key not in argument_names:
valid_arguments = "[" + ','.join(argument_names) + "]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also commented above, but I think it would be better if this behavior was guarded by an additional parameter passed into torch.jit.trace.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, please add documentation of this flag to the documentation for trace_method.

And, if you want to expose this flag in torch.jit.trace(), I think it would be preferable to also support unpacking python dicts in torch.jit.trace for functions as well. But if you don't want to add that I think it's fine, in that case don't add the flag to torch.jit.trace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, once we reach an agreement with the API design, I will do the corresponding modification and update the document.

value1 = torch.ones(1)
value2 = torch.ones(1)
value3 = torch.ones(1)
example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, how does it pick up on the fact that key4, key5, key6 are not needed? Is it based on the presence of a default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is

@davidberard98
Copy link
Contributor

also there's a failing lint, please fix that (I know you couldn't see this earlier because the check required an approval to run)

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor

@tangleintel fyi in case you have this issue again with broken tests on master, you can instead rebase on top of the viable/strict branch, which points to the latest commit for which all the tests have passed.

@tangleintel
Copy link
Contributor Author

@tangleintel fyi in case you have this issue again with broken tests on master, you can instead rebase on top of the viable/strict branch, which points to the latest commit for which all the tests have passed.

Thx @davidberard98 ! I just rebased.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2022
@davidberard98
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @tangleintel.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@tangleintel
Copy link
Contributor Author

@pytorchbot label jit

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 17, 2022

Didn't find following labels among repository labels: jit

@tangleintel
Copy link
Contributor Author

@pytorchbot label module: python frontend

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 17, 2022

Didn't find following labels among repository labels: module:,frontend

@pytorch-bot pytorch-bot bot added python Pull requests that update Python code release notes: jit release notes category labels Nov 17, 2022
neggles pushed a commit to neggles/pytorch that referenced this pull request Mar 9, 2023
…) (pytorch#99)

# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: pytorch#81623
Approved by: https://github.com/davidberard98

Co-authored-by: tangleintel <lei1.tang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed intel priority matters to intel architecture from performance wise intel This tag is for PR from Intel Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source python Pull requests that update Python code release notes: jit release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants