In [7]:
"""
Advanced text splitter with metadata extraction for scientific papers.
Supports boundary-aware chunking with production-grade citation/equation handling.
Features:
  - Robust citation detection (numeric, author-year, LaTeX \\cite{})
  - Multi-line LaTeX equation support
  - Chronological boundary detection (preserves text order)
  - Smart merging of small blocks for coherence
  - Rich metadata extraction for reranking
"""

from typing import List, Optional, Dict, Tuple, NamedTuple
import re
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dataclasses import dataclass

@dataclass
class ChunkMetadata:
    """Metadata associated with a text chunk"""
    section: str  # Main section (Abstract, Introduction, etc.)
    subsection: str  # Subsection if available
    chunk_index: int  # Index within the section
    hierarchy_level: int  # 0 for main, 1 for subsection, etc.
    content_type: str  # "text", "equation", "figure", "table", "reference"
    has_citations: bool  # Whether chunk contains citations
    citation_count: int  # Number of citations
    has_equations: bool  # Whether chunk contains equations
    equation_count: int  # Number of equations
    is_abstract: bool  # Whether chunk is from abstract (typically more important)
    is_conclusion: bool  # Whether chunk is from conclusion
    importance_score: float  # 0.0-1.0 score based on content type and position

class BoundaryMatch(NamedTuple):
    """Represents a boundary match with its position"""
    start: int
    end: int
    pattern_type: str  # Type of boundary (e.g., "heading", "equation", "paragraph")

