Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minifier fails for wav2vec #1959

Closed
wconstab opened this issue Dec 5, 2022 · 4 comments
Closed

Minifier fails for wav2vec #1959

wconstab opened this issue Dec 5, 2022 · 4 comments
Assignees
Labels
bug Something isn't working minifier-did-not-work

Comments

@wconstab
Copy link
Contributor

wconstab commented Dec 5, 2022

🐛 Describe the bug

Building on pytorch/pytorch#93464,

the whole-model repro for wav2vec posted there fails to run with the minifier for me on master @ 2ea32f41f4b4c3d0cbb9834186fdfe404e0d4c2a

my w2v.py here is copied from pytorch/pytorch#93464, namely

 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 from datasets import load_dataset
 import torch
 
 # load model and processor
 processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
 model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
 model=torch.compile(model)
     
 # load dummy dataset and read soundfiles
 ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
 
 # tokenize
 input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
 
 # retrieve logits
 logits = model(input_values).logits
 
 # take argmax and decode
 predicted_ids = torch.argmax(logits, dim=-1)
 transcription = processor.batch_decode(predicted_ids)

Error logs

No response

Minified repro

TORCHDYNAMO_REPRO_AFTER="aot" python w2v.py produces a minifier_launcher, but running it claims
python torchdynamo_debug/run_2022_12_05_23_06_50_594983/minifier/minifier_launcher.py

Traceback (most recent call last):
  File "torchdynamo_debug/run_2022_12_05_23_06_50_594983/minifier/minifier_launcher.py", line 205, in <module>
    minifier(
  File "/scratch/whc/work/pytorch/torch/_functorch/fx_minifier.py", line 96, in minifier
    raise RuntimeError("Input graph did not fail the tester")
RuntimeError: Input graph did not fail the tester

TORCHDYNAMO_REPRO_AFTER="dynamo" python w2v.py actually segfaults for me

@wconstab wconstab added the bug Something isn't working label Dec 5, 2022
@wconstab
Copy link
Contributor Author

wconstab commented Dec 5, 2022

oops, i forgot to use TORCHDYNAMO_REPRO_LEVEL=4. But that mode fails too

[2022-12-05 23:25:44,800] torch._dynamo.debug_utils: [WARNING] While minifying the program in accuracy minification mode,ran into a runtime exception which is likely an unrelated issue. Skipping this graph.

full log: https://gist.github.com/wconstab/4fc3caf44b679e3994622ddbc191dfc0

@anijain2305
Copy link
Contributor

You don't need level 4. Level 4 is for accuracy. The command that you have is correct.

Your observation is still correct. There is a minifier_launcher.py but it does not fail.

I can assign it to myself.

@anijain2305
Copy link
Contributor

anijain2305 commented Dec 5, 2022

@wconstab I tried TORCHDYNAMO_REPRO_AFTER="dynamo", and that worked. (The graph is larger than if repro_after "aot" worked, so I will fix that one as well).


from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models

# REPLACEABLE COMMENT FOR TESTING PURPOSES

args = [((1, 93680), (93680, 1), torch.float32, 'cpu', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_self_feature_extractor_conv_layers_0_conv = Conv1d(1, 512, kernel_size=(10,), stride=(5,))



    def forward(self, input_values : torch.Tensor):
        getitem = input_values[(slice(None, None, None), None)];  input_values = None
        self_self_feature_extractor_conv_layers_0_conv = self.self_self_feature_extractor_conv_layers_0_conv(getitem);  getitem = None
        return (self_self_feature_extractor_conv_layers_0_conv,)



mod = Repro()
opt_mod = torch._dynamo.optimize("inductor")(mod)


with torch.cuda.amp.autocast(enabled=False):
    ref = run_fwd_maybe_bwd(mod, args)
    res = run_fwd_maybe_bwd(opt_mod, args)


@anijain2305
Copy link
Contributor

Real bug for the model - pytorch/pytorch#90260

Why does minifier (repro_after="aot") does not work? Again the same reason as above issue. The small difference in code path between minified_repro and compiler led to this rare divergence.

kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Fixes pytorch/torchdynamo#1959, pytorch#90260
However, I wasn't able to make existing stride tests fail before the fix, even though I'm comparing all, not just significant strides.
Separately running refs on meta tensors produces wrong strides as shown in pytorch#90260, however, it looks like in meta tests some other way of computing meta info is used (I've been running
```
pytest -s -v test/test_meta.py -k test_meta_outplace_expand_cuda_float64
```
and verified that it has sample input that should fail, and that it indeed compares all the strides, but the produced `meta_rs` results somehow still had correct strides).

Edit: @SherlockNoMad helped me figure out how to fail the tests, and now I've set the correct ops for checking. `expand` fails for some test inputs because it special-cases 0-dim input case, correctly modeling it in prims would require a lot of changes, so skipping that for now.

Pull Request resolved: pytorch#90341
Approved by: https://github.com/SherlockNoMad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working minifier-did-not-work
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants