Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Enabling multiple gpu causes AssertionError #1385

Closed
heury opened this issue Nov 25, 2022 · 6 comments · Fixed by #1509
Closed

[BUG] Enabling multiple gpu causes AssertionError #1385

heury opened this issue Nov 25, 2022 · 6 comments · Fixed by #1509
Labels
bug Something isn't working
Projects

Comments

@heury
Copy link

heury commented Nov 25, 2022

Describe the bug
Enabling multiple gpus works fine on fitting model. But, it causes an error on doing predict,historical_forecasts, backtest.
Single gpu works fine on both stage.

pytorch_lightning/overrides/distributed.py", line 91, in init
self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas))
self.total_size = len(self.dataset)
assert self.num_samples >= 1 or self.total_size == 0

To Reproduce
I have 8 gpus. So, I enabled gpu option as follows

pl_trainer_kwargs={
"accelerator": "gpu",
"devices": [0, 1, 2, 3, 4, 5, 6, 7]
}

Expected behavior
It should be run without any problem on single or multi gpus.

System (please complete the following information):
Python 3.9.13 (main, Aug 25 2022, 23:26:10)
[GCC 11.2.0] :: Anaconda, Inc. on linux

open-clip-torch 2.7.0
pytorch-lightning 1.8.3.post0
torch 1.13.0+cu117
torchaudio 0.13.0+cu117
torchmetrics 0.9.3
torchvision 0.14.0+cu117
darts 0.22.0

Additional context
Add any other context about the problem here.

@heury heury added bug Something isn't working triage Issue waiting for triaging labels Nov 25, 2022
@heury
Copy link
Author

heury commented Nov 26, 2022

This error is gone when I input at least 8 time series data which is equal as gpu count.

valid_data = [valid_scaled[:100], valid_scaled[1:101], valid_scaled[2:102], valid_scaled[3:103], valid_scaled[4:104], valid_scaled[5:105], valid_scaled[6:106], valid_scaled[7:107], valid_scaled[8:108]]

pred_series = trained_model.predict(n=5, num_loader_workers=8, series=valid_data, verbose=True, n_jobs=-1)

File "/home/jovyan/anaconda3/envs/dreambooth/lib/python3.9/site-packages/darts-0.22.0-py3.9.egg/darts/models/forecasting/torch_forecasting_model.py", line 1184, in predict_from_dataset
    return [ts for batch in predictions for ts in batch]
TypeError: 'NoneType' object is not iterable

But, Predicting DataLoader takes too much time and predicting causes None error.
I have two questions about it.

Do I have to input time series data more than gpu counts when multiple gpus is enabled?
Why dataloader takes too much time in this case and predicting returns None?

@hrzn hrzn removed the triage Issue waiting for triaging label Nov 30, 2022
@hrzn hrzn added this to To do in darts via automation Nov 30, 2022
@erik-hasse
Copy link
Contributor

erik-hasse commented Dec 6, 2022

I'm having the same issue with a TFT model. As a minimal reproduction:

from darts.datasets import AirPassengersDataset
from darts.models import TFTModel

def main():
    series = AirPassengersDataset().load()
    model = TFTModel(
        input_chunk_length=12,
        output_chunk_length=1,
        n_epochs=1,
        pl_trainer_kwargs={
            "accelerator": "gpu",
            "devices": 4,
        },
        add_relative_index=True
    )

    model.fit(series=series)
    model.predict(12) # AssertionError: assert self.num_samples >= 1 or self.total_size == 0

if __name__ == "__main__":
    main()

I'm using an AWS g4dn.12xlarge instance, which has 4 GPUs. Here's the full traceback:

Traceback (most recent call last):
  File "/home/ubuntu/mre.py", line 25, in <module>
    main()
  File "/home/ubuntu/mre.py", line 22, in main
    model.predict(10)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/darts/models/forecasting/tft_model.py", line 1145, in predict
    return super().predict(n, *args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/darts/utils/torch.py", line 112, in decorator
    return decorated(self, *args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 1051, in predict
    predictions = self.predict_from_dataset(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/darts/utils/torch.py", line 112, in decorator
    return decorated(self, *args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 1181, in predict_from_dataset
    predictions = self.trainer.predict(self.model, pred_loader)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 850, in predict
    return call._call_and_handle_interrupt(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 113, in launch
    mp.start_processes(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 139, in _wrapping_function
    results = function(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 897, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1139, in _run_stage
    return self._run_predict()
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1189, in _run_predict
    self.reset_predict_dataloader(self.lightning_module)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1627, in reset_predict_dataloader
    self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 377, in _reset_eval_dataloader
    dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 377, in <listcomp>
    dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 283, in _prepare_dataloader
    sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 300, in _resolve_sampler
    sampler = self._get_distributed_sampler(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 339, in _get_distributed_sampler
    sampler = cls(dataloader.sampler, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/overrides/distributed.py", line 116, in __init__
    super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/pytorch_lightning/overrides/distributed.py", line 90, in __init__
    assert self.num_samples >= 1 or self.total_size == 0
AssertionError

@erik-hasse
Copy link
Contributor

While digging I found that this is actually tied to the devices trainer kwarg, not GPUs specifically. The same error occurs on CPU if you set devices>1 and leave accelerator unset (or set to "cpu").

Here's a super hack-y workaround by overwriting the model's trainer after fitting:

from darts.datasets import AirPassengersDataset
from darts.models import TFTModel
import pytorch_lightning as pl

def main():
    series = AirPassengersDataset().load()
    model = TFTModel(
        input_chunk_length=12,
        output_chunk_length=1,
        n_epochs=1,
        pl_trainer_kwargs={
            "accelerator": "gpu",
            "devices": 4,
        },
        add_relative_index=True
    )

    model.fit(series=series)
    model.trainer = pl.Trainer(**{**model.trainer_params, "devices": 1})
    model.predict(12)

if __name__ == "__main__":
    main()

This will work for experimentation, but it's definitely not ideal

@hrzn
Copy link
Contributor

hrzn commented Jan 4, 2023

Is this still an issue with 0.23.0?

@erik-hasse
Copy link
Contributor

I just confirmed that it's still happening in 0.23:

# ... training logs snipped

darts.__version__='0.23.0'
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

Traceback (most recent call last):
  File "mre.py", line 22, in <module>
    main()
  File "mre.py", line 19, in main
    model.predict(12) # AssertionError: assert self.num_samples >= 1 or self.total_size == 0
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/darts/models/forecasting/tft_model.py", line 1147, in predict
    return super().predict(n, *args, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/darts/utils/torch.py", line 112, in decorator
    return decorated(self, *args, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 1067, in predict
    predictions = self.predict_from_dataset(
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/darts/utils/torch.py", line 112, in decorator
    return decorated(self, *args, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 1198, in predict_from_dataset
    predictions = self.trainer.predict(self.model, pred_loader)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 949, in predict
    return self._call_and_handle_interrupt(
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 107, in launch
    mp.start_processes(
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 133, in _wrapping_function
    results = function(*args, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in _predict_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1251, in _run_stage
    return self._run_predict()
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_predict
    self.reset_predict_dataloader(self.lightning_module)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1961, in reset_predict_dataloader
    self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 385, in _reset_eval_dataloader
    dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 385, in <listcomp>
    dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 295, in _prepare_dataloader
    sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 308, in _resolve_sampler
    sampler = self._get_distributed_sampler(
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 347, in _get_distributed_sampler
    sampler = cls(dataloader.sampler, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/overrides/distributed.py", line 165, in __init__
    super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)
  File "/Users/Erik/.pyenv/versions/multistep-forecaster/lib/python3.8/site-packages/pytorch_lightning/overrides/distributed.py", line 85, in __init__
    assert self.num_samples >= 1 or self.total_size == 0
AssertionError

With this script:

import darts
from darts.datasets import AirPassengersDataset
from darts.models import TFTModel

def main():
    series = AirPassengersDataset().load()
    model = TFTModel(
        input_chunk_length=12,
        output_chunk_length=1,
        n_epochs=1,
        pl_trainer_kwargs={
            "devices": 4,
        },
        add_relative_index=True
    )

    model.fit(series=series)
    print(f"{darts.__version__=}")
    model.predict(12) # AssertionError: assert self.num_samples >= 1 or self.total_size == 0

if __name__ == "__main__":
    main()

@solalatus
Copy link
Contributor

There is a temporary workaround here: #1287 (comment)

darts automation moved this from To do to Done Feb 21, 2023
@dennisbader dennisbader moved this from Done to Released in darts May 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
darts
Released
Development

Successfully merging a pull request may close this issue.

4 participants