Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 113 additions & 10 deletions tensorflow_io/core/kernels/image_dicom_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.h"
#include "absl/strings/str_split.h"
#include "absl/strings/numbers.h"

// clang-format on

Expand Down Expand Up @@ -319,7 +321,6 @@ class DecodeDICOMDataOp : public OpKernel {

const Tensor *in_tags;
OP_REQUIRES_OK(context, context->input("tags", &in_tags));
auto in_tags_flat = in_tags->flat<uint32>();

// Create an output tensor
Tensor *out_tag_values = NULL;
Expand All @@ -337,26 +338,128 @@ class DecodeDICOMDataOp : public OpKernel {
OFCondition cond = dfile.read(dataBuf);
dfile.transferEnd();

DcmDataset *dset = dfile.getDataset();
DcmItem *item = static_cast<DcmItem *>(dfile.getDataset());
DcmMetaInfo *meta = dfile.getMetaInfo();

for (unsigned int tag_i = 0; tag_i < in_tags_flat.size(); ++tag_i) {
uint32 tag_value = in_tags_flat(tag_i);
uint16 tag_group_number = (uint16)((tag_value & 0xFFFF0000) >> 16);
uint16 tag_element_number = (uint16)((tag_value & 0x0000FFFF) >> 0);
DcmTag tag(tag_group_number, tag_element_number);
for (int64 tag_i = 0; tag_i < in_tags->NumElements(); ++tag_i) {
DcmTag tag;
if (in_tags->dtype() == DT_STRING) {
uint16 tag_group_number, tag_element_number;
const auto &in_tags_flat = in_tags->flat<tstring>();
const tstring &tag_sequence = in_tags_flat(tag_i);
std::vector<absl::string_view> tag_sequence_views;
if (tag_sequence[0] == '[' &&
tag_sequence[tag_sequence.size() - 1] == ']') {
tag_sequence_views =
absl::StrSplit(absl::string_view(tag_sequence.data() + 1,
tag_sequence.size() - 2),
"][");
} else {
tag_sequence_views.push_back(
absl::string_view(tag_sequence.data(), tag_sequence.size()));
}

OP_REQUIRES(
context, (tag_sequence_views.size() % 2 == 1),
errors::InvalidArgument(
"tag sequences should have 2xn + 1 elements, received: ",
tag_sequence_views.size()));

// Walk through before the last element of value
for (size_t i = 0; i < tag_sequence_views.size() - 1; i += 2) {
absl::string_view tag_value(tag_sequence_views[i].data(),
tag_sequence_views[i].size());
OP_REQUIRES_OK(context, GetDcmTag(tag_value, &tag));

absl::string_view number_value(tag_sequence_views[i + 1].data(),
tag_sequence_views[i + 1].size());
uint32 number = 0;
OP_REQUIRES(
context,
absl::numbers_internal::safe_strtou32_base(number_value, &number,
0),
errors::InvalidArgument("number should be an integer, received ",
number_value));

DcmItem *lookup;
OFCondition condition =
item->findAndGetSequenceItem(tag, lookup, number);
OP_REQUIRES(context, condition.good(),
errors::InvalidArgument("item findAndGetSequenceItem: ",
condition.text()));
item = lookup;
}

// The last element of value
absl::string_view tag_value(tag_sequence_views.back().data(),
tag_sequence_views.back().size());
OP_REQUIRES_OK(context, GetDcmTag(tag_value, &tag));
} else {
const auto &in_tags_flat = in_tags->flat<uint32>();
uint32 tag_value = in_tags_flat(tag_i);
OP_REQUIRES_OK(context, GetDcmTag(tag_value, &tag));
}

OFString val;
if (dset->tagExists(tag)) {
dset->findAndGetOFStringArray(tag, val);
if (item->tagExists(tag)) {
OFCondition condition = item->findAndGetOFStringArray(tag, val);
OP_REQUIRES(context, condition.good(),
errors::InvalidArgument("item findAndGetOFStringArray: ",
condition.text()));
} else if (meta->tagExists(tag)) {
meta->findAndGetOFStringArray(tag, val);
OFCondition condition = meta->findAndGetOFStringArray(tag, val);
OP_REQUIRES(context, condition.good(),
errors::InvalidArgument("meta findAndGetOFStringArray: ",
condition.text()));
} else {
val = OFString("");
}
out_tag_values_flat(tag_i) = val.c_str();
}
}

private:
Status GetDcmTag(const uint32 tag_value, DcmTag *tag) {
uint16 tag_group_number = (uint16)((tag_value & 0xFFFF0000) >> 16);
uint16 tag_element_number = (uint16)((tag_value & 0x0000FFFF) >> 0);
*tag = DcmTag(tag_group_number, tag_element_number);
return Status::OK();
}
Status GetDcmTag(const absl::string_view tag_value, DcmTag *tag) {
std::vector<absl::string_view> number_views =
absl::StrSplit(tag_value, ',');
if (number_views.size() != 2) {
return errors::InvalidArgument(
"sequence should consist of group and "
"element numbers, received ",
tag_value);
}
uint32 number = 0;
if (!absl::numbers_internal::safe_strtou32_base(number_views[0], &number,
0)) {
return errors::InvalidArgument(
"group number should be an integer, received ", number_views[0]);
}
if (number > std::numeric_limits<short>::max()) {
return errors::InvalidArgument("group number should be uint16, received ",
number_views[0]);
}
uint16 tag_group_number = number;

if (!absl::numbers_internal::safe_strtou32_base(number_views[1], &number,
0)) {
return errors::InvalidArgument(
"element number should be an integer, received ", number_views[1]);
}
if (number > std::numeric_limits<short>::max()) {
return errors::InvalidArgument(
"element number should be uint16, received ", number_views[1]);
}
uint16 tag_element_number = number;

*tag = DcmTag(tag_group_number, tag_element_number);
return Status::OK();
}
};

// Register the CPU kernels.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_io/core/ops/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ loads a dicom image file and returns its pixel information in the specified outp

REGISTER_OP("IO>DecodeDICOMData")
.Input("contents: string")
.Input("tags: uint32")
.Input("tags: dtype")
.Attr("dtype: {uint32,string}")
.Output("tag_values: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(1));
Expand Down
Loading