class BoundaryAwareTextSplitter:
    def __init__(
        self,
        chunk_size: int = 512,
        chunk_overlap: int = 50,
        boundary_patterns: Optional[List[Tuple[str, str]]] = None,
        min_chunk_size: int = 100,
        merge_small_chunks: bool = True
    ):
        """
        Initialize boundary-aware text splitter.
        
        Args:
            chunk_size: Target chunk size in characters
            chunk_overlap: Overlap between chunks
            boundary_patterns: List of (pattern, pattern_type) tuples
            min_chunk_size: Minimum size before merging
            merge_small_chunks: Whether to merge small chunks for coherence
        """
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.min_chunk_size = min_chunk_size
        self.merge_small_chunks = merge_small_chunks
        
        # Boundary patterns with types for better chronological ordering
        self.boundary_patterns = boundary_patterns or [
            # Priority 1: Section headers (must be handled first)
            (r"^#{1,6}\s+(.+)$", "heading_markdown"),  # Markdown headers
            (r"^(\d+(?:\.\d+)*)\s+([A-Z].+?)(?:\n|$)", "heading_numbered"),  # Numbered sections
            
            # Priority 2: Major section boundaries (with case-insensitive for common variations)
            (r"^(Abstract|Introduction|Background|Methodology|Method|Results|"
             r"Discussion|Conclusion|Acknowledgments?|References|Appendix)\b", "section_boundary"),
            (r"^(?:related\s+works?|related\s+work)\b", "section_boundary", re.IGNORECASE),  # Related Work variations
            
            # Priority 3: Equations (multi-line support)
            (r"\\begin\{(?:equation|align|align\*|gather|gather\*|multline|displaymath|equation\*)\}.+?"
             r"\\end\{(?:equation|align|align\*|gather|gather\*|multline|displaymath|equation\*)\}", "equation_block"),
            (r"\$\$.+?\$\$", "equation_inline_display"),
            (r"\\\[.+?\\\]", "equation_latex_brackets"),
            
            # Priority 4: Tables and figures (CAPTION patterns must flush chunk)
            (r"^\s*(?:Table|Figure)\s+\d+(?:\.\d+)?:.*$", "caption"),
            (r"^\s*(?:Table|Figure)\s*\d+", "table_figure"),
            
            # Priority 5: Captions and centered content
            (r"<center>.+?</center>", "centered_text"),
            
            # Priority 6: Lists (numbered and bulleted)
            (r"^[\s]*[\d]+\.\s+", "numbered_list"),
            (r"^[\s]*[-•*]\s+", "bullet_list"),
            
            # Priority 7: Paragraph boundaries
            (r"\n\n+", "paragraph_break"),
            
            # Priority 8: Citations (must be last to not interfere with others)
            (r"\\\cite\{[^}]+\}", "cite_latex"),
            (r"\[(?:\d+(?:,\s*\d+)*|\d+\s*[–-]\s*\d+)\]", "cite_numeric"),
            (r"\((?:[A-Za-z\s]+(?:et\s+al\.)?),?\s*\d{4}[a-z]?\)", "cite_author_year"),
            (r"[A-Z][a-z]+\s+(?:et\s+al\.)?,\s*\d{4}", "cite_author_year_inline"),
        ]
        
        # Initialize LangChain's RecursiveCharacterTextSplitter as backup
        self.backup_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", " ", ""]
        )

    def _find_all_boundaries(self, text: str) -> List[BoundaryMatch]:
        """
        Find ALL boundaries in chronological order (by position in text).
        This is crucial for maintaining text order and avoiding duplicate splits.
        Handles regex flags properly for case-insensitive patterns.
        
        Returns:
            List of BoundaryMatch sorted by start position
        """
        all_matches = []
        
        # Use DOTALL flag to handle multi-line patterns
        for pattern_tuple in self.boundary_patterns:
            # Handle both 2-tuple and 3-tuple formats (pattern, type) or (pattern, type, flags)
            if len(pattern_tuple) == 3:
                pattern, pattern_type, flags = pattern_tuple
                base_flags = re.MULTILINE | re.DOTALL | flags
            else:
                pattern, pattern_type = pattern_tuple
                base_flags = re.MULTILINE | re.DOTALL
            
            try:
                for match in re.finditer(pattern, text, base_flags):
                    # Avoid duplicates at same position
                    if not any(m.start == match.start() and m.end == match.end() for m in all_matches):
                        all_matches.append(BoundaryMatch(
                            start=match.start(),
                            end=match.end(),
                            pattern_type=pattern_type
                        ))
            except re.error:
                # Silently ignore regex errors if a pattern is bad
                continue
        
        # Sort by start position (chronological order)
        all_matches.sort(key=lambda x: (x.start, -x.end))  # Sort by start, then longer matches first
        
        # Remove overlapping boundaries (keep the first/longest one)
        filtered_matches = []
        for match in all_matches:
            # Check if this boundary overlaps with an already-added boundary
            overlaps = any(
                (fm.start <= match.start < fm.end) or (fm.start < match.end <= fm.end)
                for fm in filtered_matches
            )
            if not overlaps:
                filtered_matches.append(match)
        
        return filtered_matches

    def _count_citations_advanced(self, chunk: str) -> int:
        """
        Count citations with support for multiple citation formats.
        Handles: [1], [1,2], [1-3], [1–3], (Smith et al., 2023), Smith (2022), \\cite{...}
        """
        count = 0
        
        # Numeric citations: [1], [1, 2], [1–3], [1-3]
        numeric_citations = re.findall(
            r"\[(?:\d+(?:[,\s]+\d+)*|(?:\d+\s*[–-]\s*\d+))\]",
            chunk
        )
        count += len(numeric_citations)
        
        # Author-year citations: (Smith et al., 2023), (Smith, 2023)
        author_year = re.findall(
            r"\([A-Za-z\s]+(?:et\s+al\.)?[,\s]*\d{4}[a-z]?\)",
            chunk
        )
        count += len(author_year)
        
        # Author-year inline: Smith et al. (2023), Smith (2022)
        author_inline = re.findall(
            r"[A-Z][a-z]+(?:\s+et\s+al\.)?[,\s]*\(\d{4}[a-z]?\)",
            chunk
        )
        count += len(author_inline)
        
        # LaTeX citations: \cite{...}
        latex_cites = re.findall(r"\\\cite\{[^}]+\}", chunk)
        count += len(latex_cites)
        
        return count

    def _count_equations_advanced(self, chunk: str) -> int:
        """
        Count equations with support for multi-line LaTeX environments.
        Handles nested structures and various equation environments.
        """
        count = 0
        
        # Display equations with $$
        count += len(re.findall(r"\$\$", chunk)) // 2
        
        # LaTeX bracket equations \\[...\\]
        count += len(re.findall(r"\\\[.*?\\\]", chunk, re.DOTALL))
        
        # Equation environments (align, gather, multline, etc.)
        count += len(re.findall(
            r"\\begin\{(?:equation|align|align\*|gather|gather\*|multline|displaymath|equation\*)\}",
            chunk
        ))
        
        # Inline math with single $
        inline_math = re.findall(r"(?<!\$)\$(?!\$)[^$]+\$(?!\$)", chunk)
        count += len(inline_math)
        
        return count # Đã xóa logger

    def _extract_section_info(self, text: str, boundary_pattern_type: Optional[str] = None) -> Tuple[str, int]:
        """
        Extract section name and hierarchy level from text.
        Prioritizes boundary pattern type to avoid mis-detection in multi-line content.
        
        Args:
            text: The text chunk to analyze
            boundary_pattern_type: The pattern type from boundary detection (if available)
            
        Returns:
            Tuple of (section_name, hierarchy_level)
        """
        # Priority 1: Use boundary pattern type if available (most reliable)
        if boundary_pattern_type == "section_boundary":
            # Extract section name from text
            match = re.search(
                r"^(?:related\s+works?|related\s+work|abstract|introduction|background|methodology|method|"
                r"results|discussion|conclusion|acknowledgments?|references|appendix)\b",
                text,
                re.MULTILINE | re.IGNORECASE
            )
            if match:
                return match.group(0).title(), 1
        
        # Priority 2: Check for markdown headers
        headers = {
            r"^#\s+(.+)$": 1,
            r"^##\s+(.+)$": 2,
            r"^###\s+(.+)$": 3,
            r"^####\s+(.+)$": 4,
            r"^#####\s+(.+)$": 5,
            r"^######\s+(.+)$": 6,
        }
        
        for pattern, level in headers.items():
            match = re.search(pattern, text, re.MULTILINE)
            if match:
                return match.group(1).strip(), level
        
        # Priority 3: Check for numbered sections
        numbered_match = re.search(r"^(\d+(?:\.\d+)*)\s+(.+)$", text, re.MULTILINE)
        if numbered_match:
            section_name = numbered_match.group(2).strip()
            level = len(numbered_match.group(1).split('.'))
            return section_name, level
        
        return "General", 0

    def _identify_content_type(self, chunk: str) -> str:
        """Identify the type of content in the chunk"""
        if re.search(r"\$\$|\\begin\{|\\end\{|\\\[|\\\]", chunk):
            return "equation"
        elif re.search(r"^(Table|Figure)\s+\d+", chunk, re.MULTILINE):
            return "table" if chunk.startswith("Table") else "figure"
        elif re.search(r"^\[\d+\]\s+", chunk, re.MULTILINE):
            return "reference"
        elif re.search(r"^(Abstract|Introduction|Methodology|Results|Discussion|Conclusion)", chunk):
            return "section_header"
        else:
            return "text"

    def _extract_metadata(
        self,
        chunk: str,
        chunk_index: int,
        current_section: str,
        boundary_pattern_type: Optional[str] = None
    ) -> Dict:
        """
        Extract comprehensive metadata for a chunk.
        
        Args:
            chunk: The text chunk
            chunk_index: Index of chunk
            current_section: Current section name
            boundary_pattern_type: Type of boundary pattern detected (for accurate section detection)
        """
        # Count citations (advanced)
        citation_count = self._count_citations_advanced(chunk)
        
        # Count equations (advanced)
        equation_count = self._count_equations_advanced(chunk)
        
        # Determine content type
        content_type = self._identify_content_type(chunk)
        
        # Extract section info - pass boundary type for accurate detection
        section_name, hierarchy_level = self._extract_section_info(chunk, boundary_pattern_type)
        
        # Determine if abstract or conclusion
        is_abstract = "abstract" in section_name.lower()
        is_conclusion = "conclusion" in section_name.lower()
        
        # Calculate importance score
        importance_score = self._calculate_importance_score(
            content_type=content_type,
            hierarchy_level=hierarchy_level,
            is_abstract=is_abstract,
            is_conclusion=is_conclusion,
            citation_count=citation_count
        )
        
        # Use detected section or keep current section
        if section_name != "General":
            current_section = section_name
        
        metadata = {
            "section": current_section,
            "subsection": section_name if section_name != "General" else "",
            "chunk_index": chunk_index,
            "hierarchy_level": hierarchy_level,
            "content_type": content_type,
            "has_citations": citation_count > 0,
            "citation_count": citation_count,
            "has_equations": equation_count > 0,
            "equation_count": equation_count,
            "is_abstract": is_abstract,
            "is_conclusion": is_conclusion,
            "importance_score": importance_score,
            "chunk_length": len(chunk),
            "word_count": len(chunk.split()),
            "boundary_pattern_type": boundary_pattern_type or "unknown"
        }
        
        return metadata

    def _calculate_importance_score(
        self,
        content_type: str,
        hierarchy_level: int,
        is_abstract: bool,
        is_conclusion: bool,
        citation_count: int
    ) -> float:
        """Calculate importance score for prioritizing chunks"""
        score = 0.5  # Base score
        
        # Content type impact
        content_weights = {
            "abstract": 1.0,
            "conclusion": 0.9,
            "equation": 0.8,
            "section_header": 0.7,
            "reference": 0.4,
            "table": 0.6,
            "figure": 0.6,
            "text": 0.5
        }
        score += content_weights.get(content_type, 0.5) * 0.3
        
        # Hierarchy level impact (lower level = higher importance)
        if hierarchy_level > 0:
            score += max(0, (4 - hierarchy_level) / 4) * 0.2
        
        # Section importance
        if is_abstract:
            score += 0.3
        if is_conclusion:
            score += 0.25
        
        # Citation impact (more citations = more important)
        citation_bonus = min(citation_count / 5, 0.5) * 0.2
        score += citation_bonus
        
        return min(1.0, score)

    def split_text(self, text: str) -> List[Tuple[str, Dict]]:
        """
        Split text using chronologically-ordered boundary detection with metadata.
        Handles multi-line equations, various citation formats, and maintains text coherence.
        Smart flush for captions to keep them separate from regular text.
        
        Returns:
            List of tuples (chunk_text, metadata)
        """
        # Find all boundaries in chronological order
        boundaries = self._find_all_boundaries(text)
        
        if not boundaries:
            # No boundaries found, treat entire text as one chunk
            metadata = self._extract_metadata(text.strip(), 0, "General")
            return [(text.strip(), metadata)]
        
        chunks_raw = []
        last_pos = 0
        current_section = "Introduction"
        
        # Split text at boundaries in chronological order
        for boundary in boundaries:
            # Add text before boundary
            chunk_before = text[last_pos:boundary.start].strip()
            if chunk_before:
                chunks_raw.append((chunk_before, "text"))
            
            # Add the boundary itself
            boundary_text = text[boundary.start:boundary.end].strip()
            if boundary_text:
                chunks_raw.append((boundary_text, boundary.pattern_type))
            
            last_pos = boundary.end
        
        # Add remaining text after last boundary
        if last_pos < len(text):
            remaining = text[last_pos:].strip()
            if remaining:
                chunks_raw.append((remaining, "text"))
        
        # Merge and filter chunks with smart caption flushing
        chunks_with_metadata = []
        current_chunk = ""
        current_chunk_type = "text"
        chunk_index = 0
        
        for i, (chunk_text, pattern_type) in enumerate(chunks_raw):
            # Smart caption handling: if current chunk has content and we encounter a caption, flush immediately
            is_caption = pattern_type == "caption"
            should_flush_for_caption = is_caption and current_chunk and len(current_chunk) > self.min_chunk_size
            
            # Try to add to current chunk
            if not should_flush_for_caption and len(current_chunk) + len(chunk_text) <= self.chunk_size:
                if current_chunk:
                    current_chunk += "\n" + chunk_text
                else:
                    current_chunk = chunk_text
                    current_chunk_type = pattern_type
            else:
                # Current chunk would exceed size or caption encountered, save it and start new one
                if current_chunk and len(current_chunk) > self.min_chunk_size:
                    metadata = self._extract_metadata(
                        current_chunk,
                        chunk_index,
                        current_section,
                        current_chunk_type
                    )
                    chunks_with_metadata.append((current_chunk, metadata))
                    if metadata["subsection"]:
                        current_section = metadata["subsection"]
                    chunk_index += 1
                
                current_chunk = chunk_text
                current_chunk_type = pattern_type
        
        # Add the last chunk
        if current_chunk and len(current_chunk) > self.min_chunk_size:
            metadata = self._extract_metadata(
                current_chunk,
                chunk_index,
                current_section,
                current_chunk_type
            )
            chunks_with_metadata.append((current_chunk, metadata))
        elif current_chunk:
            # Small chunk at end - try to merge with previous
            if chunks_with_metadata:
                prev_chunk, prev_metadata = chunks_with_metadata[-1]
                merged_chunk = prev_chunk + "\n" + current_chunk
                if len(merged_chunk) <= self.chunk_size * 1.2:  # Allow slight overflow for merging
                    merged_metadata = self._extract_metadata(
                        merged_chunk,
                        prev_metadata["chunk_index"],
                        prev_metadata["section"],
                        prev_metadata.get("boundary_pattern_type", "text")
                    )
                    chunks_with_metadata[-1] = (merged_chunk, merged_metadata)
                else:
                    metadata = self._extract_metadata(
                        current_chunk,
                        chunk_index,
                        current_section,
                        current_chunk_type
                    )
                    chunks_with_metadata.append((current_chunk, metadata))
        
        # Final pass: Split chunks that are still too large
        final_chunks = []
        for chunk, metadata in chunks_with_metadata:
            if len(chunk) > self.chunk_size:
                # Use backup splitter
                backup_chunks = self.backup_splitter.split_text(chunk)
                for i, sub_chunk in enumerate(backup_chunks):
                    new_metadata = metadata.copy()
                    new_metadata["chunk_index"] = len(final_chunks)
                    final_chunks.append((sub_chunk, new_metadata))
            else:
                final_chunks.append((chunk, metadata))
        
        return final_chunks

    def split_documents(self, documents: List[str]) -> List[Tuple[str, Dict]]:
        """
        Split multiple documents while maintaining document boundaries.
        
        Returns:
            List of tuples (chunk_text, metadata)
        """
        all_chunks = []
        for doc in documents:
            chunks = self.split_text(doc)
            all_chunks.extend(chunks)
        return all_chunks

