Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very high memory usage when training Stage 1 #10

Closed
godspirit00 opened this issue Oct 7, 2023 · 15 comments
Closed

Very high memory usage when training Stage 1 #10

godspirit00 opened this issue Oct 7, 2023 · 15 comments

Comments

@godspirit00
Copy link

godspirit00 commented Oct 7, 2023

Hello,
Thanks for the great work.
I'm trying to train a model on my dataset using an A5000 (24GB VRAM). I kept getting OOM at the beginning of Stage 1. I kept reducing batch size, and finally, the training could go on with a batch size of 4.
Is this normal? What hardware were you using?
Thanks!

@yl4579
Copy link
Owner

yl4579 commented Oct 7, 2023

I used 4 A100 with a batch size of 32 for the paper, and the checkpoint I shared was trained with 4 L40 with a batch size of 16. You can either decrease the batch size to 4 (as you only have one GPU), or you can decrease the max_len, which now is equivalent to 5 seconds. You definitely don't need that long clip for training.

@yl4579 yl4579 closed this as completed Oct 7, 2023
@yl4579
Copy link
Owner

yl4579 commented Oct 8, 2023

Also for the first stage, you can try mixed precision, and it doesn't seem to decrease the reconstruction quality in my experience. It results in much faster training and half of the RAM use. All you need is accelerate launch --mixed_precision=fp16 train_first.py --config_path ./Configs/config.yml. This is the stage that takes the most time anyway so using mixed precision is always good practice.

@godspirit00
Copy link
Author

Also for the first stage, you can try mixed precision, and it doesn't seem to decrease the reconstruction quality in my experience. It results in much faster training and half of the RAM use. All you need is accelerate launch --mixed_precision=fp16 train_first.py --config_path ./Configs/config.yml. This is the stage that takes the most time anyway so using mixed precision is always good practice.

Thanks! I'll give a try!

@godspirit00
Copy link
Author

Got another error when using mixed precision:

Traceback (most recent call last):
  File "train_first.py", line 444, in <module>
    main()
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "train_first.py", line 305, in main
    optimizer.step('pitch_extractor')
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in step
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in <listcomp>
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 39, in _step
    self.optimizers[key].step()
  File "/root/miniconda3/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

What seems to be the problem? Thanks!

@yl4579
Copy link
Owner

yl4579 commented Oct 9, 2023

@godspirit00 It says the loss is NaN because you are using mixed precision. I think you may have to change the loss here: https://github.com/yl4579/StyleTTS2/blob/main/losses.py#L22 with return F.l1_loss(y_mag, x_mag).

@godspirit00
Copy link
Author

The error persists after changing the line.

Traceback (most recent call last):
  File "train_first.py", line 444, in <module>
    main()
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "train_first.py", line 305, in main
    optimizer.step('pitch_extractor')
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in step
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in <listcomp>
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 39, in _step
    self.optimizers[key].step()
  File "/root/miniconda3/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

It was at Epoch 50.

@yl4579
Copy link
Owner

yl4579 commented Oct 10, 2023

Could it be related to the discriminators, or is it a TMA issue? Can you set the TMA epoch to a higher number but set start_ds in train_first.py to True and see if it is a problem of discriminators?

@godspirit00
Copy link
Author

Can you set the TMA epoch to a higher number but set start_ds in train_first.py to True

I tried searching for start_ds in train_first.py, but I can't find it.

@yl4579
Copy link
Owner

yl4579 commented Oct 12, 2023

Sorry I mean you just modify the code to train the discriminator but not the aligner. But if it still doesn’t work, probably you have to do it without mixed precision. It is highly sensitive to batch size unfortunately, so it probably only works for large enough batches (like 16 or 32).

@stevenhillis
Copy link

Got another error when using mixed precision:

Traceback (most recent call last):
  File "train_first.py", line 444, in <module>
    main()
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/root/miniconda3/lib/python3.8/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "train_first.py", line 305, in main
    optimizer.step('pitch_extractor')
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in step
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 32, in <listcomp>
    _ = [self._step(key, scaler) for key in keys]
  File "/root/autodl-tmp/StyleTTS2/optimizers.py", line 39, in _step
    self.optimizers[key].step()
  File "/root/miniconda3/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

What seems to be the problem? Thanks!

I have traced this to

optimizer.step('pitch_extractor')
. The pitch extractor's optimizer is being called to step, but the pitch extractor itself is only ever called in a no_grad context, so there are no gradients. The assertion is in torch's grad_scaler code, which is why it only shows up in mixed_precision.

It does make me wonder, @yl4579 , if this is expected? Were the parameters of the pitch_extractor expected to get updated?

@yl4579
Copy link
Owner

yl4579 commented Oct 28, 2023

@stevenhillis I think you are correct. The pitch_extractor actually shouldn't be updated. Does removing this line fix this problem of mixed precision?

@stevenhillis
Copy link

Good deal. Sure does!

@RillmentGames
Copy link

I have traced this to

optimizer.step('pitch_extractor')

I had the same 'inf' crash problem at epoch 50 with mixed precision and Batch=4. I can confirm that this allows the training to continue past 50.

@Moonmore
Copy link

Moonmore commented Dec 6, 2023

I used 4 A100 with a batch size of 32 for the paper, and the checkpoint I shared was trained with 4 L40 with a batch size of 16. You can either decrease the batch size to 4 (as you only have one GPU), or you can decrease the max_len, which now is equivalent to 5 seconds. You definitely don't need that long clip for training.

@yl4579 Thanks for your great job.
I used 4 V100 with a batch size of 8 and max len of 200 with my dataset. When I set the batch size to 16, out of memory.
So I should reduce max len?How to balance max len and batch size, and how should I adjust this parameter?
Thanks a lot.

@yl4579
Copy link
Owner

yl4579 commented Dec 6, 2023

@Moonmore See #81 for a detailed discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants