Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore the updates when weights are 0s and return the default value #283

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 63 additions & 11 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def __init__(
self._batch_window_buffers = {}
else:
self._batch_window_buffers = None
self._add_state(
"has_valid_update",
torch.zeros(self._n_tasks, dtype=torch.uint8),
add_window_state=False,
dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(),
persistent=True,
)

@staticmethod
def get_window_state_name(state_name: str) -> str:
Expand Down Expand Up @@ -339,16 +346,39 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
for metric_report in getattr(
self._metrics_computations[0], compute_scope + "compute"
)():
for task, metric_value in zip(self._tasks, metric_report.value):
yield task, metric_report.name, metric_value, compute_scope + metric_report.metric_prefix.value
for task, metric_value, has_valid_update in zip(
self._tasks,
metric_report.value,
self._metrics_computations[0].has_valid_update,
):
# The attribute has_valid_update is a tensor whose length equals to the
# number of tasks. Each value in it is corresponding to whether a task
# has valid updates or not.
# If for a task there's no valid updates, the calculated metric_value
# will be meaningless, so we mask it with the default value, i.e. 0.
valid_metric_value = (
metric_value
if has_valid_update > 0
else torch.zeros_like(metric_value)
)
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value

def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
for task, metric_computation in zip(self._tasks, self._metrics_computations):
metric_computation.pre_compute()
for metric_report in getattr(
metric_computation, compute_scope + "compute"
)():
yield task, metric_report.name, metric_report.value, compute_scope + metric_report.metric_prefix.value
# The attribute has_valid_update is a tensor with only 1 value
# corresponding to whether the task has valid updates or not.
# If there's no valid update, the calculated metric_report.value
# will be meaningless, so we mask it with the default value, i.e. 0.
valid_metric_value = (
metric_report.value
if metric_computation.has_valid_update[0] > 0
else torch.zeros_like(metric_report.value)
)
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value

def _fuse_update_buffers(self) -> Dict[str, RecModelOutput]:
def fuse(outputs: List[RecModelOutput]) -> RecModelOutput:
Expand Down Expand Up @@ -398,6 +428,9 @@ def _create_default_weights(self, predictions: torch.Tensor) -> torch.Tensor:
self._default_weights[predictions.size()] = weights
return weights

def _check_nonempty_weights(self, weights: torch.Tensor) -> torch.Tensor:
return torch.gt(torch.count_nonzero(weights, dim=-1), 0)

def _update(
self,
*,
Expand All @@ -408,6 +441,7 @@ def _update(
with torch.no_grad():
if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
assert isinstance(predictions, torch.Tensor)
# Reshape the predictions to size([len(self._tasks), self._batch_size])
predictions = predictions.view(-1, self._batch_size)
assert isinstance(labels, torch.Tensor)
labels = labels.view(-1, self._batch_size)
Expand All @@ -416,9 +450,19 @@ def _update(
else:
assert isinstance(weights, torch.Tensor)
weights = weights.view(-1, self._batch_size)
self._metrics_computations[0].update(
predictions=predictions, labels=labels, weights=weights
)
# has_valid_weights is a tensor of bool whose length equals to the number
# of tasks. Each value in it is corresponding to whether the weights
# are valid, i.e. are set to non-zero values for that task in this update.
# If has_valid_weights are Falses for all the tasks, we just ignore this
# update.
has_valid_weights = self._check_nonempty_weights(weights)
if torch.any(has_valid_weights):
self._metrics_computations[0].update(
predictions=predictions, labels=labels, weights=weights
)
self._metrics_computations[0].has_valid_update.logical_or_(
has_valid_weights
).byte()
else:
for task, metric_ in zip(self._tasks, self._metrics_computations):
if task.name not in predictions:
Expand All @@ -427,17 +471,25 @@ def _update(
assert torch.numel(labels[task.name]) == 0
assert weights is None or torch.numel(weights[task.name]) == 0
continue
# Reshape the predictions to size([1, self._batch_size])
task_predictions = predictions[task.name].view(1, -1)
task_labels = labels[task.name].view(1, -1)
if weights is None:
task_weights = self._create_default_weights(task_predictions)
else:
task_weights = weights[task.name].view(1, -1)
metric_.update(
predictions=task_predictions,
labels=task_labels,
weights=task_weights,
)
# has_valid_weights is a tensor with only 1 value corresponding to
# whether the weights are valid, i.e. are set to non-zero values for
# the task in this update.
# If has_valid_update[0] is False, we just ignore this update.
has_valid_weights = self._check_nonempty_weights(task_weights)
if has_valid_weights[0]:
metric_.update(
predictions=task_predictions,
labels=task_labels,
weights=task_weights,
)
metric_.has_valid_update.logical_or_(has_valid_weights).byte()

def update(
self,
Expand Down
48 changes: 24 additions & 24 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import tempfile
import unittest
from typing import List, Dict, Optional
from typing import Dict, List, Optional
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -25,20 +25,15 @@
StateMetricEnum,
)
from torchrec.metrics.metrics_config import (
_DEFAULT_WINDOW_SIZE,
DefaultMetricsConfig,
EmptyMetricsConfig,
DefaultTaskInfo,
RecMetricEnum,
EmptyMetricsConfig,
RecMetricDef,
_DEFAULT_WINDOW_SIZE,
)
from torchrec.metrics.model_utils import (
parse_task_model_outputs,
)
from torchrec.metrics.rec_metric import (
RecMetricList,
RecTaskInfo,
RecMetricEnum,
)
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
from torchrec.metrics.tests.test_utils import gen_test_batch, get_launch_config
from torchrec.metrics.throughput import ThroughputMetric

Expand Down Expand Up @@ -215,7 +210,10 @@ def _run_trainer_checkpointing() -> None:
state_dict = metric_module.state_dict()
for k, v in state_dict.items():
if k.startswith("rec_metrics."):
tc.assertEqual(v.item(), value * world_size)
if k.endswith("has_valid_update"):
tc.assertEqual(v.item(), 1)
else:
tc.assertEqual(v.item(), value * world_size)

# 2. Test unsync()
metric_module.unsync()
Expand Down Expand Up @@ -373,12 +371,13 @@ def test_ne_memory_usage(self) -> None:
device=torch.device("cpu"),
)
# Default NEMetric's dtype is
# float64 (8 bytes) * 16 tensors of size 1 = 128 bytes
# float64 (8 bytes) * 16 tensors of size 1 + unit8 (1 byte) * 2 tensors of size 1 = 130 bytes
# Tensors in NeMetricComputation:
# 8 in _default, 8 specific attributes: 4 attributes, 4 window
self.assertEqual(metric_module.get_memory_usage(), 128)
# NE metric specific attributes: 8 in _default, 8 actual attribute values: 4 attributes, 4 window
# RecMetric's has_valid_update attribute: 1 in _default, 1 actual attribute value
self.assertEqual(metric_module.get_memory_usage(), 130)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 160)
self.assertEqual(metric_module.get_memory_usage(), 162)

def test_calibration_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand All @@ -400,12 +399,13 @@ def test_calibration_memory_usage(self) -> None:
device=torch.device("cpu"),
)
# Default calibration metric dtype is
# float64 (8 bytes) * 8 tensors, size 1 = 64 bytes
# float64 (8 bytes) * 8 tensors of size 1 + uint8 (1 byte) * 2 tensors of size 1 = 66 bytes
# Tensors in CalibrationMetricComputation:
# 4 in _default, 4 specific attributes: 2 attribute, 2 window
self.assertEqual(metric_module.get_memory_usage(), 64)
# Calibration metric attributes: 4 in _default, 4 actual attribute values: 2 attribute, 2 window
# RecMetric's has_valid_update attribute: 1 in _default, 1 actual attribute value
self.assertEqual(metric_module.get_memory_usage(), 66)
metric_module.update(gen_test_batch(128))
self.assertEqual(metric_module.get_memory_usage(), 80)
self.assertEqual(metric_module.get_memory_usage(), 82)

def test_auc_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand All @@ -426,11 +426,11 @@ def test_auc_memory_usage(self) -> None:
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
device=torch.device("cpu"),
)
# 3 (tensors) * 8 (double)
self.assertEqual(metric_module.get_memory_usage(), 24)
# 3 (tensors) * 8 (double) + 1 (tensor) * 2 (uint8)
self.assertEqual(metric_module.get_memory_usage(), 26)
metric_module.update(gen_test_batch(128))
# 24 (initial states) + 3 (tensors) * 128 (batch_size) * 8 (double)
self.assertEqual(metric_module.get_memory_usage(), 3096)
# 24 (initial states) + 3 (tensors) * 128 (batch_size) * 8 (double) + 1 (tensor) * 2 (uint8)
self.assertEqual(metric_module.get_memory_usage(), 3098)

def test_check_memory_usage(self) -> None:
mock_optimizer = MockOptimizer()
Expand Down