Skip to content

Commit 6e782e1

Browse files
committed
[SPARKNLP-1119] Adding XML reader
1 parent 2b37fb0 commit 6e782e1

File tree

11 files changed

+358
-3
lines changed

11 files changed

+358
-3
lines changed

python/sparknlp/reader/sparknlp_reader.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,49 @@ def txt(self, docPath):
322322
if not isinstance(docPath, str):
323323
raise TypeError("docPath must be a string")
324324
jdf = self._java_obj.txt(docPath)
325+
return self.getDataFrame(self.spark, jdf)
326+
327+
def xml(self, docPath):
328+
"""Reads XML files and returns a Spark DataFrame.
329+
330+
Parameters
331+
----------
332+
docPath : str
333+
Path to an XML file or a directory containing XML files.
334+
335+
Returns
336+
-------
337+
pyspark.sql.DataFrame
338+
A DataFrame containing parsed XML content.
339+
340+
Examples
341+
--------
342+
>>> from sparknlp.reader import SparkNLPReader
343+
>>> xml_df = SparkNLPReader(spark).xml("home/user/xml-directory")
344+
345+
You can use SparkNLP for one line of code
346+
347+
>>> import sparknlp
348+
>>> xml_df = sparknlp.read().xml("home/user/xml-directory")
349+
>>> xml_df.show(truncate=False)
350+
+-----------------------------------------------------------+
351+
|xml |
352+
+-----------------------------------------------------------+
353+
|[{Title, John Smith, {elementId -> ..., tag -> title}}] |
354+
+-----------------------------------------------------------+
355+
356+
>>> xml_df.printSchema()
357+
root
358+
|-- path: string (nullable = true)
359+
|-- xml: array (nullable = true)
360+
| |-- element: struct (containsNull = true)
361+
| | |-- elementType: string (nullable = true)
362+
| | |-- content: string (nullable = true)
363+
| | |-- metadata: map (nullable = true)
364+
| | | |-- key: string
365+
| | | |-- value: string (valueContainsNull = true)
366+
"""
367+
if not isinstance(docPath, str):
368+
raise TypeError("docPath must be a string")
369+
jdf = self._java_obj.xml(docPath)
325370
return self.getDataFrame(self.spark, jdf)

python/test/sparknlp_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,18 @@ def runTest(self):
125125
txt_df = sparknlp.read().txt(self.txt_file)
126126
txt_df.show()
127127

128-
self.assertTrue(txt_df.select("txt").count() > 0)
128+
self.assertTrue(txt_df.select("txt").count() > 0)
129+
130+
131+
@pytest.mark.fast
132+
class SparkNLPTestXMLFilesSpec(unittest.TestCase):
133+
134+
def setUp(self):
135+
self.data = SparkContextForTest.data
136+
self.xml_files = f"file:///{os.getcwd()}/../src/test/resources/reader/xml"
137+
138+
def runTest(self):
139+
xml_df = sparknlp.read().xml(self.xml_files)
140+
xml_df.show()
141+
142+
self.assertTrue(xml_df.select("xml").count() > 0)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright 2017-2025 John Snow Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.johnsnowlabs.partition
17+
18+
import com.johnsnowlabs.nlp.ParamsAndFeaturesWritable
19+
import org.apache.spark.ml.param.Param
20+
21+
trait HasXmlReaderProperties extends ParamsAndFeaturesWritable {
22+
23+
val xmlKeepTags = new Param[Boolean](
24+
this,
25+
"xmlKeepTags",
26+
"Whether to include XML tag names as metadata in the output.")
27+
28+
def setXmlKeepTags(value: Boolean): this.type = set(xmlKeepTags, value)
29+
30+
val onlyLeafNodes = new Param[Boolean](
31+
this,
32+
"onlyLeafNodes",
33+
"If true, only processes XML leaf nodes (no nested children).")
34+
35+
def setOnlyLeafNodes(value: Boolean): this.type = set(onlyLeafNodes, value)
36+
37+
setDefault(xmlKeepTags -> false, onlyLeafNodes -> true)
38+
}

src/main/scala/com/johnsnowlabs/partition/Partition.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
188188
"application/vnd.openxmlformats-officedocument.presentationml.presentation" =>
189189
sparkNLPReader.ppt
190190
case "application/pdf" => sparkNLPReader.pdf
191+
case "application/xml" => sparkNLPReader.xml
191192
case _ => throw new IllegalArgumentException(s"Unsupported content type: $contentType")
192193
}
193194
}
@@ -199,6 +200,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
199200
case "text/plain" => sparkNLPReader.txtToHTMLElement
200201
case "text/html" => sparkNLPReader.htmlToHTMLElement
201202
case "url" => sparkNLPReader.urlToHTMLElement
203+
case "application/xml" => sparkNLPReader.xmlToHTMLElement
202204
case _ => throw new IllegalArgumentException(s"Unsupported content type: $contentType")
203205
}
204206
}
@@ -234,6 +236,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
234236
case "xls" | "xlsx" => sparkNLPReader.xls
235237
case "ppt" | "pptx" => sparkNLPReader.ppt
236238
case "pdf" => sparkNLPReader.pdf
239+
case "xml" => sparkNLPReader.xml
237240
case _ => throw new IllegalArgumentException(s"Unsupported file type: $extension")
238241
}
239242
}

