### Imports


In [1]:
import os
from pathlib import Path
import json
import random
import random
import re
from typing import List, Dict, Tuple, Set
from collections import Counter, defaultdict
import numpy as np
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from sentence_transformers.util import cos_sim
import PyPDF2
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

DATA_DIR = "../data/"
OUTPUT_DIR = "./trained_model/"
OUTPUT_RESULTS_DIR = "./results/"
TRAINING_DATA_DIR = "./train_data/"

for path in [ OUTPUT_DIR, OUTPUT_RESULTS_DIR, TRAINING_DATA_DIR]:
    os.makedirs(path, exist_ok=True)

### STEP 1: PDF PROCESSING & CHUNKING

In [3]:

def extract_text_from_pdf(pdf_path: Path) -> str:
    """Extract text from a single PDF"""
    try:
        with open(pdf_path, 'rb') as f:
            reader = PyPDF2.PdfReader(f)
            text = ""
            for page in reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n\n"
        
        text = text.replace('\x00', '')
        text = ' '.join(text.split())
        return text
    except Exception as e:
        print(f"   ‚ö† Error reading {pdf_path.name}: {e}")
        return ""


def chunk_text_semantic(text: str, 
                        max_tokens: int = 256,
                        overlap_tokens: int = 50,
                        min_tokens: int = 30) -> List[str]:
    """
    Semantic chunking using sentence boundaries.
    More suitable for embedding models than character-based splitting.
    """
    from transformers import AutoTokenizer
    
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    
    import re
    sentences = re.split(r'(?<=[.!?])\s+', text)
    
    chunks = []
    current_chunk = []
    current_tokens = 0
    
    for sentence in sentences:
        sentence = sentence.strip()
        if not sentence:
            continue
        
        sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False))
        
        if sentence_tokens > max_tokens:
            words = sentence.split()
            temp_chunk = []
            temp_tokens = 0
            
            for word in words:
                word_tokens = len(tokenizer.encode(word, add_special_tokens=False))
                if temp_tokens + word_tokens <= max_tokens - 10:
                    temp_chunk.append(word)
                    temp_tokens += word_tokens
                else:
                    if temp_chunk:
                        chunks.append(' '.join(temp_chunk))
                    temp_chunk = [word]
                    temp_tokens = word_tokens
            
            if temp_chunk:
                chunks.append(' '.join(temp_chunk))
            continue

        if current_tokens + sentence_tokens > max_tokens:

            if current_chunk:
                chunks.append(' '.join(current_chunk))

            if overlap_tokens > 0 and current_chunk:
                overlap_text = ' '.join(current_chunk)
                overlap_tok = tokenizer.encode(overlap_text, add_special_tokens=False)[-overlap_tokens:]
                overlap_text = tokenizer.decode(overlap_tok)
                current_chunk = [overlap_text, sentence]
                current_tokens = len(tokenizer.encode(' '.join(current_chunk), add_special_tokens=False))
            else:
                current_chunk = [sentence]
                current_tokens = sentence_tokens
        else:
            current_chunk.append(sentence)
            current_tokens += sentence_tokens

    if current_chunk:
        chunks.append(' '.join(current_chunk))
    
    chunks = [c for c in chunks if len(tokenizer.encode(c, add_special_tokens=False)) >= min_tokens]
    
    return chunks


def extract_text_from_pdfs(pdf_folder: str,
                           max_tokens: int = 256,
                           overlap_tokens: int = 50) -> List[Dict]:
    """Extract and chunk text from all PDFs in folder"""
    pdf_path = Path(pdf_folder)
    if not pdf_path.exists():
        raise ValueError(f"PDF folder not found: {pdf_folder}")
    
    pdf_files = list(pdf_path.glob("*.pdf"))
    if not pdf_files:
        raise ValueError(f"No PDF files found in {pdf_folder}")
    
    print(f"Found {len(pdf_files)} PDF files")
    
    all_chunks = []
    
    for pdf_file in tqdm(pdf_files, desc="Processing PDFs"):
        text = extract_text_from_pdf(pdf_file)
        if not text:
            continue

        text_chunks = chunk_text_semantic(text, max_tokens, overlap_tokens)
        
        for i, chunk_text in enumerate(text_chunks):
            all_chunks.append({
                'text': chunk_text,
                'source': pdf_file.name,
                'chunk_id': f"{pdf_file.stem}_chunk_{i}",
                'chunk_index': i,
                'total_chunks': len(text_chunks)
            })
    
    print(f"‚úì Extracted {len(all_chunks)} chunks from {len(pdf_files)} PDFs")
    return all_chunks


### Training Dataset Creation

In [4]:
_nlp = None
_sentence_model = None

def get_nlp():
    global _nlp
    if _nlp is None:
        import spacy
        try:
            _nlp = spacy.load("en_core_web_sm")
        except OSError:
            import subprocess
            subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
            _nlp = spacy.load("en_core_web_sm")
    return _nlp