In [8]:
from typing import List, Optional, Dict
import torch
from transformers import AutoTokenizer, AutoModel
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
import numpy as np
import datetime
from pathlib import Path
import sys
import argparse
import os

In [9]:
class DocumentEmbedder:
    def __init__(
        self,
        model_name: str = "BAAI/bge-large-en-v1.5",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        collection_name: str = "scientific_papers",
        dim: int = 1024,
        milvus_host: str = "localhost",
        milvus_port: int = 19530,
        chunk_size: int = 512,
        chunk_overlap: int = 50
    ):
        # Initialize the embedding model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.to(device)
        self.device = device
        self.dim = dim

        # Initialize text splitter
        self.text_splitter = BoundaryAwareTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )

        # Connect to Milvus
        connections.connect(host=milvus_host, port=milvus_port)

        # Create collection if it doesn't exist
        self.collection_name = collection_name
        if not utility.has_collection(collection_name):
            self._create_collection(dim)

        self.collection = Collection(collection_name)
        self.collection.load()

    def _create_collection(self, dim: int):
        """Create a new Milvus collection with rich metadata schema."""
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
            FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
            # Metadata fields for better reranking
            FieldSchema(name="section", dtype=DataType.VARCHAR, max_length=256),
            FieldSchema(name="subsection", dtype=DataType.VARCHAR, max_length=256),
            FieldSchema(name="content_type", dtype=DataType.VARCHAR, max_length=50),
            FieldSchema(name="hierarchy_level", dtype=DataType.INT8),
            FieldSchema(name="importance_score", dtype=DataType.FLOAT),
            FieldSchema(name="citation_count", dtype=DataType.INT32),
            FieldSchema(name="equation_count", dtype=DataType.INT32),
            FieldSchema(name="is_abstract", dtype=DataType.BOOL),
            FieldSchema(name="is_conclusion", dtype=DataType.BOOL),
            FieldSchema(name="word_count", dtype=DataType.INT32),
            FieldSchema(name="chunk_index", dtype=DataType.INT32),
            # Source metadata
            FieldSchema(name="source_file", dtype=DataType.VARCHAR, max_length=512),
            FieldSchema(name="processing_timestamp", dtype=DataType.VARCHAR, max_length=30),
        ]
        schema = CollectionSchema(
            fields=fields,
            description="Scientific paper chunks collection with rich metadata"
        )
        Collection(self.collection_name, schema)

    def _get_embeddings(self, texts: List[str]) -> np.ndarray:
        """Generate embeddings for a list of text chunks."""
        encoded_input = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**encoded_input)
            embeddings = model_output.last_hidden_state[:, 0]
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()

    def read_scientific_paper(self, file_path: str) -> str:
        """Read and preprocess the scientific paper from final_text.txt"""
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        return content

    def process_document(self, text: str, metadata: Optional[Dict] = None, batch_size: int = 32):
        """Process a document by splitting it with metadata and generating embeddings."""
        # Split text into chunks with metadata
        chunks_with_metadata = self.text_splitter.split_text(text)

        # Add document-level metadata
        doc_metadata = metadata or {}
        source_file = doc_metadata.get("source_file", "unknown")
        processing_timestamp = doc_metadata.get("processing_timestamp", datetime.datetime.now().isoformat())

        # Process chunks in batches
        for i in range(0, len(chunks_with_metadata), batch_size):
            batch = chunks_with_metadata[i:i + batch_size]
            batch_chunks = [item[0] for item in batch]
            batch_chunk_metadata = [item[1] for item in batch]

            batch_embeddings = self._get_embeddings(batch_chunks)

            # Insert into Milvus with full metadata
            entities = []
            for chunk, chunk_meta, embedding in zip(
                batch_chunks,
                batch_chunk_metadata,
                batch_embeddings
            ):
                entity = {
                    "text": chunk,
                    "embeddings": embedding.tolist(),
                    # Chunk-level metadata
                    "section": chunk_meta.get("section", "unknown"),
                    "subsection": chunk_meta.get("subsection", ""),
                    "content_type": chunk_meta.get("content_type", "text"),
                    "hierarchy_level": chunk_meta.get("hierarchy_level", 0),
                    "importance_score": chunk_meta.get("importance_score", 0.5),
                    "citation_count": chunk_meta.get("citation_count", 0),
                    "equation_count": chunk_meta.get("equation_count", 0),
                    "is_abstract": chunk_meta.get("is_abstract", False),
                    "is_conclusion": chunk_meta.get("is_conclusion", False),
                    "word_count": chunk_meta.get("word_count", 0),
                    "chunk_index": chunk_meta.get("chunk_index", i),
                    # Source metadata
                    "source_file": source_file,
                    "processing_timestamp": processing_timestamp,
                }
                entities.append(entity)

            self.collection.insert(entities)

        # Flush to ensure data is written
        self.collection.flush()

    def search(
        self,
        query: str,
        top_k: int = 5,
        score_threshold: float = 0.5
    ) -> List[dict]:
        """Search for similar chunks using the query."""
        # Generate query embedding
        query_embedding = self._get_embeddings([query])[0]

        # Search in Milvus
        search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
        results = self.collection.search(
            data=[query_embedding.tolist()],
            anns_field="embeddings",
            param=search_params,
            limit=top_k,
            output_fields=["text", "section", "source_file"] # Thêm output fields cho kết quả
        )

        # Format results
        matches = []
        for hits in results:
            for hit in hits:
                if hit.score >= score_threshold:
                    matches.append({
                        "text": hit.entity.get("text"),
                        "score": hit.score,
                        "section": hit.entity.get("section", "N/A"),
                        "source_file": hit.entity.get("source_file", "N/A"),
                    })

        return matches

    def process_scientific_paper(self, file_path: str, paper_metadata: Optional[Dict] = None):
        """Process a scientific paper from file and store in vector database"""
        content = self.read_scientific_paper(file_path)

        if paper_metadata is None:
            paper_metadata = {
                "source_file": file_path,
                "paper_type": "scientific_paper",
                "processing_timestamp": datetime.datetime.now().isoformat()
            }

        self.process_document(content, metadata=paper_metadata)

    def close(self):
        """Clean up connections."""
        self.collection.release()
        connections.disconnect("default")


def main():
    """
    Hàm main để xử lý các bài báo khoa học và tạo embeddings.
    (Sử dụng các tham số được gán cứng thay vì đối số dòng lệnh)
    """
    
    # --- Các tham số gán cứng ---
    data_dir_str = "../../data/processed"
    file_name = "final_text.txt"
    collection_name = "scientific_papers"
    model_name = "BAAI/bge-large-en-v1.5"
    milvus_host = "localhost"
    milvus_port = 19530
    chunk_size = 512
    chunk_overlap = 50
    
    # Đặt thành True nếu bạn muốn xóa và tạo lại collection
    force_recreate = False 
    
    # Đặt một câu truy vấn để kiểm tra, hoặc đặt là None để bỏ qua
    query = "What is the main challenge of this research?" 
    query_top_k = 5
    # ----------------------------

    # Giải quyết đường dẫn file
    script_dir = Path(__file__).parent.absolute()
    data_dir = script_dir.parent.parent / data_dir_str
    file_path = data_dir / file_name

    # Kiểm tra file tồn tại
    if not file_path.exists():
        print(f"Lỗi: Không tìm thấy file: {file_path}")
        if data_dir.exists():
            print("Các file có sẵn trong thư mục dữ liệu:")
            for file in data_dir.glob("*"):
                print(f"  - {file.name}")
        sys.exit(1) # Thoát nếu không tìm thấy file

    try:
        # Tái tạo collection nếu được yêu cầu
        if force_recreate:
            print(f"Đang xóa collection '{collection_name}' (nếu tồn tại)...")
            connections.connect(
                host=milvus_host,
                port=milvus_port
            )
            if utility.has_collection(collection_name):
                utility.drop_collection(collection_name)
                print("Đã xóa collection cũ.")
            connections.disconnect("default")

        # Khởi tạo DocumentEmbedder
        print("Đang khởi tạo DocumentEmbedder...")
        embedder = DocumentEmbedder(
            model_name=model_name,
            device="cuda" if torch.cuda.is_available() else "cpu",
            collection_name=collection_name,
            dim=1024,  # BGE-large có dim 1024
            milvus_host=milvus_host,
            milvus_port=milvus_port,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        print("Khởi tạo thành công.")

        # Chuẩn bị metadata cho tài liệu
        paper_metadata = {
            "source_file": str(file_path),
            "paper_type": "scientific_paper",
            "processing_timestamp": datetime.datetime.now().isoformat(),
            "file_size_mb": file_path.stat().st_size / 1024 / 1024
        }

        # Xử lý bài báo khoa học
        print(f"Đang xử lý file: {file_path}...")
        embedder.process_scientific_paper(str(file_path), paper_metadata)
        print("Hoàn tất xử lý và nạp dữ liệu vào Milvus.")

        # Kiểm tra bằng truy vấn nếu có
        if query:
            print(f"\nĐang thực hiện truy vấn kiểm tra: '{query}'")
            results = embedder.search(
                query=query,
                top_k=query_top_k,
                score_threshold=0.0 # Lấy tất cả kết quả để xem
            )

            if results:
                print("\n--- Kết quả tìm kiếm ---")
                for i, result in enumerate(results, 1):
                    print(f"[{i}] Điểm tương đồng (Score): {result['score']:.4f}")
                    print(f"    Nguồn: {result['source_file'].split(os.path.sep)[-1]}")
                    print(f"    Phần (Section): {result['section']}")
                    preview = result['text'][:200].replace('\n', ' ')
                    print(f"    Nội dung: {preview}...")
                    print("-" * 25)
            else:
                print("Không tìm thấy kết quả nào cho truy vấn.")

        # Đóng kết nối
        print("\nĐang đóng kết nối...")
        embedder.close()
        print("Hoàn tất.")

    except Exception as e:
        print(f"Đã xảy ra lỗi trong quá trình thực thi: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

In [None]:
def main():
    """
    Hàm main để xử lý các bài báo khoa học và tạo embeddings.
    (Sử dụng các tham số được gán cứng thay vì đối số dòng lệnh)
    """
    
    # --- Các tham số gán cứng ---
    data_dir_str = "../../data/processed"
    file_name = "final_text.txt"
    collection_name = "scientific_papers"
    model_name = "BAAI/bge-large-en-v1.5"
    milvus_host = "localhost"
    milvus_port = 19530
    chunk_size = 512
    chunk_overlap = 50
    file_path = data_dir_str + file_name
    
    # Đặt thành True nếu bạn muốn xóa và tạo lại collection
    force_recreate = False 
    
    # Đặt một câu truy vấn để kiểm tra, hoặc đặt là None để bỏ qua
    query = "What is the main challenge of this research?" 
    query_top_k = 5
    # ----------------------------

    try:
        # Tái tạo collection nếu được yêu cầu
        if force_recreate:
            print(f"Đang xóa collection '{collection_name}' (nếu tồn tại)...")
            connections.connect(
                host=milvus_host,
                port=milvus_port
            )
            if utility.has_collection(collection_name):
                utility.drop_collection(collection_name)
                print("Đã xóa collection cũ.")
            connections.disconnect("default")

        # Khởi tạo DocumentEmbedder
        print("Đang khởi tạo DocumentEmbedder...")
        embedder = DocumentEmbedder(
            model_name=model_name,
            device="cuda" if torch.cuda.is_available() else "cpu",
            collection_name=collection_name,
            dim=1024,  # BGE-large có dim 1024
            milvus_host=milvus_host,
            milvus_port=milvus_port,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        print("Khởi tạo thành công.")

        # Chuẩn bị metadata cho tài liệu
        paper_metadata = {
            "source_file": str(file_path),
            "paper_type": "scientific_paper",
            "processing_timestamp": datetime.datetime.now().isoformat(),
            "file_size_mb": file_path.stat().st_size / 1024 / 1024
        }

        # Xử lý bài báo khoa học
        print(f"Đang xử lý file: {file_path}...")
        embedder.process_scientific_paper(str(file_path), paper_metadata)
        print("Hoàn tất xử lý và nạp dữ liệu vào Milvus.")

        # Kiểm tra bằng truy vấn nếu có
        if query:
            print(f"\nĐang thực hiện truy vấn kiểm tra: '{query}'")
            results = embedder.search(
                query=query,
                top_k=query_top_k,
                score_threshold=0.0 # Lấy tất cả kết quả để xem
            )

            if results:
                print("\n--- Kết quả tìm kiếm ---")
                for i, result in enumerate(results, 1):
                    print(f"[{i}] Điểm tương đồng (Score): {result['score']:.4f}")
                    print(f"    Nguồn: {result['source_file'].split(os.path.sep)[-1]}")
                    print(f"    Phần (Section): {result['section']}")
                    preview = result['text'][:200].replace('\n', ' ')
                    print(f"    Nội dung: {preview}...")
                    print("-" * 25)
            else:
                print("Không tìm thấy kết quả nào cho truy vấn.")

        # Đóng kết nối
        print("\nĐang đóng kết nối...")
        embedder.close()
        print("Hoàn tất.")

    except Exception as e:
        print(f"Đã xảy ra lỗi trong quá trình thực thi: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)


from pathlib import Path
import sys
import os
import datetime
import torch
from pymilvus import connections, utility
    
# (Lớp DocumentEmbedder phải được định nghĩa ở trên)

main()

Đang khởi tạo DocumentEmbedder...
