Skip to content

Commit

Permalink
feat(task): add option to cache training metadata to disk
Browse files Browse the repository at this point in the history
Co-authored-by: Herve Bredin <hbredin@users.noreply.github.com>
  • Loading branch information
clement-pages and hbredin committed Jan 12, 2024
1 parent d41ce0a commit 9e4ec5f
Show file tree
Hide file tree
Showing 16 changed files with 949 additions and 496 deletions.
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ body:
- The golden rule is to **always open *one* issue for *one* bug**. If you notice several bugs and want to report them, make sure to create one new issue for each of them.
- Search [open](https://github.com/pyannote/pyannote-audio/issues) and [closed](https://github.com/pyannote/pyannote-audio/issues?q=is%3Aissue+is%3Aclosed) issues to ensure it has not already been reported. If you don't find a relevant match or if you're unsure, don't hesitate to **open a new issue**. The bugsquad will handle it from there if it's a duplicate.
- Please always check if your issue is reproducible in the latest version – it may already have been fixed!
- If you use a custom build, please test if your issue is reproducible in official releases too.
- If you use a custom build, please test if your issue is reproducible in official releases too.
- type: textarea
attributes:
label: Tested versions
description: |
To properly fix a bug, we need to identify if the bug was recently introduced in the engine, or if it was always present.
- Please specify the pyannote.audio version you found the issue in, including the **Git commit hash** if using a development build.
- If you can, **please test earlier pyannote.audio versions** and, if applicable, newer versions (development branch). Mention whether the bug is reproducible or not in the versions you tested.
- If you can, **please test earlier pyannote.audio versions** and, if applicable, newer versions (development branch). Mention whether the bug is reproducible or not in the versions you tested.
- The aim is for us to identify whether a bug is a **regression**, i.e. an issue that didn't exist in a previous version, but was introduced later on, breaking existing functionality. For example, if a bug is reproducible in 3.2 but not in 3.0, we would like you to test intermediate 3.1 to find which version is the first one where the issue can be reproduced.
placeholder: |
- Reproducible in: 3.1, 3.2, and later
Expand All @@ -33,7 +33,7 @@ body:
- Specify the OS version, and when relevant hardware information.
- For issues that are likely OS-specific and/or GPU-related, please specify the GPU model and architecture.
- **Bug reports not including the required information may be closed at the maintainers' discretion.** If in doubt, always include all the requested information; it's better to include too much information than not enough information.
placeholder: macOS 13.6 - pyannote.audio 3.1.1 - M1 Pro
placeholder: macOS 13.6 - pyannote.audio 3.1.1 - M1 Pro
validations:
required: true

Expand Down
5 changes: 2 additions & 3 deletions .github/ISSUE_TEMPLATE/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ blank_issues_enabled: false

contact_links:

- name: Feature request
- name: Feature request
url: https://github.com/pyannote/pyannote-audio/discussions
about: Suggest an idea for this project.

Expand All @@ -12,5 +12,4 @@ contact_links:

- name: Premium models
url: https://forms.gle/eKhn7H2zTa68sMMx8
about: We are considering selling premium models, extensions, or services around pyannote.audio.

about: We are considering selling premium models, extensions, or services around pyannote.audio.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

### New features

- feat(task): add option to cache task training metadata to speed up training
- feat(pipeline): add `Waveform` and `SampleRate` preprocessors
- feat(model): add `num_frames` and `receptive_field` to segmentation models

### Fixes

- fix(task): fix random generators

## Breaking changes

- BREAKING(task): custom tasks need to be updated (see "Add your own task" tutorial)

## Version 3.1.1 (2023-12-01)

### TL;DR
Expand Down
23 changes: 16 additions & 7 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def specifications(self) -> Union[Specifications, Tuple[Specifications]]:
except AttributeError as e:
raise UnknownSpecificationsError(
"Task specifications are not available. This is most likely because they depend on "
"the content of the training subset. Use `model.task.setup()` to go over the training "
"subset and fix this, or let lightning trainer do that for you in `trainer.fit(model)`."
"the content of the training subset. Use `model.prepare_data()` and `model.setup()` "
"to go over the training subset and fix this, or let lightning trainer do that for you in `trainer.fit(model)`."
) from e

return specifications
Expand Down Expand Up @@ -217,9 +217,19 @@ def __example_output(
self.specifications, __example_output, example_output
)

def prepare_data(self):
self.task.prepare_data()

def setup(self, stage=None):
if stage == "fit":
self.task.setup_metadata()
# let the task know about the trainer (e.g for broadcasting
# cache path between multi-GPU training processes).
self.task.trainer = self.trainer

# setup the task if defined (only on training and validation stages,
# but not for basic inference)
if self.task:
self.task.setup(stage)

# list of layers before adding task-dependent layers
before = set((name, id(module)) for name, module in self.named_modules())
Expand Down Expand Up @@ -252,7 +262,7 @@ def setup(self, stage=None):
module.to(self.device)

# add (trainable) loss function (e.g. ArcFace has its own set of trainable weights)
if stage == "fit":
if self.task:
# let task know about the model
self.task.model = self
# setup custom loss function
Expand Down Expand Up @@ -468,9 +478,8 @@ def __by_name(
if isinstance(modules, str):
modules = [modules]

for name, module in ModelSummary(self, max_depth=-1).named_modules:
if name not in modules:
continue
for name in modules:
module = getattr(self, name)

for parameter in module.parameters(recurse=True):
parameter.requires_grad = requires_grad
Expand Down
Loading

0 comments on commit 9e4ec5f

Please sign in to comment.