class QueryGenerator:
    """
    Data-driven query generator that learns patterns from your corpus.
    No hardcoded rules - adapts to domain vocabulary and structure.
    """
    
    def __init__(self):
        self.vocab_stats = None
        self.phrase_patterns = None
        self.question_templates = None
        
    def learn_from_corpus(self, chunks: List[Dict], sample_size: int = 1000):
        """
        Learn vocabulary, patterns, and structure from the corpus itself.
        This replaces all hardcoded rules with data-driven patterns.
        """
        print("Learning corpus patterns...")

        sampled = random.sample(chunks, min(sample_size, len(chunks)))
        
        nlp = get_nlp()

        all_noun_phrases = []
        all_entities = []
        all_verbs = []
        sentence_starts = []
        bigrams = []
        trigrams = []
        
        for chunk in sampled:
            doc = nlp(chunk['text'][:500])

            for chunk_item in doc.noun_chunks:
                text = self._clean_phrase(chunk_item.text)
                if text:
                    all_noun_phrases.append(text.lower())
            
            for ent in doc.ents:
                text = self._clean_phrase(ent.text)
                if text and len(text.split()) <= 4:
                    all_entities.append((text.lower(), ent.label_))
            
            for token in doc:
                if token.pos_ == 'VERB' and not token.is_stop:
                    all_verbs.append(token.lemma_.lower())

            sentences = [s.text.strip() for s in doc.sents if len(s.text.strip()) > 20]
            for sent in sentences[:3]:
                words = sent.split()[:5]
                sentence_starts.append(' '.join(words).lower())

            tokens = [t.text.lower() for t in doc if not t.is_stop and t.is_alpha]
            if len(tokens) >= 2:
                bigrams.extend([f"{tokens[i]} {tokens[i+1]}" for i in range(len(tokens)-1)])
            if len(tokens) >= 3:
                trigrams.extend([f"{tokens[i]} {tokens[i+1]} {tokens[i+2]}" for i in range(len(tokens)-2)])

        self.vocab_stats = {
            'top_noun_phrases': [item for item, count in Counter(all_noun_phrases).most_common(200)],
            'top_entities': [item[0] for item, count in Counter(all_entities).most_common(100)],
            'entity_types': Counter([item[1] for item in all_entities]),
            'top_verbs': [item for item, count in Counter(all_verbs).most_common(50)],
            'common_bigrams': [item for item, count in Counter(bigrams).most_common(100)],
            'common_trigrams': [item for item, count in Counter(trigrams).most_common(100)],
            'sentence_patterns': Counter(sentence_starts).most_common(50)
        }

        self.question_templates = self._generate_templates()
        
        print(f"‚úì Learned {len(self.vocab_stats['top_noun_phrases'])} key phrases")
        print(f"‚úì Learned {len(self.vocab_stats['top_entities'])} entities")
        print(f"‚úì Generated {len(self.question_templates)} query templates")
        
    def _clean_phrase(self, text: str) -> str:
        """Clean and normalize phrases"""
        text = text.strip()

        text = re.sub(r'^(the|a|an|this|that|these|those)\s+', '', text, flags=re.IGNORECASE)
        text = re.sub(r'\s+', ' ', text)
        return text if len(text) > 3 else ''
    
    def _generate_templates(self) -> List[Dict]:
        """
        Generate query templates dynamically based on common linguistic patterns.
        These adapt to your domain's vocabulary and structure.
        """
        templates = []

        question_words = ['what', 'how', 'why', 'when', 'where', 'which', 'who']
        verbs = self.vocab_stats['top_verbs'][:20] if self.vocab_stats['top_verbs'] else ['is', 'are', 'does']
        
        for qw in question_words:
            for verb in verbs[:5]:
                templates.append({
                    'pattern': f"{qw} {verb} {{concept}}",
                    'type': 'question'
                })

        phrase_starters = ['explain', 'describe', 'define', 'overview of', 'details about', 
                          'information on', 'summary of', 'analysis of']
        for starter in phrase_starters:
            templates.append({
                'pattern': f"{starter} {{concept}}",
                'type': 'imperative'
            })

        templates.extend([
            {'pattern': '{concept}', 'type': 'keyword'},
            {'pattern': '{concept} {concept2}', 'type': 'multi_keyword'},
            {'pattern': '{concept} overview', 'type': 'keyword_suffix'},
            {'pattern': '{concept} definition', 'type': 'keyword_suffix'},
        ])
        
        return templates
    
    def _extract_key_phrases(self, text: str, max_phrases: int = 10) -> List[str]:
        """
        Extract most relevant phrases using TF-IDF-like scoring against learned vocabulary.
        """
        nlp = get_nlp()
        doc = nlp(text[:800])
        
        phrases = []
        scores = []

        for chunk in doc.noun_chunks:
            phrase = self._clean_phrase(chunk.text).lower()
            if not phrase:
                continue
            
            score = 0

            if phrase in self.vocab_stats['top_noun_phrases'][:50]:
                score += 10
            elif phrase in self.vocab_stats['top_noun_phrases']:
                score += 5

            word_count = len(phrase.split())
            if 2 <= word_count <= 4:
                score += word_count * 2

            if phrase in self.vocab_stats['common_bigrams']:
                score += 3
            if phrase in self.vocab_stats['common_trigrams']:
                score += 5
            
            phrases.append(phrase)
            scores.append(score)

        for ent in doc.ents:
            phrase = self._clean_phrase(ent.text).lower()
            if phrase and len(phrase.split()) <= 4:
                if phrase in self.vocab_stats['top_entities'][:30]:
                    phrases.append(phrase)
                    scores.append(8)

        if not phrases:
            return []
        
        sorted_phrases = [p for _, p in sorted(zip(scores, phrases), reverse=True)]
        return list(dict.fromkeys(sorted_phrases))[:max_phrases]
    
    def _sentence_to_query(self, sentence: str) -> List[str]:
        """
        Convert informative sentences to queries using learned patterns.
        """
        queries = []
        nlp = get_nlp()
        doc = nlp(sentence[:200])
        
        for token in doc:
            if token.dep_ == 'ROOT' and token.pos_ == 'VERB':
                
                subjects = [child.text for child in token.children if child.dep_ in ['nsubj', 'nsubjpass']]
                if subjects:
                    subject = subjects[0]
                    
                    if token.lemma_ in ['be', 'have']:
                        queries.append(f"What {token.lemma_} {subject}?")
                    else:
                        queries.append(f"How does {subject} {token.lemma_}?")
        
        return queries
    
    def generate_queries(self, chunk_text: str, num_queries: int = 5) -> List[str]:
        """
        Generate diverse queries for a chunk using learned patterns.
        """
        if not self.vocab_stats:
            raise ValueError("Must call learn_from_corpus() first!")
        
        queries = set()
        
        key_phrases = self._extract_key_phrases(chunk_text, max_phrases=8)
        
        if not key_phrases:
            return []
        
        for template in self.question_templates:
            if len(queries) >= num_queries * 3:
                break
            
            pattern = template['pattern']
            
            if '{concept}' in pattern:
                for phrase in key_phrases[:4]:
                    if '{concept2}' in pattern and len(key_phrases) >= 2:

                        for phrase2 in key_phrases[:4]:
                            if phrase != phrase2:
                                query = pattern.format(concept=phrase, concept2=phrase2)
                                queries.add(query)
                    else:

                        query = pattern.format(concept=phrase)
                        queries.add(query)

        sentences = [s.strip() for s in re.split(r'[.!?]+', chunk_text) if 30 < len(s.strip()) < 150]
        for sent in sentences[:2]:
            sent_queries = self._sentence_to_query(sent)
            queries.update(sent_queries)

        for phrase in key_phrases[:3]:
            queries.add(phrase)

        clean_queries = []
        for q in queries:
            q = re.sub(r'\s+', ' ', q.strip())
            if (8 < len(q) < 150 and 
                len(q.split()) >= 2 and
                not q.startswith('the ')):
                clean_queries.append(q)

        random.shuffle(clean_queries)
        return clean_queries[:num_queries]


def create_training_dataset(chunks: List[Dict], 
                           queries_per_chunk: int = 5,
                           learn_sample_size: int = 1000) -> List[Dict]:
    """
    Generate training dataset using data-driven approach.
    
    Args:
        chunks: List of document chunks with 'text' and 'chunk_id'
        queries_per_chunk: Number of queries to generate per chunk
        learn_sample_size: Number of chunks to use for learning patterns
    
    Returns:
        List of {'query', 'positive', 'chunk_id'} training pairs
    """
    from tqdm import tqdm

    generator = QueryGenerator()
    generator.learn_from_corpus(chunks, sample_size=learn_sample_size)

    training_data = []
    skipped = 0
    
    for chunk in tqdm(chunks, desc="Generating queries"):
        try:
            queries = generator.generate_queries(chunk['text'], queries_per_chunk)
            
            if not queries:
                skipped += 1
                continue
            
            for query in queries:
                training_data.append({
                    'query': query,
                    'positive': chunk['text'],
                    'chunk_id': chunk['chunk_id']
                })
        except Exception as e:
            skipped += 1
            continue
    
    print(f"\n Generated {len(training_data)} query-chunk pairs")
    if skipped > 0:
        print(f"Skipped {skipped} chunks")
    
    return training_data

In [5]:

def Training_data_and_chunking():
    """Complete training pipeline"""

    CONFIG = {
        'pdf_folder': DATA_DIR,
        'output_dir': TRAINING_DATA_DIR,
        'base_model': "sentence-transformers/all-MiniLM-L6-v2",
        'queries_per_chunk': 5,
    }

    output_dir = Path(CONFIG['output_dir'])
    output_dir.mkdir(exist_ok=True)
    
    print("="*60)
    print("DOMAIN-SPECIFIC EMBEDDING TRAINING PIPELINE")
    print("="*60)
    
    print("\nSTEP 1: Extracting and chunking PDFs...")
    chunks = extract_text_from_pdfs(CONFIG['pdf_folder'])
    
    # Save chunks
    with open(output_dir / "chunks.json", 'w') as f:
        json.dump(chunks, f, indent=2)
    print(f"‚úì Saved {len(chunks)} chunks")

    print("\n STEP 2: Generating synthetic queries...")

    
    training_data = create_training_dataset(
        chunks,
        queries_per_chunk=CONFIG['queries_per_chunk'],
        learn_sample_size=min(1000, len(chunks))
    )

    with open(output_dir / "training_data_raw.json", 'w') as f:
        json.dump(training_data, f, indent=2)

    return training_data, chunks

training_data, chunks = Training_data_and_chunking()

DOMAIN-SPECIFIC EMBEDDING TRAINING PIPELINE

STEP 1: Extracting and chunking PDFs...
Found 5 PDF files


Processing PDFs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:05<00:00,  1.04s/it]


‚úì Extracted 407 chunks from 5 PDFs
‚úì Saved 407 chunks

 STEP 2: Generating synthetic queries...
Learning corpus patterns...
‚úì Learned 200 key phrases
‚úì Learned 100 entities
‚úì Generated 47 query templates


Generating queries: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 407/407 [00:17<00:00, 23.33it/s]


 Generated 2030 query-chunk pairs
Skipped 1 chunks





### STEP 2: HARD NEGATIVE MINING

In [6]:
def mine_hard_negatives_efficient(
    training_data: List[Dict], 
    chunks: List[Dict], 
    base_model: SentenceTransformer, 
    num_negatives: int = 3,
    use_cross_encoder: bool = True,
    cross_encoder_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
) -> List[Dict]:

    from rank_bm25 import BM25Okapi
    
    cross_encoder = None
    if use_cross_encoder:
        from sentence_transformers import CrossEncoder
        print("   Loading cross-encoder for reranking...")
        cross_encoder = CrossEncoder(cross_encoder_model)
    
    # Prepare BM25 index
    corpus = [c['text'] for c in chunks]
    tokenized_corpus = [doc.lower().split() for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
    
    # Embed all chunks once with bi-encoder
    print("   Encoding chunks with bi-encoder...")
    chunk_embeddings = base_model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)
    
    enriched_data = []
    
    for item in tqdm(training_data, desc="Mining hard negatives"):
        query = item['query']
        positive_id = item['chunk_id']
        positive_text = item['positive']
        
        # ========== STAGE 1: Fast Candidate Retrieval ==========
        # BM25 retrieval (lexical)
        tokenized_query = query.lower().split()
        bm25_scores = bm25.get_scores(tokenized_query)
        top_bm25_indices = np.argsort(bm25_scores)[-30:][::-1]  # Top 30
        
        # Semantic retrieval (bi-encoder)
        query_emb = base_model.encode(query, convert_to_tensor=True)
        semantic_scores = cos_sim(query_emb, chunk_embeddings)[0]
        top_sem_indices = torch.argsort(semantic_scores, descending=True)[:30].tolist()
        
        # Combine candidates and filter out the positive
        candidate_indices = list(set(top_bm25_indices.tolist()) | set(top_sem_indices))
        candidate_indices = [i for i in candidate_indices if chunks[i]['chunk_id'] != positive_id]
        
        if not candidate_indices:
            continue

        if cross_encoder is not None and len(candidate_indices) > 0:

            pairs = [[query, corpus[idx]] for idx in candidate_indices]
        
            ce_scores = cross_encoder.predict(pairs)
        
            bi_scores = semantic_scores[candidate_indices].cpu().numpy()
            
            bi_normalized = (bi_scores - bi_scores.min()) / (bi_scores.max() - bi_scores.min() + 1e-8)
            ce_normalized = (ce_scores - ce_scores.min()) / (ce_scores.max() - ce_scores.min() + 1e-8)

            hardness_scores = bi_normalized - ce_normalized

            hard_negative_indices = []
            sorted_indices = np.argsort(hardness_scores)[::-1]
            
            for idx in sorted_indices:
                candidate_idx = candidate_indices[idx]

                if ce_scores[idx] < 0.5 and len(hard_negative_indices) < num_negatives:
                    hard_negative_indices.append(candidate_idx)

            if len(hard_negative_indices) < num_negatives:
                remaining = num_negatives - len(hard_negative_indices)
                for idx in sorted_indices:
                    candidate_idx = candidate_indices[idx]
                    if candidate_idx not in hard_negative_indices:
                        hard_negative_indices.append(candidate_idx)
                        if len(hard_negative_indices) >= num_negatives:
                            break
            
            negatives = [corpus[idx] for idx in hard_negative_indices[:num_negatives]]
        
        else:
            sorted_by_similarity = sorted(
                candidate_indices, 
                key=lambda i: semantic_scores[i].item(), 
                reverse=True
            )
            negatives = [corpus[idx] for idx in sorted_by_similarity[:num_negatives]]

        if negatives:
            enriched_data.append({
                'query': query,
                'positive': positive_text,
                'negatives': negatives,
                'chunk_id': positive_id
            })
    
    print(f"   ‚úì Generated {len(enriched_data)} training examples with hard negatives")
    return enriched_data



In [8]:
def Negative_samples(training_data: List[Dict], chunks: List[Dict]) -> List[Dict]:
    """Complete training pipeline"""

    CONFIG = {
        'pdf_folder': DATA_DIR,
        'output_dir': TRAINING_DATA_DIR,
        'base_model': "sentence-transformers/all-MiniLM-L6-v2",
        'queries_per_chunk': 5,
        'num_hard_negatives': 3,
    }

    output_dir = Path(CONFIG['output_dir'])
    output_dir.mkdir(exist_ok=True)
    
    print("\n STEP 3: Mining hard negatives...")
    base_model = SentenceTransformer(CONFIG['base_model'])
    
    training_data_enriched = mine_hard_negatives_efficient(
        training_data,
        chunks,
        base_model,
        num_negatives=CONFIG['num_hard_negatives'],
    )

    with open(output_dir / "training_data_enriched.json", 'w') as f:
        json.dump(training_data_enriched, f, indent=2)
    return training_data_enriched

training_data_enriched =  Negative_samples(training_data, chunks)


 STEP 3: Mining hard negatives...
   Loading cross-encoder for reranking...
   Encoding chunks with bi-encoder...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  5.82it/s]
Mining hard negatives: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2030/2030 [21:23<00:00,  1.58it/s]

   ‚úì Generated 2030 training examples with hard negatives





### STEP 3: MODEL TRAINING

In [9]:
from typing import List, Dict, Optional, Tuple
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
import os
import csv


def train_embedding_model(
    train_data: List[Dict],
    base_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
    output_path: str = "./custom_embedding_model",
    val_data: Optional[List[Dict]] = None,
    epochs: int = 3,
    batch_size: int = 16,
    warmup_ratio: float = 0.1,
    use_amp: bool = True
) -> Tuple[SentenceTransformer, Dict]:
    """
    Train custom embedding model with hard negatives.
    
    Returns:
        model: Trained model
        training_stats: Dictionary with training metrics
    """
    print(f"Loading base model: {base_model_name}")
    model = SentenceTransformer(base_model_name)

    print("Preparing training data...")
    train_examples = []
    
    for item in train_data:
        query = item['query']
        positive = item['positive']
        negatives = item.get('negatives', [])
    
        if negatives:
            for neg in negatives:
                train_examples.append(InputExample(texts=[query, positive, neg]))
        else:
            train_examples.append(InputExample(texts=[query, positive]))
    
    print(f"Created {len(train_examples)} training examples")

    train_dataloader = DataLoader(
        train_examples, 
        shuffle=True, 
        batch_size=batch_size
    )

    if any('negatives' in item for item in train_data):
        print("Using MultipleNegativesRankingLoss (with hard negatives)")
        train_loss = losses.MultipleNegativesRankingLoss(model)
    else:
        print("Using ContrastiveLoss")
        train_loss = losses.ContrastiveLoss(model)

    evaluator = None
    if val_data:
        print("Setting up validation evaluators...")

        val_examples = []
        for item in val_data:
            query = item['query']
            positive = item['positive']
            negatives = item.get('negatives', [])
            
            if negatives:

                val_examples.append(InputExample(texts=[query, positive, negatives[0]]))
            else:
                val_examples.append(InputExample(texts=[query, positive]))

        triplet_evaluator = evaluation.TripletEvaluator.from_input_examples(
            val_examples,
            name="validation_loss",
            write_csv=True
        )
        
        val_queries = {f"q{i}": item['query'] for i, item in enumerate(val_data)}
        val_corpus = {f"d{i}": item['positive'] for i, item in enumerate(val_data)}
        val_relevant = {f"q{i}": {f"d{i}"} for i in range(len(val_data))}
        
        ir_evaluator = evaluation.InformationRetrievalEvaluator(
            queries=val_queries,
            corpus=val_corpus,
            relevant_docs=val_relevant,
            name="validation_metrics",
            write_csv=True
        )
        
        evaluator = evaluation.SequentialEvaluator(
            [triplet_evaluator, ir_evaluator],
            main_score_function=lambda scores: scores[1] 
        )

    print(f"\nTraining for {epochs} epochs...")
    warmup_steps = int(len(train_dataloader) * warmup_ratio)
    
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=epochs,
        warmup_steps=warmup_steps,
        output_path=output_path,
        show_progress_bar=True,
        use_amp=use_amp,
        evaluator=evaluator,
        evaluation_steps=len(train_dataloader) if evaluator else 0,
        save_best_model=True
    )
    
    print(f"‚úì Training complete. Model saved to {output_path}")

    training_stats = {
        'epochs': epochs,
        'total_steps': len(train_dataloader) * epochs,
        'warmup_steps': warmup_steps
    }
    
    if val_data:
        try:

            loss_csv_path = os.path.join(output_path, "validation_loss_evaluation_results.csv")
            if os.path.exists(loss_csv_path):
                with open(loss_csv_path, 'r') as f:
                    reader = csv.DictReader(f)
                    loss_data = list(reader)
                    training_stats['validation_loss_history'] = [
                        float(row['cosine_accuracy']) for row in loss_data
                    ]

            metrics_csv_path = os.path.join(output_path, "validation_metrics_evaluation_results.csv")
            if os.path.exists(metrics_csv_path):
                with open(metrics_csv_path, 'r') as f:
                    reader = csv.DictReader(f)
                    metrics_data = list(reader)
                    if metrics_data:
                        last_metrics = metrics_data[-1]
                        training_stats['final_metrics'] = {
                            'accuracy@1': float(last_metrics.get('cosine_accuracy@1', 0)),
                            'accuracy@10': float(last_metrics.get('cosine_accuracy@10', 0)),
                            'ndcg@10': float(last_metrics.get('cosine_ndcg@10', 0)),
                            'mrr@10': float(last_metrics.get('cosine_mrr@10', 0)),
                            'map@100': float(last_metrics.get('cosine_map@100', 0))
                        }
        except Exception as e:
            print(f"Warning: Could not read validation metrics: {e}")
    
    return model, training_stats

In [10]:
def Training_model(training_data_enriched: List[Dict]):
    """Complete training pipeline"""

    CONFIG = {
        'output_dir': OUTPUT_DIR,
        'base_model': "sentence-transformers/all-MiniLM-L6-v2",
        'epochs': 3,
        'batch_size': 16,
        'train_val_test_split': [0.7, 0.15, 0.15]
    }

    output_dir = Path(CONFIG['output_dir'])
    output_dir.mkdir(exist_ok=True)

    print("\n STEP 4: Splitting train/val/test...")
    random.shuffle(training_data_enriched)
    
    n = len(training_data_enriched)
    train_size = int(CONFIG['train_val_test_split'][0] * n)
    val_size = int(CONFIG['train_val_test_split'][1] * n)
    
    train_data = training_data_enriched[:train_size]
    val_data = training_data_enriched[train_size:train_size+val_size]
    test_data = training_data_enriched[train_size+val_size:]
    
    print(f"Train: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

    print("\nüèãÔ∏è STEP 5: Training custom embedding model...")
    custom_model, training_stats = train_embedding_model(
        train_data,
        base_model_name=CONFIG['base_model'],
        output_path=str(output_dir / "custom_model"),
        val_data=val_data,
        epochs=CONFIG['epochs'],
        batch_size=CONFIG['batch_size']
    )

    return training_stats

training_stats = Training_model(training_data_enriched)


 STEP 4: Splitting train/val/test...
Train: 1421 | Val: 304 | Test: 305

üèãÔ∏è STEP 5: Training custom embedding model...
Loading base model: sentence-transformers/all-MiniLM-L6-v2
Preparing training data...
Created 4263 training examples
Using MultipleNegativesRankingLoss (with hard negatives)
Setting up validation evaluators...

Training for 3 epochs...


                                                                     

Step,Training Loss,Validation Loss,Validation Loss Cosine Accuracy,Validation Metrics Cosine Accuracy@1,Validation Metrics Cosine Accuracy@3,Validation Metrics Cosine Accuracy@5,Validation Metrics Cosine Accuracy@10,Validation Metrics Cosine Precision@1,Validation Metrics Cosine Precision@3,Validation Metrics Cosine Precision@5,Validation Metrics Cosine Precision@10,Validation Metrics Cosine Recall@1,Validation Metrics Cosine Recall@3,Validation Metrics Cosine Recall@5,Validation Metrics Cosine Recall@10,Validation Metrics Cosine Ndcg@10,Validation Metrics Cosine Mrr@10,Validation Metrics Cosine Map@100,Sequential Score
267,No log,No log,0.891447,0.075658,0.207237,0.276316,0.430921,0.075658,0.069079,0.055263,0.043092,0.075658,0.207237,0.276316,0.430921,0.228736,0.167315,0.188066,0.228736
534,1.696400,No log,0.914474,0.095395,0.203947,0.3125,0.453947,0.095395,0.067982,0.0625,0.045395,0.095395,0.203947,0.3125,0.453947,0.249853,0.187439,0.20896,0.249853
801,1.696400,No log,0.914474,0.095395,0.207237,0.319079,0.440789,0.095395,0.069079,0.063816,0.044079,0.095395,0.207237,0.319079,0.440789,0.244668,0.184523,0.207653,0.244668


‚úì Training complete. Model saved to trained_model/custom_model


### STEP 4: EVALUATION & VISUALIZATION

In [11]:
def evaluate_retrieval(
    model: SentenceTransformer,
    test_data: List[Dict],
    chunks: List[Dict],
    k_values: List[int] = [1, 3, 5, 10]
) -> Dict:
    """
    Comprehensive retrieval evaluation.
    
    Metrics:
    - Recall@K: What fraction of queries retrieve the correct document in top-K
    - MRR: Mean Reciprocal Rank
    - MAP: Mean Average Precision
    """
    corpus = [c['text'] for c in chunks]
    chunk_id_to_idx = {c['chunk_id']: i for i, c in enumerate(chunks)}
    
    print("Encoding corpus for evaluation...")
    corpus_embeddings = model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)
    
    metrics = {f'recall@{k}': [] for k in k_values}
    metrics['mrr'] = []
    metrics['map'] = []
    
    for item in tqdm(test_data, desc="Evaluating"):
        query = item['query']
        target_id = item['chunk_id']
        
        target_idx = chunk_id_to_idx.get(target_id)
        if target_idx is None:
            continue

        query_emb = model.encode(query, convert_to_tensor=True)
        scores = cos_sim(query_emb, corpus_embeddings)[0]
        top_indices = torch.argsort(scores, descending=True).cpu().numpy()
    
        rank = np.where(top_indices == target_idx)[0][0] + 1

        for k in k_values:
            metrics[f'recall@{k}'].append(1.0 if rank <= k else 0.0)

        metrics['mrr'].append(1.0 / rank)

        metrics['map'].append(1.0 / rank)

    results = {key: np.mean(values) for key, values in metrics.items()}
    
    return results


