Skip to content

Commit

Permalink
setup: switch to pytorch-lightning >= 1.8 (#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
entn-at committed Jan 20, 2023
1 parent 2a4e3bf commit cb7bfd5
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyannote/audio/cli/train.py
Expand Up @@ -37,7 +37,7 @@
RichProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from lightning_lite.utilities.seed import seed_everything
from torch_audiomentations.utils.config import from_dict as get_augmentation

from pyannote.audio.core.io import get_torchaudio_info
Expand Down
2 changes: 1 addition & 1 deletion pyannote/audio/core/model.py
Expand Up @@ -36,7 +36,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.cloud_io import load as pl_load
from lightning_lite.utilities.cloud_io import _load as pl_load
from pytorch_lightning.utilities.model_summary import ModelSummary
from semver import VersionInfo
from torch.utils.data import DataLoader
Expand Down
3 changes: 2 additions & 1 deletion pyannote/audio/pipelines/voice_activity_detection.py
Expand Up @@ -346,7 +346,8 @@ def configure_optimizers(model):
with tempfile.TemporaryDirectory() as default_root_dir:
trainer = Trainer(
max_epochs=self.num_epochs,
gpus=1,
accelerator="gpu",
devices=1,
callbacks=[GraduallyUnfreeze(epochs_per_stage=self.num_epochs + 1)],
enable_checkpointing=False,
default_root_dir=default_root_dir,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -9,7 +9,7 @@ pyannote.core >=4.4,<5.0
pyannote.database >=4.1.1,<5.0
pyannote.metrics >=3.2,<4.0
pyannote.pipeline >=2.3,<3.0
pytorch_lightning >=1.5.4,<1.7
pytorch_lightning >=1.8.0,<1.9
pytorch_metric_learning >=1.0.0,<2.0
rich >= 12.0.0
semver >=2.10.2,<3.0
Expand Down
8 changes: 4 additions & 4 deletions tests/tasks/test_reproducibility.py
@@ -1,4 +1,4 @@
import pytorch_lightning as pl
from lightning_lite.utilities.seed import seed_everything
import torch

from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel
Expand Down Expand Up @@ -30,7 +30,7 @@ def get_next5(dl):

def test_seeding_ensures_data_loaders():
"Setting a global seed for the dataloaders ensures that we get data back in the same order"
pl.seed_everything(1)
seed_everything(1)

for task in [VoiceActivityDetection, MultiLabelSegmentation]:
protocol, vad = setup_tasks(task)
Expand All @@ -50,12 +50,12 @@ def test_different_seeds():

for task in [VoiceActivityDetection, MultiLabelSegmentation]:
protocol, vad = setup_tasks(task)
pl.seed_everything(4)
seed_everything(4)
dl = create_dl(SimpleSegmentationModel, vad)
last5a = get_next5(dl)

protocol, vad = setup_tasks(task)
pl.seed_everything(5)
seed_everything(5)
dl = create_dl(SimpleSegmentationModel, vad)
last5b = get_next5(dl)

Expand Down
8 changes: 4 additions & 4 deletions tutorials/training_a_model.ipynb
Expand Up @@ -180,7 +180,7 @@
],
"source": [
"import pytorch_lightning as pl\n",
"trainer = pl.Trainer(gpus=1, max_epochs=1)\n",
"trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=1)\n",
"trainer.fit(vad_model)"
]
},
Expand Down Expand Up @@ -545,7 +545,7 @@
}
],
"source": [
"trainer = pl.Trainer(gpus=1, max_epochs=1)\n",
"trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=1)\n",
"trainer.fit(finetuned)"
]
},
Expand Down Expand Up @@ -781,7 +781,7 @@
}
],
"source": [
"trainer = pl.Trainer(gpus=1, max_epochs=1)\n",
"trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=1)\n",
"trainer.fit(osd_model)"
]
},
Expand Down Expand Up @@ -854,7 +854,7 @@
"We also benefit from all the nice things [`pytorch-lightning`](ttps://pytorch-lightning.readthedocs.io) has to offer (like multi-gpu training, for instance).\n",
"\n",
"```python\n",
"trainer = Trainer(gpus=4, strategy='ddp')\n",
"trainer = Trainer(devices=4, accelerator=\"gpu\", strategy='ddp')\n",
"trainer.fit(model)\n",
"```\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/voice_activity_detection.ipynb
Expand Up @@ -273,7 +273,7 @@
],
"source": [
"import pytorch_lightning as pl\n",
"trainer = pl.Trainer(gpus=1, max_epochs=2)\n",
"trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=2)\n",
"trainer.fit(model)"
]
},
Expand Down

0 comments on commit cb7bfd5

Please sign in to comment.