# Test of Greg Kamradt's semantic chunking for Chinese

In [None]:
with open('../../../text_files/your_cat_part_3.txt') as file:
    essay = file.read()

### Single-pattern split

In [None]:
import re
pattern = r'(?<=[。？；!])\s*|(?<=\.)\s*\n'
single_sentences_list = re.split(pattern, essay)
print(f"{len(single_sentences_list)} senteneces were found")
print(f"max length of sentence: {max(len(sentence) for sentence in single_sentences_list)}")

### Multi-patterns split

In [None]:
import re

pattern1 = r'={16}'
pattern2 = r'(?<=[。？；!])\s*|(?<=\.)\s*\n'
patterns = [pattern1, pattern2]

def split_by_patterns(text, patterns):
    result = [text]
    for pattern in patterns:
        temp = []
        for segment in result:
            temp.extend(re.split(pattern, segment))
        result = temp
    return result


In [None]:
single_sentences_list = split_by_patterns(essay, patterns)

In [None]:
print(f"{len(single_sentences_list)} senteneces were found")
print(f"max length of sentence: {max(len(sentence) for sentence in single_sentences_list)}")

In [None]:
# print(single_sentences_list[0:5])
for i, sentence in enumerate(single_sentences_list[:5]):
    print(f"sentence {i}: {sentence}")

In [None]:
sentences = [{'sentence': x, 'index' : i} for i, x in enumerate(single_sentences_list)]

### function for combine sentences into group
ex: group1: 1,2,3; group2: 2,3,4; group3: 3,4,5...
buffer_size is the group size

In [None]:
def combine_sentences(sentences, buffer_size=1):
    # Go through each sentence dict
    for i in range(len(sentences)):

        # Create a string that will hold the sentences which are joined
        combined_sentence = ''

        # Add sentences before the current one, based on the buffer size.
        for j in range(i - buffer_size, i):
            # Check if the index j is not negative (to avoid index out of range like on the first one)
            if j >= 0:
                # Add the sentence at index j to the combined_sentence string
                combined_sentence += sentences[j]['sentence'] + ' '

        # Add the current sentence
        combined_sentence += sentences[i]['sentence']

        # Add sentences after the current one, based on the buffer size
        for j in range(i + 1, i + 1 + buffer_size):
            # Check if the index j is within the range of the sentences list
            if j < len(sentences):
                # Add the sentence at index j to the combined_sentence string
                combined_sentence += ' ' + sentences[j]['sentence']

        # Then add the whole thing to your dict
        # Store the combined sentence in the current sentence dict
        sentences[i]['combined_sentence'] = combined_sentence

    return sentences



In [None]:
sentences = combine_sentences(sentences)

In [None]:
sentences[0:3]

In [None]:
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings

embedding = HuggingFaceEmbeddings(
        model_name="aspire/acge_text_embedding",
        model_kwargs={'device': 'cpu'},
        show_progress=True
        )

In [None]:
embeddings = embedding.embed_documents([x['combined_sentence'] for x in sentences])

In [None]:
for i, sentence in enumerate(sentences):
    sentence['combined_sentence_embedding'] = embeddings[i]
print(len(sentences[0].get("combined_sentence_embedding")))

### calculate cosine similarity

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def calculate_cosine_distances(sentences):
    distances = []
    for i in range(len(sentences) - 1):
        embedding_current = sentences[i]['combined_sentence_embedding']
        embedding_next = sentences[i + 1]['combined_sentence_embedding']
        
        # Calculate cosine similarity
        similarity = cosine_similarity([embedding_current], [embedding_next])[0][0]
        
        # Convert to cosine distance, 0 means the sentences are identical, 1 means they are unrelated, 2 means they are opposite
        distance = 1 - similarity

        # Append cosine distance to the list
        distances.append(distance)

        # Store distance in the dictionary
        sentences[i]['distance_to_next'] = distance

    # Optionally handle the last sentence
    # sentences[-1]['distance_to_next'] = None  # or a default value

    return distances, sentences

In [None]:
distances, sentences = calculate_cosine_distances(sentences)

In [None]:
distances[:3]

In [None]:
import matplotlib.pyplot as plt