src/main/scala/com/johnsnowlabs/partition/PartitionTransformer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class PartitionTransformer(override val uid: String)
8686
with HasPowerPointProperties
8787
with HasTextReaderProperties
8888
with HasPdfProperties
89+
with HasXmlReaderProperties
8990
with HasChunkerProperties {
9091

9192
def this() = this(Identifiable.randomUID("PartitionTransformer"))
@@ -157,7 +158,9 @@ class PartitionTransformer(override val uid: String)
157158
"newAfterNChars" -> $(newAfterNChars).toString,
158159
"overlap" -> $(overlap).toString,
159160
"combineTextUnderNChars" -> $(combineTextUnderNChars).toString,
160-
"overlapAll" -> $(overlapAll).toString)
161+
"overlapAll" -> $(overlapAll).toString,
162+
"xmlKeepTags" -> $(xmlKeepTags).toString,
163+
"onlyLeafNodes" -> $(onlyLeafNodes).toString)
161164
val partitionInstance = new Partition(params.asJava)
162165

163166
val inputColum = if (get(inputCols).isDefined) {

src/main/scala/com/johnsnowlabs/reader/SparkNLPReader.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ class SparkNLPReader(
296296
* |-- width_dimension: integer (nullable = true)
297297
* |-- content: binary (nullable = true)
298298
* |-- exception: string (nullable = true)
299-
* |-- pagenum: integer (nullable = true)
300299
* }}}
301300
*
302301
* @param params
@@ -642,4 +641,69 @@ class SparkNLPReader(
642641
default = BLOCK_SPLIT_PATTERN)
643642
}
644643

