-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[OpenVINO backend] support tri, triu, and tril #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[OpenVINO backend] support tri, triu, and tril #21408
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21408 +/- ##
==========================================
- Coverage 82.72% 77.20% -5.52%
==========================================
Files 565 565
Lines 55219 55562 +343
Branches 8608 8671 +63
==========================================
- Hits 45682 42899 -2783
- Misses 7427 10601 +3174
+ Partials 2110 2062 -48
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
f178511
to
3bbbf99
Compare
@@ -126,7 +126,6 @@ NumpyOneInputOpsCorrectnessTest::test_trace | |||
NumpyOneInputOpsCorrectnessTest::test_transpose | |||
NumpyOneInputOpsCorrectnessTest::test_tril | |||
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer | |||
NumpyOneInputOpsCorrectnessTest::test_triu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumpyDtypeTest::test_tri needs to be removed as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, all are supported and passed tests.
d0c86c4
to
48a910e
Compare
@rkazants |
48a910e
to
96408fa
Compare
dtype = "float32" | ||
|
||
ov_dtype = OPENVINO_DTYPES[dtype] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add comments with explanation for each block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!, comments added.
c4c0bb6
to
7da7c2c
Compare
keras/src/backend/openvino/numpy.py
Outdated
# Mask for lower triangle (col <= row + k) | ||
k_const = ov_opset.constant(k, Type.i32) | ||
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const)) | ||
mask = ov_opset.convert(mask, ov_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use tri
for mask computation? This way you can avoid code duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rkazants
Done!
keras/src/backend/openvino/numpy.py
Outdated
row_idx = ov_opset.broadcast(row_idx, target_shape) | ||
col_idx = ov_opset.broadcast(col_idx, target_shape) | ||
mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can create mask
without preliminary boradcasting. less_equal
supports numpy broadcasting internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls avoid extra operations and extra memory usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right, apologies for that. That was my old implementation. I've updated it to the more efficient one.
753f15b
to
4349bdc
Compare
keras/src/backend/openvino/numpy.py
Outdated
ov_opset.constant(1, Type.i32), | ||
output_type=Type.i32, | ||
) | ||
return ov_opset.gather(shape, indices, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks to return the same value as shape = ov_opset.shape_of(x, Type.i32)
. No need in get_shape_dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
4349bdc
to
5de788d
Compare
Hi @rkazants
I've supported
tri
,triu
andtril
, and they are ready for review.