Skip to content

Commit

Permalink
Ignore the updates when weights are 0s and return the default value (p…
Browse files Browse the repository at this point in the history
…ytorch#283)

Summary:
Pull Request resolved: pytorch#283

For a multi-task multi-label (MTML) model, sometimes we intentionally set weights = 0 for the model effectively ignore the data. In terms of metrics calculation, we should
- ignore this update if weights for all tasks are 0
- ignore the metric result and output 0 (metric's default value) if the weights for a tasks are 0

Previously if weights = 0, there would be some NAN values for metrics and triggered metric health related alerts. This change fixes it

Differential Revision: D36114064

fbshipit-source-id: bb684849f737fa9a68eeae7c76509c5656818b34
  • Loading branch information
yachyv7 authored and facebook-github-bot committed May 4, 2022
1 parent 2f30b55 commit 01d77e0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 35 deletions.
74 changes: 63 additions & 11 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def __init__(
self._batch_window_buffers = {}
else:
self._batch_window_buffers = None
self._add_state(
"has_valid_update",
torch.zeros(self._n_tasks, dtype=torch.uint8),
add_window_state=False,
dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(),
persistent=True,
)

@staticmethod
def get_window_state_name(state_name: str) -> str:
Expand Down Expand Up @@ -339,16 +346,39 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
for metric_report in getattr(
self._metrics_computations[0], compute_scope + "compute"
)():
for task, metric_value in zip(self._tasks, metric_report.value):
yield task, metric_report.name, metric_value, compute_scope + metric_report.metric_prefix.value
for task, metric_value, has_valid_update in zip(
self._tasks,
metric_report.value,
self._metrics_computations[0].has_valid_update,
):
# The attribute has_valid_update is a tensor whose length equals to the
# number of tasks. Each value in it is corresponding to whether a task
# has valid updates or not.
# If for a task there's no valid updates, the calculated metric_value
# will be meaningless, so we mask it with the default value, i.e. 0.
valid_metric_value = (
metric_value
if has_valid_update > 0
else torch.zeros_like(metric_value)
)
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value

def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
for task, metric_computation in zip(self._tasks, self._metrics_computations):
metric_computation.pre_compute()
for metric_report in getattr(
metric_computation, compute_scope + "compute"
)():
yield task, metric_report.name, metric_report.value, compute_scope + metric_report.metric_prefix.value
# The attribute has_valid_update is a tensor with only 1 value
# corresponding to whether the task has valid updates or not.
# If there's no valid update, the calculated metric_report.value
# will be meaningless, so we mask it with the default value, i.e. 0.
valid_metric_value = (
metric_report.value
if metric_computation.has_valid_update[0] > 0
else torch.zeros_like(metric_report.value)
)
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value

def _fuse_update_buffers(self) -> Dict[str, RecModelOutput]:
def fuse(outputs: List[RecModelOutput]) -> RecModelOutput:
Expand Down Expand Up @@ -398,6 +428,9 @@ def _create_default_weights(self, predictions: torch.Tensor) -> torch.Tensor:
self._default_weights[predictions.size()] = weights
return weights

def _check_nonempty_weights(self, weights: torch.Tensor) -> torch.Tensor:
return torch.gt(torch.count_nonzero(weights, dim=-1), 0)

def _update(
self,
*,
Expand All @@ -408,6 +441,7 @@ def _update(
with torch.no_grad():
if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
assert isinstance(predictions, torch.Tensor)
# Reshape the predictions to size([len(self._tasks), self._batch_size])
predictions = predictions.view(-1, self._batch_size)
assert isinstance(labels, torch.Tensor)
labels = labels.view(-1, self._batch_size)
Expand All @@ -416,9 +450,19 @@ def _update(
else:
assert isinstance(weights, torch.Tensor)
weights = weights.view(-1, self._batch_size)
self._metrics_computations[0].update(
predictions=predictions, labels=labels, weights=weights
)
# has_valid_weights is a tensor of bool whose length equals to the number
# of tasks. Each value in it is corresponding to whether the weights
# are valid, i.e. are set to non-zero values for that task in this update.
# If has_valid_weights are Falses for all the tasks, we just ignore this
# update.
has_valid_weights = self._check_nonempty_weights(weights)
if torch.any(has_valid_weights):
self._metrics_computations[0].update(
predictions=predictions, labels=labels, weights=weights
)
self._metrics_computations[0].has_valid_update.logical_or_(
has_valid_weights
).byte()
else:
for task, metric_ in zip(self._tasks, self._metrics_computations):
if task.name not in predictions:
Expand All @@ -427,17 +471,25 @@ def _update(
assert torch.numel(labels[task.name]) == 0
assert weights is None or torch.numel(weights[task.name]) == 0
continue
# Reshape the predictions to size([1, self._batch_size])
task_predictions = predictions[task.name].view(1, -1)
task_labels = labels[task.name].view(1, -1)
if weights is None:
task_weights = self._create_default_weights(task_predictions)
else:
task_weights = weights[task.name].view(1, -1)
metric_.update(
predictions=task_predictions,
labels=task_labels,
weights=task_weights,
)
# has_valid_weights is a tensor with only 1 value corresponding to
# whether the weights are valid, i.e. are set to non-zero values for
# the task in this update.
# If has_valid_update[0] is False, we just ignore this update.
has_valid_weights = self._check_nonempty_weights(task_weights)
if has_valid_weights[0]:
metric_.update(
predictions=task_predictions,
labels=task_labels,
weights=task_weights,
)
metric_.has_valid_update.logical_or_(has_valid_weights).byte()

def update(
self,
Expand Down
48 changes: 24 additions & 24 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import tempfile
import unittest
from typing import List, Dict, Optional
from typing import Dict, List, Optional
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -25,20 +25,15 @@
StateMetricEnum,
)
from torchrec.metrics.metrics_config import (
_DEFAULT_WINDOW_SIZE,
DefaultMetricsConfig,
EmptyMetricsConfig,
DefaultTaskInfo,
RecMetricEnum,
EmptyMetricsConfig,
RecMetricDef,
_DEFAULT_WINDOW_SIZE,
)
from torchrec.metrics.model_utils import (
parse_task_model_outputs,
)
from torchrec.metrics.rec_metric import (
RecMetricList,
RecTaskInfo,
RecMetricEnum,
)
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
from torchrec.metrics.tests.test_utils import gen_test_batch, get_launch_config
from torchrec.metrics.throughput import ThroughputMetric

Expand Down Expand Up @@ -215,7 +210,10 @@ def _run_trainer_checkpointing() -> None:
state_dict = metric_module.state_dict()
for k, v in state_dict.items():
if k.startswith("rec_metrics."):
tc.assertEqual(v.item(), value * world_size)
if k.endswith("has_valid_update"):
tc.assertEqual(v.item(), 1)
else:
tc.assertEqual(v.item(), value * world_size)

# 2. Test unsync()
metric_module.unsync()
Expand Down Expand Up @@ -373,12 +371,13 @@ def test_ne_memory_usage(self) -> None:
device=torch.device("cpu"),
)
# Default NEMetric's dtype is
# float64 (8 bytes) * 16 tensors of size 1 = 128 bytes
# float64 (8 bytes) * 16 tensors of size 1 + unit8 (1 byte) * 2 tensors of size 1 = 130 bytes
# Tensors in NeMetricComputation:
# 8 in _default, 8 specific attributes: 4 attributes, 4 window
self.assertEqual(metric_module.get_memory_usage(), 128)
# NE metric specific attributes: 8 in _default, 8 actual attribute values: 4 attributes, 4 window
# RecMetric's has_valid_update attribute: 1 in _default, 1 actual attribute value
self.assertEqual(metric_module.get_memory_usage(), 130)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 160)
self.assertEqual(metric_module.get_memory_usage(), 162)

def test_calibration_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand All @@ -400,12 +399,13 @@ def test_calibration_memory_usage(self) -> None:
device=torch.device("cpu"),
)
# Default calibration metric dtype is
# float64 (8 bytes) * 8 tensors, size 1 = 64 bytes
# float64 (8 bytes) * 8 tensors of size 1 + uint8 (1 byte) * 2 tensors of size 1 = 66 bytes
# Tensors in CalibrationMetricComputation:
# 4 in _default, 4 specific attributes: 2 attribute, 2 window
self.assertEqual(metric_module.get_memory_usage(), 64)
# Calibration metric attributes: 4 in _default, 4 actual attribute values: 2 attribute, 2 window
# RecMetric's has_valid_update attribute: 1 in _default, 1 actual attribute value
self.assertEqual(metric_module.get_memory_usage(), 66)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 80)
self.assertEqual(metric_module.get_memory_usage(), 82)

def test_auc_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand All @@ -426,11 +426,11 @@ def test_auc_memory_usage(self) -> None:
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
# 3 (tensors) * 8 (double)
self.assertEqual(metric_module.get_memory_usage(), 24)
# 3 (tensors) * 8 (double) + 1 (tensor) * 2 (uint8)
self.assertEqual(metric_module.get_memory_usage(), 26)
metric_module.update(gen_test_batch(128))
# 24 (initial states) + 3 (tensors) * 128 (batch_size) * 8 (double)
self.assertEqual(metric_module.get_memory_usage(), 3096)
# 24 (initial states) + 3 (tensors) * 128 (batch_size) * 8 (double) + 1 (tensor) * 2 (uint8)
self.assertEqual(metric_module.get_memory_usage(), 3098)

def test_check_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand Down

0 comments on commit 01d77e0

Please sign in to comment.