-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unsqueeze op #1236
Unsqueeze op #1236
Conversation
…te_outputs function
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1236 +/- ##
==========================================
+ Coverage 84.41% 84.49% +0.07%
==========================================
Files 549 563 +14
Lines 61952 63340 +1388
==========================================
+ Hits 52295 53517 +1222
- Misses 9657 9823 +166 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank for adding this new OP!
There are minor issues that need to be cleaned up because merging. I haven't verified if all cases are covered - it seems the implementation was not trivial.
burn-import/src/onnx/from_onnx.rs
Outdated
//this is an extremely hacky temporary solution while I figure out how to properly handle this | ||
//situation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this be handled in this PR or later? We should replace with a TODO with instructions if not handled in this PR.
Also please explain the context why this is needed and if any potential issues it might cause down the road.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Down the road, the primary issue is speed for building larger graphs. I'm working on a better solution in #1296
The new function is necessary because the values for the axes to insert at aren't available unless the rhs of an unsqueeze op is constant. The only way to determine what the arguments of unsqueeze should be in that case is either:
- from the output shape if it's explicit in the graph, though in this case it's better to just remap to a reshape and avoid the runtime inference of a shape since we already know the result.
- at runtime which isn't covered here because correct me if I'm wrong, but I don't think burn supports runtime inference of shapes yet.
Added @nathanielsimard as a reviewer the new tensor OP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just commented on the added op, which should be simplified.
/// } | ||
/// ``` | ||
pub fn unsqueeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> { | ||
let mut new_dims = [1; D2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this can be implemented in a simpler way:
let new_dims = [1; D2];
let mut counter = 0;
for i in 0..D2 {
if !dims.contains(i) {
new_dims[i] = old_dims[counter];
counter += 1;
}
}
tensor.reshape(new_dims);
No allocation, and contains
is extremely fast on small slices (less than 10 elems).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that would work, but there might be a way to avoid the allocation.
the reason why I say that isdims can have negative values which needs to be converted to a usize.
I'm using a vec due to slices not having a len
function and I haven't figured out a way to infer the length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could avoid the allocation by making the second argument &'op mut [isize]
, and then mutate the new dims in place.
I just realized that the dims can also contain duplicates. The Onnx spec never specified that the values had to be unique, so I wrote it so that there could be duplicates.
essentially the thinking was unsqueeze dims was equivalent to doing a series of single dim unsqueezes, and if you executed those unsqueezes, it wouldn't matter what order the operations happened: 2 values of -1
would just result in two axes in the resulting tensor at the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my end it looks good.
Thank you for the improvements and OP addition. Burn ONNX is getting better!
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
None, just adding
unsqueeze_dims
to get us closer to fullONNX
supportChanges
adds a new function
unsqueeze_dims
toburn-tensor/src/tensor/api/base
, which takes a slice of isizes as the second argument (to make it compatible with the onnx unsqeeze op.For burn import, if the rhs of the op node is constant, the output shape is calculated, if it isn't and the output shape has an explicit value already, the op is replaced with a reshape.
it might be desireable to just remap to reshape in every case if the output shape is explicit. and/or to truncate multiple unsqueezes to a single unsqueeze.
as we were discussing on the discord a while ago, this implementation doesn't yet support the third case: where the dimensions of the output are symbolic (determined by the outputs of previous steps at runtime), but I'm not sure if supporting that would be possible in burn right now for any op.
Testing
There was a discrepancy between the unsqueeze function for torch and ONNX. Torch only supported as single axis argument, ONNX supported multiple, so I wrote some code to generate the Onnx model directly through Onnx helper and runtime. We might want to move it into it's own directory to use as a python module if we want to use it for other operations.
right now the onnx model takes a second input argument that is not present in the burn forward function, but I needed the second op to test remapping to a reshape node.
Graph: