Skip to content

Commit 3e0aa61

Browse files
danilojslDevinTDHa
authored andcommitted
[SPARKNLP-1163] Adding title chunking strategy (#14594)
1 parent 2aae5a5 commit 3e0aa61

File tree

11 files changed

+324
-15
lines changed

11 files changed

+324
-15
lines changed

python/sparknlp/partition/partition_properties.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def setThreshold(self, value):
256256
def getThreshold(self):
257257
return self.getOrDefault(self.threshold)
258258

259-
class HasSemanticChunkerProperties(Params):
259+
class HasChunkerProperties(Params):
260260

261261
chunkingStrategy = Param(
262262
Params._dummy(),
@@ -296,4 +296,24 @@ def setNewAfterNChars(self, value):
296296
)
297297

298298
def setOverlap(self, value):
299-
return self._set(overlap=value)
299+
return self._set(overlap=value)
300+
301+
combineTextUnderNChars = Param(
302+
Params._dummy(),
303+
"combineTextUnderNChars",
304+
"Threshold to merge adjacent small sections",
305+
typeConverter=TypeConverters.toInt
306+
)
307+
308+
def setCombineTextUnderNChars(self, value):
309+
return self._set(combineTextUnderNChars=value)
310+
311+
overlapAll = Param(
312+
Params._dummy(),
313+
"overlapAll",
314+
"Apply overlap context between all sections, not just split chunks",
315+
typeConverter=TypeConverters.toBoolean
316+
)
317+
318+
def setOverlapAll(self, value):
319+
return self._set(overlapAll=value)

python/sparknlp/partition/partition_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class PartitionTransformer(
2323
HasHTMLReaderProperties,
2424
HasPowerPointProperties,
2525
HasTextReaderProperties,
26-
HasSemanticChunkerProperties
26+
HasChunkerProperties
2727
):
2828
"""
2929
The PartitionTransformer annotator allows you to use the Partition feature more smoothly
@@ -194,5 +194,7 @@ def __init__(self, classname="com.johnsnowlabs.partition.PartitionTransformer",
194194
chunkingStrategy="",
195195
maxCharacters=100,
196196
newAfterNChars=-1,
197-
overlap=0
197+
overlap=0,
198+
combineTextUnderNChars=0,
199+
overlapAll=False
198200
)

src/main/scala/com/johnsnowlabs/partition/BasicChunker.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ import com.johnsnowlabs.reader.HTMLElement
1919

2020
import scala.collection.mutable
2121

22-
case class Chunk(elements: List[HTMLElement]) {
23-
def length: Int = elements.map(_.content.length).sum
24-
}
25-
2622
object BasicChunker {
2723

2824
/** Splits a list of [[HTMLElement]]s into chunks constrained by a maximum number of characters.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.johnsnowlabs.partition
2+
3+
import com.johnsnowlabs.reader.HTMLElement
4+
5+
case class Chunk(elements: List[HTMLElement]) {
6+
def length: Int = elements.map(_.content.length).sum
7+
}

src/main/scala/com/johnsnowlabs/partition/HasSemanticChunkerProperties.scala renamed to src/main/scala/com/johnsnowlabs/partition/HasChunkerProperties.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package com.johnsnowlabs.partition
1818
import com.johnsnowlabs.nlp.ParamsAndFeaturesWritable
1919
import org.apache.spark.ml.param.Param
2020

21-
trait HasSemanticChunkerProperties extends ParamsAndFeaturesWritable {
21+
trait HasChunkerProperties extends ParamsAndFeaturesWritable {
2222

2323
val chunkingStrategy = new Param[String](this, "chunkingStrategy", "Set the chunking strategy")
2424

@@ -39,6 +39,26 @@ trait HasSemanticChunkerProperties extends ParamsAndFeaturesWritable {
3939

4040
def setOverlap(value: Int): this.type = set(overlap, value)
4141

42-
setDefault(chunkingStrategy -> "", maxCharacters -> 100, newAfterNChars -> -1, overlap -> 0)
42+
val combineTextUnderNChars =
43+
new Param[Int](this, "combineTextUnderNChars", "Threshold to merge adjacent small sections")
44+
45+
def setComBineTextUnderNChars(value: Int): this.type =
46+
set(combineTextUnderNChars, value)
47+
48+
val overlapAll =
49+
new Param[Boolean](
50+
this,
51+
"overlapAll",
52+
"Apply overlap context between all sections, not just split chunks")
53+
54+
def setOverlapAll(value: Boolean): this.type = set(overlapAll, value)
55+
56+
setDefault(
57+
chunkingStrategy -> "",
58+
maxCharacters -> 100,
59+
newAfterNChars -> -1,
60+
overlap -> 0,
61+
combineTextUnderNChars -> 0,
62+
overlapAll -> false)
4363

4464
}

src/main/scala/com/johnsnowlabs/partition/Partition.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
146146

147147
val partitionResult = reader(path)
148148
if (hasChunkerStrategy) {
149-
val chunker = new SemanticChunker(params.asScala.toMap)
149+
val chunker = new PartitionChunker(params.asScala.toMap)
150150
partitionResult.withColumn(
151151
"chunks",
152152
chunker.chunkUDF()(partitionResult(sparkNLPReader.getOutputColumn)))

src/main/scala/com/johnsnowlabs/partition/SemanticChunker.scala renamed to src/main/scala/com/johnsnowlabs/partition/PartitionChunker.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,20 @@
1616
package com.johnsnowlabs.partition
1717

1818
import com.johnsnowlabs.partition.BasicChunker.chunkBasic
19+
import com.johnsnowlabs.partition.TitleChunker.chunkByTitle
1920
import com.johnsnowlabs.reader.HTMLElement
20-
import com.johnsnowlabs.reader.util.PartitionOptions.{getDefaultInt, getDefaultString}
21+
import com.johnsnowlabs.reader.util.PartitionOptions.{
22+
getDefaultBoolean,
23+
getDefaultInt,
24+
getDefaultString
25+
}
2126
import org.apache.spark.sql.Row
2227
import org.apache.spark.sql.expressions.UserDefinedFunction
2328
import org.apache.spark.sql.functions.udf
2429

2530
import scala.collection.mutable
2631

27-
class SemanticChunker(chunkerOptions: Map[String, String]) extends Serializable {
32+
class PartitionChunker(chunkerOptions: Map[String, String]) extends Serializable {
2833

2934
def chunkUDF(): UserDefinedFunction = {
3035
udf((elements: Seq[Row]) => {
@@ -37,6 +42,14 @@ class SemanticChunker(chunkerOptions: Map[String, String]) extends Serializable
3742

3843
val chunks = getChunkerStrategy match {
3944
case "basic" => chunkBasic(htmlElements, getMaxCharacters, getNewAfterNChars, getOverlap)
45+
case "byTitle" | "by_title" =>
46+
chunkByTitle(
47+
htmlElements,
48+
getMaxCharacters,
49+
getCombineTextUnderNChars,
50+
getOverlap,
51+
getNewAfterNChars,
52+
getOverlapAll)
4053
case _ =>
4154
throw new IllegalArgumentException(s"Unknown chunker strategy: $getChunkerStrategy")
4255
}
@@ -64,4 +77,15 @@ class SemanticChunker(chunkerOptions: Map[String, String]) extends Serializable
6477
default = "none")
6578
}
6679

80+
private def getCombineTextUnderNChars: Int = {
81+
getDefaultInt(
82+
chunkerOptions,
83+
Seq("combineTextUnderNChars", "combine_text_under_n_chars"),
84+
default = 0)
85+
}
86+
87+
private def getOverlapAll: Boolean = {
88+
getDefaultBoolean(chunkerOptions, Seq("overlapAll", "overlap_all"), default = false)
89+
}
90+
6791
}

src/main/scala/com/johnsnowlabs/partition/PartitionTransformer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class PartitionTransformer(override val uid: String)
8686
with HasPowerPointProperties
8787
with HasTextReaderProperties
8888
with HasPdfProperties
89-
with HasSemanticChunkerProperties {
89+
with HasChunkerProperties {
9090

9191
def this() = this(Identifiable.randomUID("PartitionTransformer"))
9292
protected val logger: Logger = LoggerFactory.getLogger(getClass.getName)
@@ -155,7 +155,9 @@ class PartitionTransformer(override val uid: String)
155155
"chunkingStrategy" -> $(chunkingStrategy),
156156
"maxCharacters" -> $(maxCharacters).toString,
157157
"newAfterNChars" -> $(newAfterNChars).toString,
158-
"overlap" -> $(overlap).toString)
158+
"overlap" -> $(overlap).toString,
159+
"combineTextUnderNChars" -> $(combineTextUnderNChars).toString,
160+
"overlapAll" -> $(overlapAll).toString)
159161
val partitionInstance = new Partition(params.asJava)
160162

161163
val inputColum = if (get(inputCols).isDefined) {
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright 2017-2025 John Snow Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.johnsnowlabs.partition
17+
18+
import com.johnsnowlabs.reader.{ElementType, HTMLElement}
19+
20+
import scala.collection.mutable
21+
22+
object TitleChunker {
23+
24+
/** Splits a list of HTML elements into semantically grouped Chunks based on Title and Table
25+
* markers.
26+
*
27+
* @param elements
28+
* List of input HTML elements to chunk.
29+
* @param maxCharacters
30+
* Maximum length allowed per chunk. Longer sections are split.
31+
* @param combineTextUnderNChars
32+
* Threshold to merge adjacent small sections.
33+
* @param overlap
34+
* Number of characters to repeat between consecutive chunks.
35+
* @param newAfterNChars
36+
* Soft limit to trigger new section if length exceeded, even before maxCharacters.
37+
* @param overlapAll
38+
* Apply overlap context between all sections, not just split chunks.
39+
* @return
40+
* List of Chunks partitioned by title and content heuristics.
41+
*/
42+
def chunkByTitle(
43+
elements: List[HTMLElement],
44+
maxCharacters: Int,
45+
combineTextUnderNChars: Int = 0,
46+
overlap: Int = 0,
47+
newAfterNChars: Int = -1,
48+
overlapAll: Boolean = false): List[Chunk] = {
49+
50+
val softLimit = if (newAfterNChars <= 0) maxCharacters else newAfterNChars
51+
val chunks = mutable.ListBuffer.empty[Chunk]
52+
val sections = mutable.ListBuffer.empty[List[HTMLElement]]
53+
var currentSection = List.empty[HTMLElement]
54+
var currentLength = 0
55+
var currentPage = -1
56+
57+
for (element <- elements) {
58+
val elementLength = element.content.length
59+
val isTable = element.elementType == "Table"
60+
val elementPage = element.metadata.getOrElse("pageNumber", "-1").toInt
61+
62+
val pageChanged = currentPage != -1 && elementPage != currentPage
63+
val softLimitExceeded = currentSection.length >= 2 &&
64+
(currentLength + elementLength > softLimit)
65+
66+
if (isTable) {
67+
if (currentSection.nonEmpty) sections += currentSection
68+
sections += List(element)
69+
currentSection = List.empty
70+
currentLength = 0
71+
currentPage = -1
72+
} else if (pageChanged || softLimitExceeded) {
73+
if (currentSection.nonEmpty) sections += currentSection
74+
currentSection = List(element)
75+
currentLength = elementLength
76+
currentPage = elementPage
77+
} else {
78+
currentSection :+= element
79+
currentLength += elementLength
80+
currentPage = elementPage
81+
}
82+
}
83+
if (currentSection.nonEmpty) sections += currentSection
84+
85+
val mergedSections = sections.foldLeft(List.empty[List[HTMLElement]]) { (acc, section) =>
86+
val sectionLength = section.map(_.content.length).sum
87+
val canMerge = combineTextUnderNChars > 0 &&
88+
sectionLength < combineTextUnderNChars &&
89+
acc.nonEmpty &&
90+
acc.last.exists(_.elementType != "Table") &&
91+
section.exists(_.elementType != "Table")
92+
93+
if (canMerge) {
94+
acc.init :+ (acc.last ++ section)
95+
} else {
96+
acc :+ section
97+
}
98+
}
99+
100+
var lastNarrativeText = ""
101+
for (section <- mergedSections) {
102+
if (section.exists(_.elementType == "Table")) {
103+
chunks += Chunk(section)
104+
lastNarrativeText = ""
105+
} else {
106+
val sectionText = section.map(_.content).mkString(" ")
107+
val content =
108+
if (overlap > 0 && lastNarrativeText.nonEmpty && (overlapAll || sectionText.length > maxCharacters))
109+
lastNarrativeText.takeRight(overlap) + " " + sectionText
110+
else sectionText
111+
112+
val merged = HTMLElement(ElementType.NARRATIVE_TEXT, content.trim, section.head.metadata)
113+
val split = if (content.length > maxCharacters) {
114+
splitHTMLElement(merged, maxCharacters, overlap)
115+
} else List(merged)
116+
117+
chunks ++= split.map(e => Chunk(List(e)))
118+
lastNarrativeText = sectionText
119+
}
120+
}
121+
122+
chunks.toList
123+
}
124+
125+
private def splitHTMLElement(
126+
element: HTMLElement,
127+
maxLen: Int,
128+
overlap: Int): List[HTMLElement] = {
129+
130+
val words = element.content.split(" ")
131+
val buffer = mutable.ListBuffer.empty[HTMLElement]
132+
var chunk = new StringBuilder
133+
134+
for (word <- words) {
135+
if (chunk.length + word.length + 1 > maxLen) {
136+
val text = chunk.toString().trim
137+
buffer += element.copy(content = text)
138+
chunk = new StringBuilder
139+
if (overlap > 0 && text.length >= overlap)
140+
chunk.append(text.takeRight(overlap)).append(" ")
141+
}
142+
chunk.append(word).append(" ")
143+
}
144+
145+
if (chunk.nonEmpty)
146+
buffer += element.copy(content = chunk.toString().trim)
147+
148+
buffer.toList
149+
}
150+
151+
}

src/test/scala/com/johnsnowlabs/partition/PartitionChunkerTest.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.scalatest.flatspec.AnyFlatSpec
2323
class PartitionChunkerTest extends AnyFlatSpec {
2424
import ResourceHelper.spark.implicits._
2525
val txtDirectory = "src/test/resources/reader/txt"
26+
val htmlDirectory = "src/test/resources/reader/html"
2627

2728
"Partition" should "perform basic chunk text" taggedAs FastTest in {
2829
val partitionOptions = Map("contentType" -> "text/plain", "chunkingStrategy" -> "basic")
@@ -38,4 +39,17 @@ class PartitionChunkerTest extends AnyFlatSpec {
3839
assert(chunkDf.count() > 1)
3940
}
4041

42+
it should "perform chunking by title" taggedAs FastTest in {
43+
val partitionOptions = Map(
44+
"contentType" -> "text/html",
45+
"titleFontSize" -> "14",
46+
"chunkingStrategy" -> "byTitle",
47+
"combineTextUnderNChars" -> "50")
48+
val textDf = Partition(partitionOptions).partition(s"$htmlDirectory/fake-html.html")
49+
50+
val partitionDf = textDf.select(explode($"chunks.content"))
51+
partitionDf.show(truncate = false)
52+
assert(partitionDf.count() == 2)
53+
}
54+
4155
}

0 commit comments

Comments
 (0)