644+
/** Instantiates class to read XML files.
645+
*
646+
* xmlPath: this is a path to a directory of XML files or a path to an XML file. E.g.,
647+
* "path/xml/files"
648+
*
649+
* ==Example==
650+
* {{{
651+
* val xmlPath = "home/user/xml-directory"
652+
* val sparkNLPReader = new SparkNLPReader()
653+
* val xmlDf = sparkNLPReader.xml(xmlPath)
654+
* }}}
655+
*
656+
* ==Example 2==
657+
* You can use SparkNLP for one line of code
658+
* {{{
659+
* val xmlDf = SparkNLP.read.xml(xmlPath)
660+
* }}}
661+
*
662+
* {{{
663+
* xmlDf.select("xml").show(false)
664+
* +------------------------------------------------------------------------------------------------------------------------+
665+
* |xml |
666+
* +------------------------------------------------------------------------------------------------------------------------+
667+
* |[{Title, John Smith, {elementId -> ..., tag -> title}}, {UncategorizedText, Some content..., {elementId -> ...}}] |
668+
* +------------------------------------------------------------------------------------------------------------------------+
669+
*
670+
* xmlDf.printSchema()
671+
* root
672+
* |-- path: string (nullable = true)
673+
* |-- xml: array (nullable = true)
674+
* | |-- element: struct (containsNull = true)
675+
* | | |-- elementType: string (nullable = true)
676+
* | | |-- content: string (nullable = true)
677+
* | | |-- metadata: map (nullable = true)
678+
* | | | |-- key: string
679+
* | | | |-- value: string (valueContainsNull = true)
680+
* }}}
681+
*
682+
* @param xmlPath
683+
* Path to the XML file or directory
684+
* @return
685+
* A DataFrame with parsed XML as structured elements
686+
*/
687+
688+
def xml(xmlPath: String): DataFrame = {
689+
val xmlReader = new XMLReader(getStoreContent, getXmlKeepTags, getOnlyLeafNodes)
690+
xmlReader.read(xmlPath)
691+
}
692+
693+
def xmlToHTMLElement(xml: String): Seq[HTMLElement] = {
694+
val xmlReader = new XMLReader(getStoreContent, getXmlKeepTags, getOnlyLeafNodes)
695+
xmlReader.parseXml(xml)
696+
}
697+
698+
private def getXmlKeepTags: Boolean = {
699+
getDefaultBoolean(params.asScala.toMap, Seq("xmlKeepTags", "xml_keep_tags"), default = false)
700+
}
701+
702+
private def getOnlyLeafNodes: Boolean = {
703+
getDefaultBoolean(
704+
params.asScala.toMap,
705+
Seq("onlyLeafNodes", "only_leaf_nodes"),
706+
default = true)
707+
}
708+
645709
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2017-2025 John Snow Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.johnsnowlabs.reader
17+
18+
import com.johnsnowlabs.nlp.util.io.ResourceHelper
19+
import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile
20+
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
21+
import org.apache.spark.sql.DataFrame
22+
import org.apache.spark.sql.functions.{col, udf}
23+
24+
import scala.collection.mutable
25+
import scala.collection.mutable.ListBuffer
26+
import scala.xml.{Elem, Node, XML}
27+
28+
class XMLReader(
29+
storeContent: Boolean = false,
30+
xmlKeepTags: Boolean = false,
31+
onlyLeafNodes: Boolean = true)
32+
extends Serializable {
33+
34+
private lazy val spark = ResourceHelper.spark
35+
36+
private var outputColumn = "xml"
37+
38+
def setOutputColumn(value: String): this.type = {
39+
require(value.nonEmpty, "Output column name cannot be empty.")
40+
outputColumn = value
41+
this
42+
}
43+
44+
def read(inputSource: String): DataFrame = {
45+
if (validFile(inputSource)) {
46+
val xmlDf = datasetWithTextFile(spark, inputSource)
47+
.withColumn(outputColumn, parseXmlUDF(col("content")))
48+
if (storeContent) xmlDf.select("path", "content", outputColumn)
49+
else xmlDf.select("path", outputColumn)
50+
} else throw new IllegalArgumentException(s"Invalid inputSource: $inputSource")
51+
}
52+
53+
private val parseXmlUDF = udf((xml: String) => {
54+
parseXml(xml)
55+
})
56+
57+
def parseXml(xmlString: String): List[HTMLElement] = {
58+
val xml = XML.loadString(xmlString)
59+
val elements = ListBuffer[HTMLElement]()
60+
61+
def traverse(node: Node, parentId: Option[String]): Unit = {
62+
node match {
63+
case elem: Elem =>
64+
val tagName = elem.label.toLowerCase
65+
val textContent = elem.text.trim
66+
val elementId = hash(tagName + textContent)
67+
68+
val isLeaf = !elem.child.exists(_.isInstanceOf[Elem])
69+
70+
if (!onlyLeafNodes || isLeaf) {
71+
val elementType = tagName match {
72+
case "title" | "author" => ElementType.TITLE
73+
case _ => ElementType.UNCATEGORIZED_TEXT
74+
}
75+
76+
val metadata = mutable.Map[String, String]("elementId" -> elementId)
77+
if (xmlKeepTags) metadata += ("tag" -> tagName)
78+
parentId.foreach(id => metadata += ("parentId" -> id))
79+
80+
val content = if (isLeaf) textContent else ""
81+
elements += HTMLElement(elementType, content, metadata)
82+
}
83+
84+
// Traverse children
85+
elem.child.foreach(traverse(_, Some(elementId)))
86+
87+
case _ => // Ignore other types
88+
}
89+
}
90+
91+
traverse(xml, None)
92+
elements.toList
93+
}
94+
95+
def hash(s: String): String = {
96+
java.security.MessageDigest
97+
.getInstance("MD5")
98+
.digest(s.getBytes)
99+
.map("%02x".format(_))
100+
.mkString
101+
}
102+
103+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<library>
2+
<section name="Fiction">
3+
<shelf number="1">
4+
<book>
5+
<title>The Alchemist</title>
6+
<author>Paulo Coelho</author>
7+
<year>1988</year>
8+
</book>
9+
</shelf>
10+
</section>
11+
<section name="Science">
12+
<shelf number="2">
13+
<book>
14+
<title>A Brief History of Time</title>
15+
<author>Stephen Hawking</author>
16+
<year>1988</year>
17+
</book>
18+
</shelf>
19+
</section>
20+
</library>
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<bookstore>
2+
<book category="children">
3+
<title lang="en">Harry Potter</title>
4+
<author>J K. Rowling</author>
5+
<year>2005</year>
6+
<price>29.99</price>
7+
</book>
8+
<book category="web">
9+
<title lang="en">Learning XML</title>
10+
<author>Erik T. Ray</author>
11+
<year>2003</year>
12+
<price>39.95</price>
13+
</book>
14+
</bookstore>

src/test/scala/com/johnsnowlabs/partition/PartitionTest.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class PartitionTest extends AnyFlatSpec {
3232
val emailDirectory = "src/test/resources/reader/email"
3333
val htmlDirectory = "src/test/resources/reader/html"
3434
val pdfDirectory = "src/test/resources/reader/pdf"
35+
val xmlDirectory = "src/test/resources/reader/xml"
3536

3637
"Partition" should "work with text content_type" taggedAs FastTest in {
3738
val textDf = Partition(Map("content_type" -> "text/plain")).partition(txtDirectory)
@@ -181,4 +182,11 @@ class PartitionTest extends AnyFlatSpec {
181182
assert(elements == expectedElements)
182183
}
183184

185+
it should "work with XML content_type" taggedAs FastTest in {
186+
val pdfDf = Partition(Map("content_type" -> "application/xml")).partition(xmlDirectory)
187+
pdfDf.show()
188+
189+
assert(!pdfDf.select(col("xml")).isEmpty)
190+
}
191+
184192
}

0 commit comments

Comments
 (0)