Skip to content

Commit

Permalink
Update on "fix embedding_backward_dense decomp with broadcasting"
Browse files Browse the repository at this point in the history
Fixes #95182

cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire ngimel for another decomp fix. For this one, I tried auditing the CPU and CUDA kernels for `embedding_backward_dense` and just could not figure out where the `unsqueeze(1)` was supposed to be coming from. In the failing example, our tensor shapes are `(2, 4, 3)` and `(2, 4)`, and so I just assumed that the existing decomp had a typo - we should be unsqueezing the last dim, instead of dim index 1. That fixes the repro, and the existing decomp + meta tests appear to be passing.




cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
  • Loading branch information
bdhirsh committed Feb 27, 2023
2 parents fe31fb2 + a30b96a commit 4919634
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies {
implementation 'androidx.appcompat:appcompat:1.0.0'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
implementation 'com.google.code.findbugs:jsr305:3.0.1'
implementation 'com.facebook.soloader:nativeloader:0.10.4'
implementation 'com.facebook.soloader:nativeloader:0.10.5'

implementation 'junit:junit:' + rootProject.junitVersion
implementation 'androidx.test:core:' + rootProject.coreVersion
Expand Down
12 changes: 11 additions & 1 deletion .github/scripts/trymerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,17 @@ def find_matching_merge_rule(
reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
raise RuntimeError(reject_reason)
checks = get_combined_checks_from_pr_and_land_validation(pr, land_check_commit)
checks = get_classifications(pr.last_commit()['oid'], pr.get_merge_base(), checks, flaky_rules)
base_rev = None
try:
# is allowed to fail if git is not available
base_rev = pr.get_merge_base()
except Exception as e:
print(
f"Failed fetching base git revision for {pr.pr_num}. Skipping additional classifications.\n"
f"{type(e)}\n{e}"
)
if base_rev is not None:
checks = get_classifications(pr.last_commit()['oid'], base_rev, checks, flaky_rules)

# PRs can fail multiple merge rules, but it only needs to pass one rule to be approved.
# If it fails all rules, we need to find the rule that it came closest to passing and report
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/mac-mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}

macos-12-py3-arm64-mps-test:
if: false
name: macos-12-py3-arm64-mps
uses: ./.github/workflows/_mac-test-mps.yml
needs: macos-12-py3-arm64-build
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,13 @@ jobs:
name: macos-12-py3-arm64-mps
uses: ./.github/workflows/_mac-test-mps.yml
needs: macos-12-py3-arm64-build
if: needs.macos-12-py3-arm64-build.outputs.build-outcome == 'success'
if: false && needs.macos-12-py3-arm64-build.outputs.build-outcome == 'success'
with:
sync-tag: macos-12-py3-arm64-mps-test
build-environment: macos-12-py3-arm64

macos-12-py3-arm64-test:
if: false
name: macos-12-py3-arm64
uses: ./.github/workflows/_mac-test.yml
needs: macos-12-py3-arm64-build
Expand Down
4 changes: 2 additions & 2 deletions android/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ dependencies {
implementation(name:'pytorch_android', ext:'aar')
implementation(name:'pytorch_android_torchvision', ext:'aar')
...
implementation 'com.facebook.soloader:nativeloader:0.10.4'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
}
```
We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.4'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`).

You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
Expand Down
2 changes: 1 addition & 1 deletion android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ allprojects {
junitVersion = "4.12"

fbjniJavaOnlyVersion = "0.2.2"
soLoaderNativeLoaderVersion = "0.10.4"
soLoaderNativeLoaderVersion = "0.10.5"
}

repositories {
Expand Down
4 changes: 2 additions & 2 deletions android/test_app/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ tasks.all { task ->

dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.10.4'
implementation 'com.facebook.soloader:nativeloader:0.10.5'

localImplementation project(':pytorch_android')
localImplementation project(':pytorch_android_torchvision')
Expand All @@ -154,7 +154,7 @@ dependencies {

aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
aarImplementation 'com.facebook.soloader:nativeloader:0.10.4'
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'

def camerax_version = "1.0.0-alpha05"
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
TORCH_INTERNAL_ASSERT(
!schema.hasAnyAliasInfo(),
"mutating and aliasing ops should all have codegen'd kernels. op name: ",
op.operator_name().name, ".", op.operator_name().overload_name);
TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels");
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
Expand Down

0 comments on commit 4919634

Please sign in to comment.