In [1]:
from langchain_community.document_loaders import TextLoader
import torch
from transformers import BertTokenizer, BertForTokenClassification
import json

# Load the BERT model and tokenizer
model_path = './nlp_bert_document-segmentation_chinese-base'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForTokenClassification.from_pretrained(model_path)

def segment_long_text(text, stride=256, max_length=512):
    # Tokenize input
    tokens = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_length, truncation=True)
    input_ids = tokens.input_ids
    
    # Prepare overlapping windows for long texts
    input_chunks = [input_ids[i:i + max_length] for i in range(0, len(input_ids), max_length - stride)]
    
    # Collect segments from each chunk
    segments = []
    for chunk in input_chunks:
        inputs = {'input_ids': torch.tensor([chunk])}
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=-1)
        tokens = tokenizer.convert_ids_to_tokens(chunk)
        current_segment = []
        for token, prediction in zip(tokens, predictions[0]):
            current_segment.append(token)
            if prediction == 1 and len(current_segment) >= min_segment_length:
                segments.append(''.join(current_segment))
                current_segment = []
        if current_segment:
            segments.append(''.join(current_segment))
    return segments


# Load documents
loader = TextLoader("AZ.txt", encoding="utf-8")
documents = loader.load()

# Segment documents using BERT model
segmented_docs = []
for doc in documents:
    segments = segment_long_text(doc.page_content)
    for segment in segments:
        segmented_docs.append({"page_content": segment, "metadata": doc.metadata})

# Save the segmented documents to a JSON file
with open('segmented_docs.json', 'w', encoding='utf-8') as f:
    json.dump(segmented_docs, f, ensure_ascii=False, indent=4)

  from .autonotebook import tqdm as notebook_tqdm


NameError: name 'segment_text' is not defined