Skip to content

Commit

Permalink
add list of int, float feature in TFRecordSampleWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
justHungryMan committed Nov 11, 2022
1 parent c6df199 commit e2b5138
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions img2dataset/writer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
""""writer module handle writing the images to disk"""

import webdataset as wds
import json
import pyarrow.parquet as pq
import pyarrow as pa
import fsspec
import os

import fsspec
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import webdataset as wds


class BufferedParquetWriter:
Expand Down Expand Up @@ -150,15 +151,10 @@ def __init__(
try:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow_io as _ # pylint: disable=import-outside-toplevel
from tensorflow.python.lib.io.tf_record import \
TFRecordWriter # pylint: disable=import-outside-toplevel
from tensorflow.python.training.training import ( # pylint: disable=import-outside-toplevel
BytesList,
Int64List,
FloatList,
Example,
Features,
Feature,
)
from tensorflow.python.lib.io.tf_record import TFRecordWriter # pylint: disable=import-outside-toplevel
BytesList, Example, Feature, Features, FloatList, Int64List)

self._BytesList = BytesList # pylint: disable=invalid-name
self._Int64List = Int64List # pylint: disable=invalid-name
Expand Down Expand Up @@ -203,7 +199,9 @@ def close(self):

def _feature(self, value):
"""Convert to proper feature type"""
if isinstance(value, int):
if isinstance(value, list):
return self._list_feature(value)
elif isinstance(value, int):
return self._int64_feature(value)
elif isinstance(value, float):
return self._float_feature(value)
Expand All @@ -226,6 +224,18 @@ def _int64_feature(self, value):
"""Returns an int64_list from a bool / enum / int / uint."""
return self._Feature(int64_list=self._Int64List(value=[value]))

def _list_feature(self, value):
"""Returns an int64_list from a bool / enum / int / uint or float_list from a float / double."""
if isinstance(value[0], int):
return self._Feature(int64_list=self._Int64List(value=value))
elif isinstance(value[0], float):
return self._Feature(float_list=self._FloatList(value=value))
else:
raise NotImplementedError(
"list feature can only have int64 or floats"
)



class FilesSampleWriter:
"""FilesSampleWriter is a caption+image writer to files"""
Expand Down

0 comments on commit e2b5138

Please sign in to comment.