def compare_models(
    base_model: SentenceTransformer,
    custom_model: SentenceTransformer,
    test_data: List[Dict],
    chunks: List[Dict]
) -> Dict:
    """Compare baseline vs custom model"""
    print("\n" + "="*60)
    print("EVALUATING BASELINE MODEL")
    print("="*60)
    base_results = evaluate_retrieval(base_model, test_data, chunks)
    
    print("\n" + "="*60)
    print("EVALUATING CUSTOM MODEL")
    print("="*60)
    custom_results = evaluate_retrieval(custom_model, test_data, chunks)
    
    return {'baseline': base_results, 'custom': custom_results}


def plot_comparison(results: Dict, save_path: str = "model_comparison.png"):
    """Plot model comparison"""
    metrics = list(results['baseline'].keys())
    baseline_scores = [results['baseline'][m] for m in metrics]
    custom_scores = [results['custom'][m] for m in metrics]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width/2, baseline_scores, width, label='Baseline', alpha=0.8)
    ax.bar(x + width/2, custom_scores, width, label='Custom', alpha=0.8)
    
    ax.set_xlabel('Metrics', fontsize=12, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Baseline vs Custom Embedding Model', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, rotation=45, ha='right')
    ax.legend()
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"‚úì Comparison plot saved: {save_path}")


def visualize_embeddings(
    model: SentenceTransformer,
    chunks: List[Dict],
    sample_size: int = 500,
    save_path: str = "embedding_visualization.png"
):
    """
    Visualize embeddings using t-SNE and compute clustering metrics.
    """

    sampled = random.sample(chunks, min(sample_size, len(chunks)))
    texts = [c['text'] for c in sampled]
    sources = [c['source'] for c in sampled]
    
    print(f"Computing embeddings for {len(texts)} samples...")
    embeddings = model.encode(texts, show_progress_bar=True)

    print("Computing t-SNE projection...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embeddings_2d = tsne.fit_transform(embeddings)

    unique_sources = list(set(sources))
    source_labels = [unique_sources.index(s) for s in sources]
    
    if len(unique_sources) > 1:
        silhouette = silhouette_score(embeddings, source_labels)
        print(f"‚úì Silhouette Score: {silhouette:.4f}")
    else:
        silhouette = None

    fig, ax = plt.subplots(figsize=(12, 8))

    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_sources)))
    
    for i, source in enumerate(unique_sources):
        mask = np.array(sources) == source
        ax.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            c=[colors[i]],
            label=source[:20],
            alpha=0.6,
            s=50
        )
    
    ax.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE Dimension 2', fontsize=12)
    title = 'Embedding Space Visualization'
    if silhouette:
        title += f' (Silhouette: {silhouette:.3f})'
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"‚úì Visualization saved: {save_path}")
    
    return {'silhouette_score': silhouette}


In [27]:
def evaluate():
    """Complete training pipeline"""
    
    # Configuration
    CONFIG = {
        'pdf_folder': DATA_DIR,
        'output_dir': TRAINING_DATA_DIR,
        'base_model': "sentence-transformers/all-MiniLM-L6-v2",
        'epochs': 3,
        'batch_size': 16,
        'queries_per_chunk': 5,
        'num_hard_negatives': 3,
        'train_val_test_split': [0.7, 0.15, 0.15]
    }
    
    # Create output directory
    output_dir = Path(CONFIG['output_dir'])
    output_dir.mkdir(exist_ok=True)
    base_model = SentenceTransformer(CONFIG['base_model'])

    custom_model = SentenceTransformer("./trained_model/custom_model/")
    n = len(training_data_enriched)
    train_size = int(CONFIG['train_val_test_split'][0] * n)
    val_size = int(CONFIG['train_val_test_split'][1] * n)
    
    test_data = training_data_enriched[train_size+val_size:]
    # Step 6: Evaluate
    print("\nüìä STEP 6: Evaluating models...")
    results = compare_models(base_model, custom_model, test_data, chunks)
    
    # Save results
    with open(output_dir / "evaluation_results.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print results
    print("\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    
    for model_name in ['baseline', 'custom']:
        print(f"\n{model_name.upper()} MODEL:")
        for metric, value in results[model_name].items():
            print(f"  {metric:12s}: {value:.4f}")
    
    # Step 7: Visualizations
    print("\nüìà STEP 7: Creating visualizations...")
    
    # Model comparison
    plot_comparison(results, str(output_dir / "model_comparison.png"))
    
    # Baseline embeddings
    print("\nVisualizing baseline embeddings...")
    baseline_metrics = visualize_embeddings(
        base_model, 
        chunks, 
        save_path=str(output_dir / "baseline_embeddings.png")
    )
    
    # Custom embeddings
    print("\nVisualizing custom embeddings...")
    custom_metrics = visualize_embeddings(
        custom_model, 
        chunks, 
        save_path=str(output_dir / "custom_embeddings.png")
    )
    
    # Summary report
    print("\n" + "="*60)
    print("SUMMARY REPORT")
    print("="*60)
    print(f"\nTotal chunks processed: {len(chunks)}")
    print(f"Training examples generated: {len(training_data_enriched)}")
    print(f"\nClustering Quality:")
    print(f"  Baseline Silhouette: {baseline_metrics.get('silhouette_score', 'N/A')}")
    print(f"  Custom Silhouette: {custom_metrics.get('silhouette_score', 'N/A')}")
    
    print(f"\n‚úÖ All outputs saved to: {output_dir}")
    print(f"‚úÖ Custom model saved to: {output_dir / 'custom_model'}")

evaluate()
    


üìä STEP 6: Evaluating models...

EVALUATING BASELINE MODEL
Encoding corpus for evaluation...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  5.16it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 305/305 [00:02<00:00, 110.23it/s]



EVALUATING CUSTOM MODEL
Encoding corpus for evaluation...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  4.88it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 305/305 [00:02<00:00, 106.00it/s]



FINAL RESULTS

BASELINE MODEL:
  recall@1    : 0.0492
  recall@3    : 0.1082
  recall@5    : 0.1344
  recall@10   : 0.1869
  mrr         : 0.1043
  map         : 0.1043

CUSTOM MODEL:
  recall@1    : 0.0885
  recall@3    : 0.1869
  recall@5    : 0.2754
  recall@10   : 0.4492
  mrr         : 0.1929
  map         : 0.1929

üìà STEP 7: Creating visualizations...
‚úì Comparison plot saved: train_data/model_comparison.png

Visualizing baseline embeddings...
Computing embeddings for 407 samples...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  4.68it/s]


Computing t-SNE projection...
‚úì Silhouette Score: 0.0452
‚úì Visualization saved: train_data/baseline_embeddings.png

Visualizing custom embeddings...
Computing embeddings for 407 samples...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  4.60it/s]


Computing t-SNE projection...
‚úì Silhouette Score: 0.0574
‚úì Visualization saved: train_data/custom_embeddings.png

SUMMARY REPORT

Total chunks processed: 407
Training examples generated: 2030

Clustering Quality:
  Baseline Silhouette: 0.045152705162763596
  Custom Silhouette: 0.05740627273917198

‚úÖ All outputs saved to: train_data
‚úÖ Custom model saved to: train_data/custom_model


In [29]:
"""
Clustering Analysis & Visualization for Custom vs Baseline Embeddings
UPDATED: Loads chunks from extracted_chunks.json
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
import pandas as pd
from collections import Counter
import json
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# LOAD MODELS & DATA (UPDATED)
# ============================================================================

def load_models_and_data(base_model_path: str, custom_model_path: str, 
                         chunks_file: str ):
    """Load both models and chunks from JSON"""
    print(f"üì• Loading data from {chunks_file}...")
    
    base_model = SentenceTransformer(base_model_path)
    custom_model = SentenceTransformer(custom_model_path)
    
    # Load chunks from JSON
    with open(chunks_file, 'r') as f:
        chunks = json.load(f)
    
    texts = [c['text'][:300] for c in chunks]  # Truncate for speed
    sources = [c.get('source', 'unknown') for c in chunks]
    
    print(f"   ‚úì Loaded {len(texts)} chunks from {len(set(sources))} sources")
    
    return base_model, custom_model, texts, sources, chunks


# ============================================================================
# EMBEDDING GENERATION
# ============================================================================

def generate_embeddings(model: SentenceTransformer, texts: list) -> np.ndarray:
    """Generate embeddings with progress bar"""
    return model.encode(texts, show_progress_bar=True, convert_to_numpy=True)


# ============================================================================
# METRIC 1: CLUSTERING QUALITY METRICS
# ============================================================================

def compute_clustering_metrics(embeddings: np.ndarray, n_clusters: int = 5):
    """Compute Silhouette, Calinski-Harabasz, Davies-Bouldin scores"""
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    labels = kmeans.fit_predict(embeddings)
    
    metrics = {
        'silhouette': silhouette_score(embeddings, labels),
        'calinski_harabasz': calinski_harabasz_score(embeddings, labels),
        'davies_bouldin': davies_bouldin_score(embeddings, labels),
        'inertia': kmeans.inertia_,
        'labels': labels
    }
    
    return metrics


# ============================================================================
# METRIC 2: INTRA-CLUSTER vs INTER-CLUSTER DISTANCE
# ============================================================================

def compute_cluster_separation(embeddings: np.ndarray, labels: np.ndarray):
    """Calculate average intra-cluster vs inter-cluster distances"""
    
    unique_labels = np.unique(labels)
    cluster_centers = []
    
    # Compute cluster centers
    for label in unique_labels:
        cluster_points = embeddings[labels == label]
        cluster_centers.append(cluster_points.mean(axis=0))
    
    cluster_centers = np.array(cluster_centers)
    
    # Intra-cluster distance (average distance within clusters)
    intra_distances = []
    for label in unique_labels:
        cluster_points = embeddings[labels == label]
        center = cluster_centers[label]
        distances = np.linalg.norm(cluster_points - center, axis=1)
        intra_distances.extend(distances)
    
    # Inter-cluster distance (distance between cluster centers)
    inter_distances = []
    for i in range(len(cluster_centers)):
        for j in range(i+1, len(cluster_centers)):
            dist = np.linalg.norm(cluster_centers[i] - cluster_centers[j])
            inter_distances.append(dist)
    
    return {
        'avg_intra_distance': np.mean(intra_distances),
        'avg_inter_distance': np.mean(inter_distances),
        'separation_ratio': np.mean(inter_distances) / np.mean(intra_distances)
    }


# ============================================================================
# METRIC 3: NEIGHBOR COHERENCE CHECK
# ============================================================================

def neighbor_coherence_check(embeddings: np.ndarray, sources: list, k: int = 10):
    """Check if k-nearest neighbors come from the same source document"""
    
    from sklearn.neighbors import NearestNeighbors
    
    nbrs = NearestNeighbors(n_neighbors=k+1, metric='cosine').fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)
    
    coherence_scores = []
    
    for i in range(len(embeddings)):
        neighbor_indices = indices[i][1:]  # Exclude self
        neighbor_sources = [sources[idx] for idx in neighbor_indices]
        
        # Calculate percentage of neighbors from same source
        same_source_count = sum(1 for s in neighbor_sources if s == sources[i])
        coherence_scores.append(same_source_count / k)
    
    return {
        'avg_coherence': np.mean(coherence_scores),
        'median_coherence': np.median(coherence_scores),
        'std_coherence': np.std(coherence_scores)
    }


# ============================================================================
# METRIC 4: PAIRWISE SIMILARITY DISTRIBUTION
# ============================================================================

def pairwise_similarity_distribution(embeddings: np.ndarray, sample_size: int = 1000):
    """Compute distribution of pairwise cosine similarities"""
    
    # Sample for efficiency
    if len(embeddings) > sample_size:
        indices = np.random.choice(len(embeddings), sample_size, replace=False)
        sampled_embeddings = embeddings[indices]
    else:
        sampled_embeddings = embeddings
    
    # Compute pairwise cosine similarities
    from sklearn.metrics.pairwise import cosine_similarity
    similarities = cosine_similarity(sampled_embeddings)
    
    # Get upper triangle (exclude diagonal)
    triu_indices = np.triu_indices_from(similarities, k=1)
    similarity_values = similarities[triu_indices]
    
    return {
        'mean_similarity': np.mean(similarity_values),
        'std_similarity': np.std(similarity_values),
        'min_similarity': np.min(similarity_values),
        'max_similarity': np.max(similarity_values),
        'similarities': similarity_values
    }


# ============================================================================
# VISUALIZATION 1: UMAP/t-SNE 2D PROJECTION
# ============================================================================

def plot_embeddings_2d(base_embeddings: np.ndarray, custom_embeddings: np.ndarray, 
                       labels_base: np.ndarray, labels_custom: np.ndarray,
                       method: str = 'umap', save_path: str = './results/embeddings_2d.png'):
    """Plot 2D projections of embeddings"""
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    reducer = TSNE(n_components=2, random_state=42, perplexity=30)

    base_2d = reducer.fit_transform(base_embeddings)
    axes[0].scatter(base_2d[:, 0], base_2d[:, 1], c=labels_base, cmap='tab10', alpha=0.6, s=20)
    axes[0].set_title(f'Baseline Model - {method.upper()}', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Component 1')
    axes[0].set_ylabel('Component 2')
    
    # Custom model
    reducer = TSNE(n_components=2, random_state=42, perplexity=30)
    
    custom_2d = reducer.fit_transform(custom_embeddings)
    axes[1].scatter(custom_2d[:, 0], custom_2d[:, 1], c=labels_custom, cmap='tab10', alpha=0.6, s=20)
    axes[1].set_title(f'Custom Model - {method.upper()}', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Component 1')
    axes[1].set_ylabel('Component 2')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")


# ============================================================================
# VISUALIZATION 2: CLUSTERING METRICS COMPARISON
# ============================================================================

def plot_metrics_comparison(base_metrics: dict, custom_metrics: dict, 
                            save_path: str = './results/metrics_comparison.png'):
    """Bar plot comparing clustering metrics"""
    
    metrics_to_plot = ['silhouette', 'calinski_harabasz', 'davies_bouldin']
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    for idx, metric in enumerate(metrics_to_plot):
        base_val = base_metrics[metric]
        custom_val = custom_metrics[metric]
        
        # For Davies-Bouldin, lower is better
        if metric == 'davies_bouldin':
            better = 'Custom' if custom_val < base_val else 'Base'
            colors = ['#d62728', '#2ca02c'] if custom_val < base_val else ['#2ca02c', '#d62728']
        else:
            better = 'Custom' if custom_val > base_val else 'Base'
            colors = ['#d62728', '#2ca02c'] if custom_val > base_val else ['#2ca02c', '#d62728']
        
        axes[idx].bar(['Baseline', 'Custom'], [base_val, custom_val], color=colors, alpha=0.7)
        axes[idx].set_title(metric.replace('_', ' ').title(), fontweight='bold')
        axes[idx].set_ylabel('Score')
        axes[idx].grid(axis='y', alpha=0.3)
        
        # Add improvement text
        improvement = ((custom_val - base_val) / base_val) * 100
        axes[idx].text(0.5, max(base_val, custom_val) * 1.05, 
                      f'{improvement:+.1f}%', ha='center', fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")


# ============================================================================
# VISUALIZATION 3: SIMILARITY DISTRIBUTION HISTOGRAM
# ============================================================================

def plot_similarity_distribution(base_sims: np.ndarray, custom_sims: np.ndarray,
                                 save_path: str = './results/similarity_distribution.png'):
    """Histogram of pairwise similarity distributions"""
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Base model
    axes[0].hist(base_sims, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
    axes[0].axvline(base_sims.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {base_sims.mean():.3f}')
    axes[0].set_title('Baseline Model - Similarity Distribution', fontweight='bold')
    axes[0].set_xlabel('Cosine Similarity')
    axes[0].set_ylabel('Frequency')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Custom model
    axes[1].hist(custom_sims, bins=50, alpha=0.7, color='coral', edgecolor='black')
    axes[1].axvline(custom_sims.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {custom_sims.mean():.3f}')
    axes[1].set_title('Custom Model - Similarity Distribution', fontweight='bold')
    axes[1].set_xlabel('Cosine Similarity')
    axes[1].set_ylabel('Frequency')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")


# ============================================================================
# VISUALIZATION 4: ELBOW PLOT (Optimal K)
# ============================================================================

def plot_elbow_curve(base_embeddings: np.ndarray, custom_embeddings: np.ndarray,
                     k_range: range = range(2, 11), save_path: str = './results/elbow_plot.png'):
    """Elbow method to find optimal number of clusters"""
    
    base_inertias = []
    custom_inertias = []
    base_silhouettes = []
    custom_silhouettes = []
    
    for k in k_range:
        # Base model
        kmeans_base = KMeans(n_clusters=k, random_state=42, n_init=10)
        labels_base = kmeans_base.fit_predict(base_embeddings)
        base_inertias.append(kmeans_base.inertia_)
        base_silhouettes.append(silhouette_score(base_embeddings, labels_base))
        
        # Custom model
        kmeans_custom = KMeans(n_clusters=k, random_state=42, n_init=10)
        labels_custom = kmeans_custom.fit_predict(custom_embeddings)
        custom_inertias.append(kmeans_custom.inertia_)
        custom_silhouettes.append(silhouette_score(custom_embeddings, labels_custom))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Inertia plot
    axes[0].plot(k_range, base_inertias, marker='o', label='Baseline', linewidth=2)
    axes[0].plot(k_range, custom_inertias, marker='s', label='Custom', linewidth=2)
    axes[0].set_title('Elbow Plot - Inertia', fontweight='bold')
    axes[0].set_xlabel('Number of Clusters (k)')
    axes[0].set_ylabel('Inertia')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Silhouette plot
    axes[1].plot(k_range, base_silhouettes, marker='o', label='Baseline', linewidth=2)
    axes[1].plot(k_range, custom_silhouettes, marker='s', label='Custom', linewidth=2)
    axes[1].set_title('Silhouette Score vs K', fontweight='bold')
    axes[1].set_xlabel('Number of Clusters (k)')
    axes[1].set_ylabel('Silhouette Score')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")


# ============================================================================
# VISUALIZATION 5: NEIGHBOR COHERENCE HEATMAP
# ============================================================================

def plot_neighbor_coherence(base_embeddings: np.ndarray, custom_embeddings: np.ndarray,
                            sources: list, k_values: list = [5, 10, 20],
                            save_path: str = './results/neighbor_coherence.png'):
    """Plot neighbor coherence for different k values"""
    
    base_coherences = []
    custom_coherences = []
    
    for k in k_values:
        base_result = neighbor_coherence_check(base_embeddings, sources, k)
        custom_result = neighbor_coherence_check(custom_embeddings, sources, k)
        base_coherences.append(base_result['avg_coherence'])
        custom_coherences.append(custom_result['avg_coherence'])
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x = np.arange(len(k_values))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, base_coherences, width, label='Baseline', alpha=0.8, color='steelblue')
    bars2 = ax.bar(x + width/2, custom_coherences, width, label='Custom', alpha=0.8, color='coral')
    
    ax.set_xlabel('k (Number of Neighbors)', fontweight='bold')
    ax.set_ylabel('Neighbor Coherence Score', fontweight='bold')
    ax.set_title('Same-Source Neighbor Coherence', fontweight='bold', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels([f'k={k}' for k in k_values])
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")


# ============================================================================
# MAIN ANALYSIS PIPELINE (UPDATED)
# ============================================================================

def run_complete_analysis(base_model_path: str = "sentence-transformers/all-MiniLM-L6-v2",
                         custom_model_path: str = "./custom_embedding_model", 
                         chunks_file: str = "extracted_chunks.json",
                         n_clusters: int = 5):
    """Run complete clustering and visualization analysis"""
    
    print("="*70)
    print("CLUSTERING & VISUALIZATION ANALYSIS")
    print("="*70)
    
    # Load data
    base_model, custom_model, texts, sources, chunks = load_models_and_data(
        base_model_path, custom_model_path, chunks_file
    )
    
    # Generate embeddings
    print("\nüî¢ Generating embeddings...")
    print("   Baseline model...")
    base_embeddings = generate_embeddings(base_model, texts)
    print("   Custom model...")
    custom_embeddings = generate_embeddings(custom_model, texts)
    
    # Compute metrics
    print(f"\nüìä Computing clustering metrics (k={n_clusters})...")
    base_metrics = compute_clustering_metrics(base_embeddings, n_clusters)
    custom_metrics = compute_clustering_metrics(custom_embeddings, n_clusters)
    
    print("\n   Baseline Metrics:")
    for k, v in base_metrics.items():
        if k != 'labels':
            print(f"      {k}: {v:.4f}")
    
    print("\n   Custom Metrics:")
    for k, v in custom_metrics.items():
        if k != 'labels':
            print(f"      {k}: {v:.4f}")
    
    # Cluster separation
    print("\nüìè Computing cluster separation...")
    base_separation = compute_cluster_separation(base_embeddings, base_metrics['labels'])
    custom_separation = compute_cluster_separation(custom_embeddings, custom_metrics['labels'])
    
    print(f"\n   Baseline - Separation Ratio: {base_separation['separation_ratio']:.4f}")
    print(f"   Custom - Separation Ratio: {custom_separation['separation_ratio']:.4f}")
    
    # Neighbor coherence
    print("\nüéØ Checking neighbor coherence...")
    base_coherence = neighbor_coherence_check(base_embeddings, sources, k=10)
    custom_coherence = neighbor_coherence_check(custom_embeddings, sources, k=10)
    
    print(f"\n   Baseline - Avg Coherence: {base_coherence['avg_coherence']:.4f}")
    print(f"   Custom - Avg Coherence: {custom_coherence['avg_coherence']:.4f}")
    
    # Similarity distribution
    print("\nüìà Computing similarity distributions...")
    base_sim_dist = pairwise_similarity_distribution(base_embeddings)
    custom_sim_dist = pairwise_similarity_distribution(custom_embeddings)
    
    # Generate visualizations
    print("\nüé® Generating visualizations...")
    
    plot_embeddings_2d(base_embeddings, custom_embeddings, 
                      base_metrics['labels'], custom_metrics['labels'],
                      method='umap', save_path='./results/embeddings_umap.png')
    
    plot_metrics_comparison(base_metrics, custom_metrics, 
                           save_path='./results/metrics_comparison.png')
    
    plot_similarity_distribution(base_sim_dist['similarities'], 
                                custom_sim_dist['similarities'],
                                save_path='./results/similarity_distribution.png')
    
    plot_elbow_curve(base_embeddings, custom_embeddings,
                    save_path='./results/elbow_plot.png')
    
    plot_neighbor_coherence(base_embeddings, custom_embeddings, sources,
                           save_path='./results/neighbor_coherence.png')
    
    # Summary report
    print("\n" + "="*70)
    print("SUMMARY REPORT")
    print("="*70)
    
    improvements = {
        'Silhouette Score': ((custom_metrics['silhouette'] - base_metrics['silhouette']) / base_metrics['silhouette'] * 100),
        'Calinski-Harabasz': ((custom_metrics['calinski_harabasz'] - base_metrics['calinski_harabasz']) / base_metrics['calinski_harabasz'] * 100),
        'Davies-Bouldin (lower is better)': ((base_metrics['davies_bouldin'] - custom_metrics['davies_bouldin']) / base_metrics['davies_bouldin'] * 100),
        'Separation Ratio': ((custom_separation['separation_ratio'] - base_separation['separation_ratio']) / base_separation['separation_ratio'] * 100),
        'Neighbor Coherence': ((custom_coherence['avg_coherence'] - base_coherence['avg_coherence']) / base_coherence['avg_coherence'] * 100)
    }
    
    print("\nüìà Custom Model Improvements:")
    for metric, improvement in improvements.items():
        symbol = "‚úì" if improvement > 0 else "‚úó"
        print(f"   {symbol} {metric}: {improvement:+.2f}%")
    
    print("\n‚úÖ Analysis complete! All plots saved.")
    print("="*70)
    
    # Return metrics for report generation
    return {
        'base_metrics': base_metrics,
        'custom_metrics': custom_metrics,
        'base_separation': base_separation,
        'custom_separation': custom_separation,
        'base_coherence': base_coherence,
        'custom_coherence': custom_coherence
    }


# ============================================================================
# USAGE
# ============================================================================

if __name__ == "__main__":
    results = run_complete_analysis(
        base_model_path="sentence-transformers/all-MiniLM-L6-v2",
        custom_model_path="./trained_model/custom_model/",
        chunks_file="./train_data/chunks.json",  # YOUR FILE
        n_clusters=5
    )

CLUSTERING & VISUALIZATION ANALYSIS
üì• Loading data from ./train_data/chunks.json...
   ‚úì Loaded 407 chunks from 5 sources

üî¢ Generating embeddings...
   Baseline model...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:02<00:00,  6.25it/s]


   Custom model...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13/13 [00:01<00:00,  7.01it/s]



üìä Computing clustering metrics (k=5)...

   Baseline Metrics:
      silhouette: 0.0658
      calinski_harabasz: 17.5987
      davies_bouldin: 3.3111
      inertia: 247.6582

   Custom Metrics:
      silhouette: 0.0608
      calinski_harabasz: 17.3790
      davies_bouldin: 3.5625
      inertia: 326.2460

üìè Computing cluster separation...

   Baseline - Separation Ratio: 0.7627
   Custom - Separation Ratio: 0.6934

üéØ Checking neighbor coherence...

   Baseline - Avg Coherence: 0.6157
   Custom - Avg Coherence: 0.6531

üìà Computing similarity distributions...

üé® Generating visualizations...
‚úì Saved: ./results/embeddings_umap.png
‚úì Saved: ./results/metrics_comparison.png
‚úì Saved: ./results/similarity_distribution.png
‚úì Saved: ./results/elbow_plot.png
‚úì Saved: ./results/neighbor_coherence.png

SUMMARY REPORT

üìà Custom Model Improvements:
   ‚úó Silhouette Score: -7.63%
   ‚úó Calinski-Harabasz: -1.25%
   ‚úó Davies-Bouldin (lower is better): -7.59%
   ‚úó Separat