# plt.plot(distances)
fig, ax = plt.subplots(figsize=(10, 5))  # 设置宽度为10英寸，高度为5英寸

# 绘制数据
ax.plot(distances)

In [None]:
import numpy as np

fig, ax = plt.subplots(figsize=(15, 5))  # 设置宽度为10英寸，高度为5英寸
ax.plot(distances)
# plt.plot(distances)

y_upper_bound = .3
plt.ylim(0, y_upper_bound)
plt.xlim(0, len(distances))

# We need to get the distance threshold that we'll consider an outlier
# We'll use numpy .percentile() for this
breakpoint_percentile_threshold = 90
breakpoint_distance_threshold = np.percentile(distances, breakpoint_percentile_threshold) # If you want more chunks, lower the percentile cutoff
print(f"95% threshold: {breakpoint_distance_threshold}")
plt.axhline(y=breakpoint_distance_threshold, color='r', linestyle='-')

# Then we'll see how many distances are actually above this one
num_distances_above_theshold = len([x for x in distances if x > breakpoint_distance_threshold]) # The amount of distances above your threshold
plt.text(x=(len(distances)*.01), y=y_upper_bound/50, s=f"{num_distances_above_theshold + 1} Chunks")

# Then we'll get the index of the distances that are above the threshold. This will tell us where we should split our text
indices_above_thresh = [i for i, x in enumerate(distances) if x > breakpoint_distance_threshold] # The indices of those breakpoints on your list

# Start of the shading and text
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
for i, breakpoint_index in enumerate(indices_above_thresh):
    start_index = 0 if i == 0 else indices_above_thresh[i - 1]
    end_index = breakpoint_index if i < len(indices_above_thresh) - 1 else len(distances)

    plt.axvspan(start_index, end_index, facecolor=colors[i % len(colors)], alpha=0.25)
    plt.text(x=np.average([start_index, end_index]),
             y=breakpoint_distance_threshold + (y_upper_bound)/ 20,
             s=f"Chunk #{i}", horizontalalignment='center',
             rotation='vertical')

# # Additional step to shade from the last breakpoint to the end of the dataset
if indices_above_thresh:
    last_breakpoint = indices_above_thresh[-1]
    if last_breakpoint < len(distances):
        plt.axvspan(last_breakpoint, len(distances), facecolor=colors[len(indices_above_thresh) % len(colors)], alpha=0.25)
        plt.text(x=np.average([last_breakpoint, len(distances)]),
                 y=breakpoint_distance_threshold + (y_upper_bound)/ 20,
                 s=f"Chunk #{i+1}",
                 rotation='vertical')

plt.title("PG Essay Chunks Based On Embedding Breakpoints")
plt.xlabel("Index of sentences in essay (Sentence Position)")
plt.ylabel("Cosine distance between sequential sentences")
plt.show()

## Chunking by breakpoint_index

In [None]:
# Initialize the start index
start_index = 0

# Create a list to hold the grouped sentences
chunks = []

# Iterate through the breakpoints to slice the sentences
for index in indices_above_thresh:
    # The end index is the current breakpoint
    end_index = index

    # Slice the sentence_dicts from the current start index to the end index
    group = sentences[start_index:end_index + 1]
    combined_text = ' '.join([d['sentence'] for d in group])
    chunks.append(combined_text)
    
    # Update the start index for the next group
    start_index = index + 1

# The last group, if any sentences remain
if start_index < len(sentences):
    combined_text = ' '.join([d['sentence'] for d in sentences[start_index:]])
    chunks.append(combined_text)

# grouped_sentences now contains the chunked sentences

In [None]:
print(f"Number of chunks: {len(chunks)}")
print(f"Max length:{max(len(chunk) for chunk in chunks)}")
for chunk in chunks:
    print(len(chunk))

## show chunks

In [None]:
# for i, chunk in enumerate(chunks[:4]):
#     buffer = 100
    
#     print (f"Chunk #{i} ###############################")
#     print (chunk[:buffer].strip())
#     print ("...")
#     print (chunk[-buffer:].strip())
#     print ("\n")

In [None]:
for i, chunk in enumerate(chunks):
    print (f"Chunk #{i} ########################################################")
    print (chunk)