You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am running the tfrecord converter example using spark-connector_2.11-1.10.0.jar(pyspark --jars ~/spark-connector_2.11-1.10.0.jar), it works. Then, I add an string vector fileld called 'StrVectorCol', I got an exception. The spark version is 2.4.0-cdh6.1.1.
hadoo-002, executor 1): org.apache.spark.SparkException: Task failed while writing rows
at org.apache.spark.internal.io.SparkHadoopWriter$.org$apache$spark$internal$io$SparkHadoopWriter$$executeTask(SparkHadoopWriter.scala:155)
at org.apache.spark.internal.io.SparkHadoopWriter$$anonfun$3.apply(SparkHadoopWriter.scala:83)
at org.apache.spark.internal.io.SparkHadoopWriter$$anonfun$3.apply(SparkHadoopWriter.scala:78)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$11.apply(Executor.scala:407)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:413)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: java.lang.String cannot be cast to org.apache.spark.unsafe.types.UTF8String
at org.apache.spark.sql.catalyst.util.GenericArrayData.getUTF8String(GenericArrayData.scala:75)
at org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$8.apply(InternalRow.scala:136)
at org.apache.spark.sql.catalyst.InternalRow$$anonfun$getAccessor$8.apply(InternalRow.scala:136)
at org.apache.spark.sql.catalyst.util.ArrayData.toArray(ArrayData.scala:178)
at org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder$.org$tensorflow$spark$datasources$tfrecords$serde$DefaultTfRecordRowEncoder$$encodeFeature(DefaultTfRecordRowEncoder.scala:132)
at org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder$$anonfun$encodeExample$1.apply(DefaultTfRecordRowEncoder.scala:64)
at org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder$$anonfun$encodeExample$1.apply(DefaultTfRecordRowEncoder.scala:61)
at scala.collection.immutable.List.foreach(List.scala:392)
at org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder$.encodeExample(DefaultTfRecordRowEncoder.scala:61)
at org.tensorflow.spark.datasources.tfrecords.DefaultSource$$anonfun$2.apply(DefaultSource.scala:59)
at org.tensorflow.spark.datasources.tfrecords.DefaultSource$$anonfun$2.apply(DefaultSource.scala:56)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at org.apache.spark.internal.io.SparkHadoopWriter$$anonfun$4.apply(SparkHadoopWriter.scala:129)
at org.apache.spark.internal.io.SparkHadoopWriter$$anonfun$4.apply(SparkHadoopWriter.scala:127)
at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1394)
at org.apache.spark.internal.io.SparkHadoopWriter$.org$apache$spark$internal$io$SparkHadoopWriter$$executeTask(SparkHadoopWriter.scala:139)
... 10 more
The text was updated successfully, but these errors were encountered:
Here is the solution I found. You need to cast strings to ByteArray. I used this function before writing to TFRecord:
def strings_to_binary(df):
"""
This function casts all StringType columns in a Spark DataFrame to BinaryType.
The Spark-Tensorflow connector does not accept Array(StringType) columns when writing TFRecords.
It expects Array(BinaryType), hence the need to cast string columns to binary before collecting them into lists
and writing to TFRecords.
"""
for col_name in df.columns:
if isinstance(df.schema[col_name].dataType, StringType):
df = df.withColumn(col_name, fn.col(col_name).cast(BinaryType()))
return df
I kept the columns as strings so I could manipulate them with functions that expected StringType and converted them to BinaryType just before writing.
I am running the tfrecord converter example using spark-connector_2.11-1.10.0.jar(
pyspark --jars ~/spark-connector_2.11-1.10.0.jar
), it works. Then, I add an string vector fileld called 'StrVectorCol', I got an exception. The spark version is 2.4.0-cdh6.1.1.I got an exception.
The text was updated successfully, but these errors were encountered: