Skip to content

Commit

Permalink
feat: add weigh_by_cardinality option (#1256)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Feb 13, 2023
1 parent 6dab904 commit 1f83e0b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
20 changes: 19 additions & 1 deletion pyannote/audio/tasks/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class Segmentation(SegmentationTaskMixin, Task):
Maximum number of (overlapping) speakers per frame.
Setting this value to 1 or more enables `powerset multi-class` training.
Default behavior is to use `multi-label` training.
weigh_by_cardinality: bool, optional
Weigh each powerset classes by the size of the corresponding speaker set.
In other words, {0, 1} powerset class weight is 2x bigger than that of {0}
or {1} powerset classes. Note that empty (non-speech) powerset class is
assigned the same weight as mono-speaker classes. Defaults to False (i.e. use
same weight for every class). Has no effect with `multi-label` training.
warm_up : float or (float, float), optional
Use that many seconds on the left- and rightmost parts of each chunk
to warm up the model. While the model does process those left- and right-most
Expand Down Expand Up @@ -119,6 +125,7 @@ def __init__(
duration: float = 2.0,
max_speakers_per_chunk: int = None,
max_speakers_per_frame: int = None,
weigh_by_cardinality: bool = False,
warm_up: Union[float, Tuple[float, float]] = 0.0,
balance: Text = None,
weight: Text = None,
Expand Down Expand Up @@ -165,6 +172,7 @@ def __init__(

self.max_speakers_per_chunk = max_speakers_per_chunk
self.max_speakers_per_frame = max_speakers_per_frame
self.weigh_by_cardinality = weigh_by_cardinality
self.balance = balance
self.weight = weight
self.vad_loss = vad_loss
Expand Down Expand Up @@ -291,8 +299,18 @@ def segmentation_loss(
"""

if self.specifications.powerset:

# `clamp_min` is needed to set non-speech weight to 1.
class_weight = (
torch.clamp_min(self.model.powerset.cardinality, 1.0)
if self.weigh_by_cardinality
else None
)
seg_loss = nll_loss(
permutated_prediction, torch.argmax(target, dim=-1), weight=weight
permutated_prediction,
torch.argmax(target, dim=-1),
class_weight=class_weight,
weight=weight,
)
else:
seg_loss = binary_cross_entropy(
Expand Down
9 changes: 8 additions & 1 deletion pyannote/audio/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def mse_loss(


def nll_loss(
prediction: torch.Tensor, target: torch.Tensor, weight: torch.Tensor = None
prediction: torch.Tensor,
target: torch.Tensor,
class_weight: torch.Tensor = None,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""Frame-weighted negative log-likelihood loss
Expand All @@ -139,6 +142,8 @@ def nll_loss(
Prediction with shape (batch_size, num_frames, num_classes).
target : torch.Tensor
Target with shape (batch_size, num_frames)
class_weight : (num_classes, ) torch.Tensor, optional
Class weight with shape (num_classes, )
weight : (batch_size, num_frames, 1) torch.Tensor, optional
Frame weight with shape (batch_size, num_frames, 1).
Expand All @@ -154,6 +159,8 @@ def nll_loss(
# (batch_size x num_frames, num_classes)
target.view(-1),
# (batch_size x num_frames, )
weight=class_weight,
# (num_classes, )
reduction="none",
).view(target.shape)
# (batch_size, num_frames)
Expand Down
14 changes: 12 additions & 2 deletions pyannote/audio/utils/powerset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self, num_classes: int, max_set_size: int):
self.num_classes = num_classes
self.max_set_size = max_set_size

mapping = self.build_mapping()
self.register_buffer("mapping", mapping, persistent=False)
self.register_buffer("mapping", self.build_mapping(), persistent=False)
self.register_buffer("cardinality", self.build_cardinality(), persistent=False)

@cached_property
def num_powerset_classes(self) -> int:
Expand All @@ -78,6 +78,16 @@ def build_mapping(self) -> torch.Tensor:

return mapping

def build_cardinality(self) -> torch.Tensor:
"""Compute size of each powerset class"""
cardinality = torch.zeros(self.num_powerset_classes)
powerset_k = 0
for set_size in range(0, self.max_set_size + 1):
for _ in combinations(range(self.num_classes), set_size):
cardinality[powerset_k] = set_size
powerset_k += 1
return cardinality

def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor:
"""Convert (hard) predictions from powerset to multi-label
Expand Down

0 comments on commit 1f83e0b

Please sign in to comment.