Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensorboard] Fix TensorBoard summary encoding for torch.bfloat16 tensors #108351

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 31 additions & 3 deletions test/test_tensorboard.py
@@ -1,13 +1,13 @@
# Owner(s): ["module: unknown"]

import expecttest
import io
import numpy as np
import os
import shutil
import sys
import tempfile
import unittest
import expecttest

TEST_TENSORBOARD = True
try:
Expand Down Expand Up @@ -43,7 +43,14 @@
skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")

import torch
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ASAN, TEST_WITH_CROSSREF
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
TestCase,
run_tests,
TEST_WITH_ASAN,
TEST_WITH_CROSSREF,
)

def tensor_N(shape, dtype=float):
numel = np.prod(shape)
Expand Down Expand Up @@ -80,7 +87,7 @@ def tearDown(self):
from torch.utils.tensorboard import summary, SummaryWriter
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
from tensorboard.compat.proto.types_pb2 import DataType
from torch.utils.tensorboard.summary import tensor_proto
from torch.utils.tensorboard.summary import int_to_half, tensor_proto
from torch.utils.tensorboard._convert_np import make_np
from torch.utils.tensorboard._pytorch_graph import graph
from google.protobuf import text_format
Expand Down Expand Up @@ -865,6 +872,25 @@ def test_caffe2_simple_cnnmodel(self):
compare_proto(graph, self)

class TestTensorProtoSummary(BaseTestCase):
@parametrize(
"tensor_type,proto_type",
[
(torch.float16, DataType.DT_HALF),
(torch.bfloat16, DataType.DT_BFLOAT16),
],
)
def test_half_tensor_proto(self, tensor_type, proto_type):
float_values = [1.0, 2.0, 3.0]
actual_proto = tensor_proto(
"dummy",
torch.tensor(float_values, dtype=tensor_type),
).value[0].tensor
self.assertSequenceEqual(
[int_to_half(x) for x in actual_proto.half_val],
float_values,
)
self.assertTrue(actual_proto.dtype == proto_type)

def test_float_tensor_proto(self):
float_values = [1.0, 2.0, 3.0]
actual_proto = (
Expand Down Expand Up @@ -902,5 +928,7 @@ def test_empty_tensor_proto(self):
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
self.assertEqual(actual_proto.float_val, [])

instantiate_parametrized_tests(TestTensorProtoSummary)

if __name__ == '__main__':
run_tests()
122 changes: 71 additions & 51 deletions torch/utils/tensorboard/summary.py
@@ -1,7 +1,9 @@
import json
import logging
import os
from typing import Optional
import struct

from typing import Any, List, Optional

import torch
import numpy as np
Expand All @@ -23,6 +25,8 @@
from ._utils import _prepare_video, convert_to_HWC

__all__ = [
"half_to_int",
"int_to_half",
"hparams",
"scalar",
"histogram_raw",
Expand All @@ -46,36 +50,66 @@

logger = logging.getLogger(__name__)

def half_to_int(f: float) -> int:
"""Casts a half-precision float value into an integer.

Converts a half precision floating point value, such as `torch.half` or
`torch.bfloat16`, into an integer value which can be written into the
half_val field of a TensorProto for storage.

To undo the effects of this conversion, use int_to_half().

"""
buf = struct.pack("f", f)
return struct.unpack("i", buf)[0]

def int_to_half(i: int) -> float:
"""Casts an integer value to a half-precision float.

Converts an integer value obtained from half_to_int back into a floating
point value.

"""
buf = struct.pack("i", i)
return struct.unpack("f", buf)[0]

def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
return [half_to_int(x) for x in t.flatten().tolist()]

def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
return torch.view_as_real(t).flatten().tolist()

def _tensor_to_list(t: torch.Tensor) -> List[Any]:
return t.flatten().tolist()

# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
_TENSOR_TYPE_MAP = {
torch.half: ("DT_HALF", "half_val"),
torch.float16: ("DT_HALF", "half_val"),
torch.bfloat16: ("DT_BFLOAT", "float_val"),
torch.float32: ("DT_FLOAT", "float_val"),
torch.float: ("DT_FLOAT", "float_val"),
torch.float64: ("DT_DOUBLE", "double_val"),
torch.double: ("DT_DOUBLE", "double_val"),
torch.int8: ("DT_INT8", "int_val"),
torch.uint8: ("DT_UINT8", "int_val"),
torch.qint8: ("DT_UINT8", "int_val"),
torch.int16: ("DT_INT16", "int_val"),
torch.short: ("DT_INT16", "int_val"),
torch.int: ("DT_INT32", "int_val"),
torch.int32: ("DT_INT32", "int_val"),
torch.qint32: ("DT_INT32", "int_val"),
torch.int64: ("DT_INT64", "int64_val"),
torch.complex32: ("DT_COMPLEX32", "scomplex_val"),
torch.chalf: ("DT_COMPLEX32", "scomplex_val"),
torch.complex64: ("DT_COMPLEX64", "scomplex_val"),
torch.cfloat: ("DT_COMPLEX64", "scomplex_val"),
torch.bool: ("DT_BOOL", "bool_val"),
torch.complex128: ("DT_COMPLEX128", "dcomplex_val"),
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val"),
torch.uint8: ("DT_UINT8", "uint32_val"),
torch.quint8: ("DT_UINT8", "uint32_val"),
torch.quint4x2: ("DT_UINT8", "uint32_val"),
torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
torch.short: ("DT_INT16", "int_val", _tensor_to_list),
torch.int: ("DT_INT32", "int_val", _tensor_to_list),
torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
}


Expand Down Expand Up @@ -375,35 +409,21 @@ def tensor_proto(tag, tensor):
)

if tensor.dtype in _TENSOR_TYPE_MAP:
proto_val_field = _TENSOR_TYPE_MAP[tensor.dtype][1]

if proto_val_field == "scomplex_val" or proto_val_field == "dcomplex_val":
proto_val_contents = torch.view_as_real(tensor).flatten().tolist()
elif tensor.numel() == 1:
proto_val_contents = [tensor.item()]
elif tensor.numel() == 0:
proto_val_contents = []
else:
proto_val_contents = tensor.flatten().tolist()

tensor_proto_args = {
"dtype": _TENSOR_TYPE_MAP[tensor.dtype][0],
"tensor_shape": TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=tensor.shape[i])
for i in range(tensor.dim())
]
),
proto_val_field: proto_val_contents,
}

tensor_proto = TensorProto(**tensor_proto_args)
dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
tensor_proto = TensorProto(
**{
"dtype": dtype,
"tensor_shape": TensorShapeProto(
dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
),
field_name: conversion_fn(tensor),
},
)
else:
raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")

plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
smd = SummaryMetadata(plugin_data=plugin_data)

return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])


Expand Down