In [1]:
# 1 Import necessary libraries
import pandas as pd
import numpy as np
import re
import string
import time
import os
import warnings
from collections import defaultdict
import torch
import torch.nn.functional as F
from functools import lru_cache
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Import distance metrics
from Levenshtein import distance as levenshtein_distance
from Levenshtein import jaro_winkler, ratio as levenshtein_ratio
import textdistance
from fuzzywuzzy import fuzz
import jellyfish
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Try importing transformers with graceful fallback
try:
    from transformers import AutoTokenizer, AutoModel
    from sentence_transformers import SentenceTransformer
    transformers_available = True
    logger.info("Transformers library available for BERT embeddings")
except ImportError:
    transformers_available = False
    logger.warning("Transformers library not available. Will use TF-IDF fallback.")

# Try importing pyahocorasick with fallback
try:
    import pyahocorasick
    aho_corasick_available = True
    logger.info("pyahocorasick is available")
except ImportError:
    logger.warning("pyahocorasick not available. Using fallback implementation.")
    aho_corasick_available = False

# Import visualization libraries if available
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    visualization_available = True
except ImportError:
    visualization_available = False
    logger.warning("Visualization libraries not available. Plots will be skipped.")

# Set up device for pytorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Suppress warnings
warnings.filterwarnings('ignore')

# Constants for the algorithm
SEMANTIC_SIMILARITY_THRESHOLD = 0.85  # Threshold for considering semantic similarity high
STRING_SIMILARITY_THRESHOLD = 0.80    # Threshold for string similarity
ACRONYM_SIMILARITY_THRESHOLD = 0.75   # Threshold for acronym formation similarity
DEFAULT_BATCH_SIZE = 64               # Default batch size for processing

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

logger.info("All libraries imported successfully!")

2025-04-07 23:54:40,241 - INFO - Transformers library available for BERT embeddings
2025-04-07 23:54:40,836 - INFO - Using device: cpu
2025-04-07 23:54:40,842 - INFO - All libraries imported successfully!


In [3]:
# Cell 2: Advanced BERT Embedder

class AdvancedBERTEmbedder:
    """
    Advanced BERT embedder using state-of-the-art models with improved pooling strategies,
    caching, and domain adaptation specifically optimized for merchant name matching.
    """
    
    def __init__(self, model_name=None, pooling_strategy='mean', cache_size=5000, device=None):
        """
        Initialize advanced BERT embedder with specified model and pooling strategy.
        
        Args:
            model_name (str): Name of the pre-trained model to use
            pooling_strategy (str): Pooling strategy ('mean', 'cls', 'max', or 'attention')
            cache_size (int): Size of the LRU cache for embeddings
            device: Device to run the model on (cuda or cpu)
        """
        # Select the best model based on availability and performance
        self.model_candidates = [
            'sentence-transformers/all-mpnet-base-v2',         # Best performance but slower
            'sentence-transformers/all-distilroberta-v1',      # Good balance of performance/speed
            'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',  # Good for international merchants
            'sentence-transformers/all-MiniLM-L12-v2'          # Fast but less accurate
        ]
        
        # If no model specified, use the first available one
        if model_name is None:
            self.model_name = self.model_candidates[0]
        else:
            self.model_name = model_name
            
        self.pooling_strategy = pooling_strategy
        self.max_sequence_length = 512  # BERT's limit
        self.cache_size = cache_size
        
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
            
        self.initialized = False
        self.domain_adapted = False
        self.domain_data = None
        
        # Initialize pre-trained model if transformers available
        if transformers_available:
            try:
                logger.info(f"Loading advanced BERT model '{self.model_name}'...")
                
                # Use SentenceTransformer for better performance
                self.model = SentenceTransformer(self.model_name).to(self.device)
                
                # Also initialize a tokenizer for fine-tuning
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                
                self.initialized = True
                logger.info(f"Advanced BERT model loaded successfully on {self.device}")
                
                # Set up LRU cache for embeddings
                self.encode = lru_cache(maxsize=self.cache_size)(self._encode_uncached)
                
            except Exception as e:
                logger.error(f"Error initializing BERT model: {e}")
                self.initialized = False
        
        # Initialize TF-IDF fallback with improved settings
        if not self.initialized:
            # Using character n-grams for better handling of typos and abbreviations
            self.tfidf_vectorizer = TfidfVectorizer(
                analyzer='char_wb', 
                ngram_range=(2, 5),  # Increased n-gram range
                min_df=2,            # Ignore very rare n-grams
                max_df=0.95          # Ignore very common n-grams
            )
            self.tfidf_fitted = False
            logger.info("Using enhanced TF-IDF fallback for embeddings")
    
    def adapt_to_domain(self, examples_df, epochs=5, learning_rate=2e-5):
        """
        Perform extensive domain adaptation to improve merchant name understanding.
        
        Args:
            examples_df (DataFrame): DataFrame with matched merchant names
            epochs (int): Number of adaptation epochs
            learning_rate (float): Learning rate for adaptation
        """
        if not self.initialized or not transformers_available:
            logger.warning("Domain adaptation skipped - model not initialized with transformers")
            return
        
        if self.domain_adapted and self.domain_data is not None:
            logger.info("Model already adapted to domain. Skipping.")
            return
        
        # Store domain data for potential reuse
        self.domain_data = examples_df.copy()
        
        # Extract positive pairs (matching merchant names)
        positive_pairs = []
        
        # Check different column combinations to find matched pairs
        if 'Enhanced_Score' in examples_df.columns:
            for _, row in examples_df.iterrows():
                if row['Enhanced_Score'] >= 0.7:  # Reduced threshold to get more examples
                    positive_pairs.append((row['Acronym'], row['Full_Name']))
        elif 'Expected_Match' in examples_df.columns:
            for _, row in examples_df.iterrows():
                if row['Expected_Match']:
                    positive_pairs.append((row['Acronym'], row['Full_Name']))
        else:
            # Just use all pairs as examples if no scoring columns exist
            for _, row in examples_df.iterrows():
                positive_pairs.append((row['Acronym'], row['Full_Name']))
        
        # Skip if not enough examples
        if len(positive_pairs) < 5:
            logger.warning("Not enough high-quality examples for adaptation")
            return
        
        # Create negative pairs for contrastive learning
        # (randomly pair non-matching merchant names)
        all_names = examples_df['Acronym'].tolist() + examples_df['Full_Name'].tolist()
        negative_pairs = []
        
        for i in range(min(len(positive_pairs), 100)):  # Limit to 100 negative pairs
            while True:
                name1 = np.random.choice(all_names)
                name2 = np.random.choice(all_names)
                pair = (name1, name2)
                # Check that this isn't actually a positive pair
                if pair not in positive_pairs and (name2, name1) not in positive_pairs and name1 != name2:
                    negative_pairs.append(pair)
                    break
        
        # Implement enhanced adaptation with contrastive learning
        logger.info(f"Starting domain adaptation with {len(positive_pairs)} positive and {len(negative_pairs)} negative pairs...")
        
        # Create a fine-tunable version of the model
        if hasattr(self.model, 'auto_model'):
            # For SentenceTransformer models
            model_for_tuning = self.model.auto_model
        else:
            # Fallback to directly using the model
            model_for_tuning = self.model
        
        model_for_tuning.train()
        optimizer = torch.optim.AdamW(model_for_tuning.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        
        batch_size = 8  # Small batch size for fine-tuning
        
        # Training loop
        for epoch in range(epochs):
            total_loss = 0
            
            # Process positive pairs (push closer together)
            for i in range(0, len(positive_pairs), batch_size):
                batch_pairs = positive_pairs[i:i+batch_size]
                
                # Prepare anchor and positive texts
                anchors = [pair[0] for pair in batch_pairs]
                positives = [pair[1] for pair in batch_pairs]
                
                # Get embeddings using the model
                if hasattr(self.model, 'encode'):
                    # For SentenceTransformer models
                    anchor_embeddings = self.model.encode(anchors, convert_to_tensor=True)
                    positive_embeddings = self.model.encode(positives, convert_to_tensor=True)
                else:
                    # If using raw transformer model
                    # Tokenize
                    anchor_inputs = self.tokenizer(anchors, return_tensors='pt', padding=True, 
                                            truncation=True, max_length=self.max_sequence_length).to(self.device)
                    positive_inputs = self.tokenizer(positives, return_tensors='pt', padding=True, 
                                            truncation=True, max_length=self.max_sequence_length).to(self.device)
                    
                    # Forward pass
                    with torch.no_grad():
                        anchor_outputs = model_for_tuning(**anchor_inputs)
                        positive_outputs = model_for_tuning(**positive_inputs)
                    
                    # Get embeddings (CLS token)
                    anchor_embeddings = anchor_outputs.last_hidden_state[:, 0, :]
                    positive_embeddings = positive_outputs.last_hidden_state[:, 0, :]
                
                # Normalize embeddings
                anchor_embeddings = F.normalize(anchor_embeddings, p=2, dim=1)
                positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
                
                # Contrastive loss (push matching names closer)
                similarity = F.cosine_similarity(anchor_embeddings, positive_embeddings)
                positive_loss = (1.0 - similarity).mean()
                
                # Process negative pairs if available (push further apart)
                negative_loss = 0
                if negative_pairs:
                    j = i % max(1, len(negative_pairs) - batch_size)
                    neg_batch_pairs = negative_pairs[j:j+batch_size]
                    
                    neg_anchors = [pair[0] for pair in neg_batch_pairs]
                    neg_samples = [pair[1] for pair in neg_batch_pairs]
                    
                    # Get embeddings
                    if hasattr(self.model, 'encode'):
                        neg_anchor_embeddings = self.model.encode(neg_anchors, convert_to_tensor=True)
                        neg_sample_embeddings = self.model.encode(neg_samples, convert_to_tensor=True)
                    else:
                        # Tokenize
                        neg_anchor_inputs = self.tokenizer(neg_anchors, return_tensors='pt', padding=True, 
                                                truncation=True, max_length=self.max_sequence_length).to(self.device)
                        neg_sample_inputs = self.tokenizer(neg_samples, return_tensors='pt', padding=True, 
                                                truncation=True, max_length=self.max_sequence_length).to(self.device)
                        
                        # Forward pass
                        with torch.no_grad():
                            neg_anchor_outputs = model_for_tuning(**neg_anchor_inputs)
                            neg_sample_outputs = model_for_tuning(**neg_sample_inputs)
                        
                        # Get embeddings
                        neg_anchor_embeddings = neg_anchor_outputs.last_hidden_state[:, 0, :]
                        neg_sample_embeddings = neg_sample_outputs.last_hidden_state[:, 0, :]
                    
                    # Normalize embeddings
                    neg_anchor_embeddings = F.normalize(neg_anchor_embeddings, p=2, dim=1)
                    neg_sample_embeddings = F.normalize(neg_sample_embeddings, p=2, dim=1)
                    
                    # Push negative pairs apart (with margin)
                    neg_similarity = F.cosine_similarity(neg_anchor_embeddings, neg_sample_embeddings)
                    negative_loss = F.relu(neg_similarity - 0.3).mean()  # Push apart with margin of 0.3
                
                # Combined loss
                loss = positive_loss + negative_loss
                
                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model_for_tuning.parameters(), 1.0)  # Gradient clipping
                optimizer.step()
                
                total_loss += loss.item()
            
            # Step the scheduler
            scheduler.step()
            
            # Calculate average loss
            avg_loss = total_loss / (len(positive_pairs) // batch_size + 1)
            logger.info(f"  Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")
        
        # Reset model to evaluation mode
        model_for_tuning.eval()
        
        # Update internal state
        self.domain_adapted = True
        
        # Clear the embedding cache after adaptation
        if hasattr(self, 'encode') and hasattr(self.encode, 'cache_clear'):
            self.encode.cache_clear()
            
        logger.info(f"Domain adaptation completed successfully")
    
    def _mean_pooling(self, model_output, attention_mask):
        """Mean pooling - take average of all token embeddings"""
        token_embeddings = model_output[0]  # First element contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    def _cls_pooling(self, model_output, attention_mask):
        """CLS pooling - use the [CLS] token embedding"""
        return model_output[0][:, 0]
    
    def _max_pooling(self, model_output, attention_mask):
        """Max pooling - take max of all token embeddings"""
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
        return torch.max(token_embeddings, 1)[0]
    
    def _attention_pooling(self, model_output, attention_mask):
        """Attention pooling - use attention weights to create weighted average"""
        token_embeddings = model_output[0]  # [batch_size, seq_len, hidden_size]
        
        # Create a simple attention mechanism
        attention_scores = torch.matmul(
            token_embeddings, 
            token_embeddings.mean(dim=1).unsqueeze(-1)
        ).squeeze(-1)  # [batch_size, seq_len]
        
        # Apply mask
        attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Apply attention weights to token embeddings
        weighted_sum = torch.bmm(
            attention_weights.unsqueeze(1), 
            token_embeddings
        ).squeeze(1)  # [batch_size, hidden_size]
        
        return weighted_sum
    
    def _get_pooled_embeddings(self, model_output, attention_mask):
        """Apply the selected pooling strategy"""
        if self.pooling_strategy == 'mean':
            return self._mean_pooling(model_output, attention_mask)
        elif self.pooling_strategy == 'cls':
            return self._cls_pooling(model_output, attention_mask)
        elif self.pooling_strategy == 'max':
            return self._max_pooling(model_output, attention_mask)
        elif self.pooling_strategy == 'attention':
            return self._attention_pooling(model_output, attention_mask)
        else:
            # Default to mean pooling
            return self._mean_pooling(model_output, attention_mask)
    
    def fit(self, texts):
        """Fit the TF-IDF vectorizer on a corpus of texts (only needed for TF-IDF fallback)"""
        if not self.initialized:
            # Fit TF-IDF vectorizer
            self.tfidf_vectorizer.fit(texts)
            self.tfidf_fitted = True
            logger.info("TF-IDF vectorizer fitted on corpus")
    
    def _encode_uncached(self, texts, batch_size=32, show_progress=False):
        """
        Internal method to encode texts into embeddings without caching.
        This is wrapped by the LRU cache decorator in the initializer.
        """
        # Handle single text input
        if isinstance(texts, str):
            texts = [texts]
        
        # Return empty array for empty input
        if len(texts) == 0:
            return np.array([])
        
        # Use pre-trained BERT if available
        if self.initialized:
            try:
                # Use SentenceTransformer's encode method directly if available
                if hasattr(self.model, 'encode'):
                    # Faster encoding with batching
                    embeddings = self.model.encode(
                        texts, 
                        batch_size=batch_size, 
                        show_progress_bar=show_progress,
                        convert_to_numpy=True
                    )
                    return embeddings
                else:
                    # Manual batching with raw transformer model
                    all_embeddings = []
                    
                    for i in range(0, len(texts), batch_size):
                        if show_progress and i % (batch_size * 10) == 0:
                            logger.info(f"Processing batch {i//batch_size + 1}/{(len(texts)//batch_size) + 1}")
                        
                        batch_texts = texts[i:i+batch_size]
                        
                        # Tokenize
                        encoded_input = self.tokenizer(
                            batch_texts, 
                            padding=True, 
                            truncation=True, 
                            max_length=self.max_sequence_length,
                            return_tensors='pt'
                        ).to(self.device)
                        
                        # Compute token embeddings
                        with torch.no_grad():
                            model_output = self.model(**encoded_input)
                            batch_embeddings = self._get_pooled_embeddings(model_output, encoded_input['attention_mask'])
                            all_embeddings.append(batch_embeddings.cpu().numpy())
                    
                    return np.vstack(all_embeddings)
                
            except Exception as e:
                logger.error(f"Error in BERT encoding: {e}")
                logger.info("Falling back to TF-IDF encoding")
                # Fall back to TF-IDF
                self.initialized = False
        
        # Use TF-IDF fallback
        if not self.tfidf_fitted:
            self.fit(texts)
        
        return self.tfidf_vectorizer.transform(texts).toarray()
    
    def compute_similarity(self, text1, text2):
        """
        Compute cosine similarity between two texts using the pre-trained model
        
        Args:
            text1: First text
            text2: Second text
            
        Returns:
            float: Cosine similarity score
        """
        # Empty text check
        if not text1 or not text2:
            return 0.0
            
        # Get embeddings for both texts
        emb1 = self.encode([text1])[0]
        emb2 = self.encode([text2])[0]
        
        # Compute cosine similarity
        norm1 = np.linalg.norm(emb1)
        norm2 = np.linalg.norm(emb2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
            
        return np.dot(emb1, emb2) / (norm1 * norm2)

# Initialize advanced BERT embedder with state-of-the-art model
bert_embedder = AdvancedBERTEmbedder(
    model_name='sentence-transformers/all-mpnet-base-v2',
    pooling_strategy='mean',
    cache_size=10000,  # Increased cache size
    device=device
)
logger.info("Advanced BERT embedder initialized successfully!")

2025-04-07 23:54:45,331 - INFO - Loading advanced BERT model 'sentence-transformers/all-mpnet-base-v2'...
2025-04-07 23:54:45,332 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2025-04-07 23:54:46,443 - INFO - Use pytorch device: cpu
2025-04-07 23:54:46,783 - INFO - Advanced BERT model loaded successfully on cpu
2025-04-07 23:54:46,785 - INFO - Advanced BERT embedder initialized successfully!


In [6]:
# Cell 3: Merchant Name Preprocessing
class MerchantPreprocessor:
    """
    Enhanced merchant name preprocessor with industry-specific pattern handling 
    and extensive abbreviation dictionaries.
    """
    
    def __init__(self):
        """Initialize merchant preprocessor with dictionaries and patterns"""
        # Load comprehensive abbreviation dictionaries
        self.abbreviations = self._load_abbreviation_dictionary()
        self.domain_abbreviations = self._load_domain_abbreviations()
        self.stopwords = self._load_stopwords()
        self.domain_stopwords = self._load_domain_stopwords()
        
        # Compile common business name patterns
        self.business_suffixes = self._compile_business_suffixes()
        self.business_patterns = self._compile_business_patterns()
    
    def _load_abbreviation_dictionary(self):
        """Load comprehensive abbreviation dictionary with merchant-specific terms"""
        return {
            # Banking & Financial Institutions (expanded)
            'bofa': 'bank of america', 'b of a': 'bank of america',
            'boa': 'bank of america', 'bac': 'bank of america',
            'jpm': 'jpmorgan chase', 'jpm chase': 'jpmorgan chase',
            'jpmc': 'jpmorgan chase', 'chase': 'jpmorgan chase',
            'wf': 'wells fargo', 'wfb': 'wells fargo bank',
            'citi': 'citibank', 'citi bank': 'citibank',
            'gs': 'goldman sachs', 'ms': 'morgan stanley',
            'db': 'deutsche bank', 'hsbc': 'hongkong and shanghai banking corporation',
            'amex': 'american express', 'amx': 'american express',
            'usb': 'us bank', 'rbc': 'royal bank of canada',
            'pnc': 'pnc financial services', 'td': 'toronto dominion bank',
            'bny': 'bank of new york', 'bnyc': 'bank of new york mellon',
            'cba': 'commonwealth bank of australia', 'nab': 'national australia bank',
            'rba': 'reserve bank of australia', 'westpac': 'western pacific bank',
            'fargo': 'wells fargo', 'chase bank': 'jpmorgan chase',
            'usaa': 'united services automobile association',
            'disc': 'discover', 'discover card': 'discover',
            'cap1': 'capital one', 'cap one': 'capital one',
            'citi': 'citibank', 'fidelity': 'fidelity investments',
            'schwab': 'charles schwab', 'etrade': 'e-trade financial',
            
            # Fast Food & Restaurant Chains (expanded)
            'mcd': 'mcdonalds', 'mcds': 'mcdonalds', 'md': 'mcdonalds',
            'mickey ds': 'mcdonalds', 'maccas': 'mcdonalds',
            'bk': 'burger king', 'kfc': 'kentucky fried chicken',
            'sbux': 'starbucks', 'sb': 'starbucks', 'starbks': 'starbucks',
            'tb': 'taco bell', 'wen': 'wendys', 'wendys': 'wendys restaurant',
            'dq': 'dairy queen', 'ph': 'pizza hut', 'pzh': 'pizza hut',
            'dnkn': 'dunkin donuts', 'cfa': 'chick fil a',
            'cmg': 'chipotle mexican grill', 'chipotle': 'chipotle mexican grill',
            'ihop': 'international house of pancakes',
            'tgi': 'tgi fridays', 'tgif': 'tgi fridays',
            'bww': 'buffalo wild wings', 'pjs': 'papa johns',
            'ljs': 'long john silvers', 'popeyes': 'popeyes louisiana kitchen',
            'subway': 'subway restaurant', 'dominos': 'dominos pizza',
            'panera': 'panera bread', 'five guys': 'five guys burgers and fries',
            'applebees': 'applebees neighborhood grill',
            'arbys': 'arbys restaurant', 'hooters': 'hooters restaurant',
            'cracker barrel': 'cracker barrel old country store',
            
            # Retail & Supermarkets (expanded)
            'wmt': 'walmart', 'wal mart': 'walmart', 'walm': 'walmart',
            'tgt': 'target', 'targ': 'target', 'tarjay': 'target',
            'costco': 'costco wholesale', 'cost': 'costco',
            'kroger': 'the kroger co', 'krg': 'kroger',
            'cvs': 'cvs pharmacy', 'cvs rx': 'cvs pharmacy',
            'wag': 'walgreens', 'walg': 'walgreens',
            'hd': 'home depot', 'thd': 'the home depot',
            'low': 'lowes', 'lowes': 'lowes home improvement',
            'bby': 'best buy', 'bb': 'best buy',
            'amzn': 'amazon', 'amazon': 'amazon.com',
            '711': '7-eleven', '7-11': '7-eleven',
            'tjx': 'tj maxx', 'tjmaxx': 'tj maxx',
            'wfm': 'whole foods market', 'whole foods': 'whole foods market',
            'sams': 'sams club', 'sams club': 'sams club',
            'macys': 'macys department store', 'nordstrom': 'nordstrom',
            'kohls': 'kohls department store', 'publix': 'publix super markets',
            'safeway': 'safeway supermarket', 'albertsons': 'albertsons supermarket',
            'aldi': 'aldi supermarket', 'lidl': 'lidl supermarket',
            'trader joes': 'trader joes', 'tj': 'trader joes',
            
            # Tech Companies (expanded)
            'msft': 'microsoft', 'ms': 'microsoft', 'aapl': 'apple',
            'goog': 'google', 'googl': 'google', 'alphabet': 'google',
            'amzn': 'amazon', 'fb': 'facebook', 'meta': 'meta platforms',
            'nflx': 'netflix', 'tsla': 'tesla motors', 'tsla': 'tesla',
            'ibm': 'international business machines', 'csco': 'cisco systems',
            'orcl': 'oracle', 'intc': 'intel', 'amd': 'advanced micro devices',
            'nvda': 'nvidia', 'adbe': 'adobe', 'crm': 'salesforce',
            'dell': 'dell technologies', 'hpe': 'hewlett packard enterprise',
            'hp': 'hewlett packard', 'vmw': 'vmware', 'twtr': 'twitter',
            'ebay': 'ebay', 'pypl': 'paypal', 'sqsp': 'squarespace',
            'lyft': 'lyft', 'uber': 'uber technologies',
            
            # Common address components (expanded)
            'rd': 'road', 'st': 'street', 'ave': 'avenue', 
            'blvd': 'boulevard', 'ctr': 'center', 'ln': 'lane', 
            'dr': 'drive', 'pl': 'place', 'ct': 'court',
            'hwy': 'highway', 'pkwy': 'parkway', 'sq': 'square',
            'cir': 'circle', 'ter': 'terrace', 'expy': 'expressway',
            'fwy': 'freeway', 'tpke': 'turnpike', 'crk': 'creek',
            'hvn': 'haven', 'xing': 'crossing', 'vlg': 'village',
            'spgs': 'springs', 'mtn': 'mountain', 'lk': 'lake',
            'n': 'north', 's': 'south', 'e': 'east', 'w': 'west',
            'ne': 'northeast', 'nw': 'northwest', 'se': 'southeast',
            'sw': 'southwest', 'apt': 'apartment', 'ste': 'suite',
            'bldg': 'building', 'fl': 'floor', 'rm': 'room',
            
            # Expanded location-based prefixes
            'norcal': 'northern california', 'socal': 'southern california',
            'nyc': 'new york city', 'la': 'los angeles', 'sf': 'san francisco',
            'chi': 'chicago', 'atl': 'atlanta', 'hou': 'houston',
            'dfw': 'dallas fort worth', 'mia': 'miami', 'bos': 'boston',
            'dc': 'washington dc', 'sea': 'seattle', 'den': 'denver',
            'phx': 'phoenix', 'pdx': 'portland', 'msp': 'minneapolis saint paul',
            'stl': 'saint louis', 'pit': 'pittsburgh', 'cin': 'cincinnati',
            'cle': 'cleveland', 'slc': 'salt lake city', 'nola': 'new orleans',
            'kcmo': 'kansas city', 'nash': 'nashville', 'ind': 'indianapolis',
            
            # Gas stations and convenience stores
            'bp': 'british petroleum', 'chevron': 'chevron gas station',
            'shell': 'shell gas station', 'exxon': 'exxon gas station',
            'mobil': 'mobil gas station', 'texaco': 'texaco gas station',
            'conoco': 'conoco gas station', 'marathon': 'marathon gas station',
            'sunoco': 'sunoco gas station', 'valero': 'valero gas station',
            'speedway': 'speedway gas station', 'circle k': 'circle k convenience store',
            'quiktrip': 'quiktrip convenience store', 'qt': 'quiktrip',
            'cumbys': 'cumberland farms', 'wawa': 'wawa convenience store',
            'sheetz': 'sheetz convenience store', 'racetrac': 'racetrac gas station',
            'caseys': 'caseys general store', 'murphy': 'murphy usa gas station',
            
            # Airlines
            'aa': 'american airlines', 'aal': 'american airlines', 
            'dal': 'delta airlines', 'dl': 'delta air lines',
            'ual': 'united airlines', 'ua': 'united airlines',
            'luv': 'southwest airlines', 'wn': 'southwest airlines',
            'jblu': 'jetblue airways', 'b6': 'jetblue airways',
            'alk': 'alaska airlines', 'as': 'alaska airlines',
            'save': 'spirit airlines', 'nk': 'spirit airlines',
            'ulcc': 'frontier airlines', 'f9': 'frontier airlines',
            
            # Hospitality
            'mar': 'marriott', 'hilton': 'hilton hotels',
            'hyatt': 'hyatt hotels', 'wyndham': 'wyndham hotels',
            'ihg': 'intercontinental hotels group', 'choice': 'choice hotels',
            'bwi': 'best western', 'accor': 'accor hotels',
            'mhgc': 'mgm resorts', 'mgm': 'mgm resorts',
            'lvs': 'las vegas sands', 'wynn': 'wynn resorts',
            'h': 'hyatt', 'hot': 'starwood hotels'
        }
    
    def _load_domain_abbreviations(self):
        """Load domain-specific abbreviation dictionaries"""
        return {
            'Medical': {
                'dr': 'doctor', 'hosp': 'hospital', 'med': 'medical',
                'clin': 'clinic', 'pharm': 'pharmacy', 'lab': 'laboratory',
                'dept': 'department', 'ctr': 'center', 'inst': 'institute',
                'er': 'emergency room', 'icu': 'intensive care unit',
                'ob': 'obstetrics', 'gyn': 'gynecology', 'peds': 'pediatrics',
                'ortho': 'orthopedics', 'onc': 'oncology', 'neuro': 'neurology',
                'radiol': 'radiology', 'card': 'cardiology', 'derm': 'dermatology',
                'ent': 'ear nose throat', 'opht': 'ophthalmology', 'uro': 'urology',
                'gastro': 'gastroenterology', 'hem': 'hematology', 'oncol': 'oncology',
                'rheum': 'rheumatology', 'endo': 'endocrinology', 'pulm': 'pulmonology',
                'rx': 'prescription', 'surg': 'surgical', 'rehab': 'rehabilitation',
                'path': 'pathology', 'anesth': 'anesthesiology', 'psych': 'psychiatry'
            },
            'Government': {
                'govt': 'government', 'dept': 'department', 'admin': 'administration',
                'auth': 'authority', 'fed': 'federal', 'natl': 'national',
                'comm': 'commission', 'sec': 'secretary', 'org': 'organization',
                'div': 'division', 'bur': 'bureau', 'off': 'office',
                'min': 'ministry', 'reg': 'regional', 'dist': 'district',
                'cncl': 'council', 'cmte': 'committee', 'subcmte': 'subcommittee',
                'doj': 'department of justice', 'dod': 'department of defense',
                'dos': 'department of state', 'dot': 'department of transportation',
                'doe': 'department of energy', 'dol': 'department of labor',
                'dhs': 'department of homeland security', 'epa': 'environmental protection agency',
                'fbi': 'federal bureau of investigation', 'cia': 'central intelligence agency',
                'irs': 'internal revenue service', 'usps': 'united states postal service',
                'sba': 'small business administration', 'va': 'veterans affairs',
                'omb': 'office of management and budget', 'gao': 'government accountability office'
            },
            'Education': {
                'univ': 'university', 'coll': 'college', 'acad': 'academy',
                'elem': 'elementary', 'sch': 'school', 'inst': 'institute',
                'dept': 'department', 'lib': 'library', 'lab': 'laboratory',
                'fac': 'faculty', 'prof': 'professor', 'assoc': 'associate',
                'asst': 'assistant', 'adm': 'administration', 'stdnt': 'student',
                'grad': 'graduate', 'undergrad': 'undergraduate', 'alum': 'alumni',
                'edu': 'education', 'hs': 'high school', 'ms': 'middle school',
                'jr': 'junior', 'sr': 'senior', 'tech': 'technical',
                'comm': 'community', 'uni': 'university', 'poly': 'polytechnic',
                'prsch': 'preschool', 'k12': 'kindergarten through 12th grade'
            },
            'Financial': {
                'fin': 'financial', 'svcs': 'services', 'mgmt': 'management',
                'assoc': 'associates', 'intl': 'international', 'grp': 'group',
                'corp': 'corporation', 'cap': 'capital', 'inv': 'investment',
                'asset': 'asset management', 'sec': 'securities', 'adv': 'advisors',
                'tr': 'trust', 'port': 'portfolio', 'acct': 'account',
                'bal': 'balance', 'stmt': 'statement', 'equ': 'equity',
                'div': 'dividend', 'ret': 'retirement', 'pens': 'pension',
                'ira': 'individual retirement account', '401k': '401k retirement plan',
                'hdg': 'holdings', 'fx': 'foreign exchange', 'mm': 'money market',
                'pe': 'private equity', 'vc': 'venture capital', 'am': 'asset management',
                'wm': 'wealth management', 'pm': 'portfolio management',
                'cc': 'credit card', 'mtg': 'mortgage', 'ins': 'insurance'
            },
            'Restaurant': {
                'rest': 'restaurant', 'cafe': 'cafeteria', 'grill': 'grillery',
                'brew': 'brewery', 'bar': 'bar and grill', 'bbq': 'barbecue',
                'deli': 'delicatessen', 'stk': 'steakhouse', 'bf': 'breakfast',
                'din': 'dinner', 'chs': 'cheese', 'ckn': 'chicken',
                'brgr': 'burger', 'piz': 'pizza', 'mex': 'mexican',
                'ital': 'italian', 'chin': 'chinese', 'jpn': 'japanese',
                'thai': 'thai', 'ind': 'indian', 'korea': 'korean',
                'viet': 'vietnamese', 'med': 'mediterranean', 'seafd': 'seafood',
                'sushi': 'sushi restaurant', 'taco': 'taco restaurant',
                'sand': 'sandwich', 'bagel': 'bagel shop', 'bkry': 'bakery',
                'coff': 'coffee', 'juice': 'juice bar', 'smth': 'smoothie'
            },
            'Retail': {
                'dept': 'department', 'str': 'store', 'suprmkt': 'supermarket',
                'groc': 'grocery', 'pharm': 'pharmacy', 'disc': 'discount',
                'whs': 'warehouse', 'outlet': 'outlet store', 'fash': 'fashion',
                'appl': 'appliance', 'elect': 'electronics', 'furn': 'furniture',
                'hdw': 'hardware', 'jwlr': 'jewelry', 'btq': 'boutique',
                'cloth': 'clothing', 'shoes': 'shoe store', 'sport': 'sporting goods',
                'toy': 'toy store', 'book': 'book store', 'music': 'music store',
                'auto': 'auto parts', 'pet': 'pet supplies', 'craft': 'craft store',
                'opt': 'optical', 'cosm': 'cosmetics', 'lux': 'luxury',
                'bargn': 'bargain', 'homegd': 'home goods', 'gard': 'garden',
                'off': 'office supplies', 'tech': 'technology'
            }
        }
    
    def _load_stopwords(self):
        """Load general stopwords for preprocessing"""
        return {
            'inc', 'llc', 'co', 'ltd', 'corp', 'plc', 'na', 'the', 
            'and', 'of', 'for', 'in', 'a', 'an', 'by', 'to', 'at',
            'corporation', 'incorporated', 'company', 'limited',
            'with', 'from', 'as', 'on', 'group', 'services', 'international',
            'enterprises', 'holdings', 'global', 'worldwide', 'partners',
            'associates', 'industries', 'solutions', 'systems', 'technologies'
        }
    
    def _load_domain_stopwords(self):
        """Load domain-specific stopwords"""
        return {
            'Medical': {
                'center', 'healthcare', 'medical', 'health', 'care', 'services', 
                'clinic', 'hospital', 'center', 'associates', 'physicians',
                'specialists', 'group', 'practice', 'medicine', 'diagnostic',
                'treatment', 'therapy', 'rehabilitation', 'wellness', 'urgent',
                'emergency', 'family', 'primary', 'specialty', 'surgical',
                'memorial', 'regional', 'community', 'university', 'medical center'
            },
            'Government': {
                'department', 'office', 'agency', 'bureau', 'division', 'authority', 
                'administration', 'commission', 'board', 'committee', 'council',
                'district', 'federal', 'state', 'county', 'city', 'municipal',
                'regional', 'national', 'international', 'public', 'affairs',
                'services', 'development', 'program', 'project', 'initiative',
                'regulatory', 'oversight', 'enforcement', 'compliance'
            },
            'Education': {
                'university', 'college', 'school', 'institute', 'academy', 'education', 
                'learning', 'campus', 'department', 'faculty', 'studies',
                'research', 'sciences', 'arts', 'technology', 'engineering',
                'mathematics', 'business', 'law', 'medicine', 'public',
                'community', 'state', 'technical', 'vocational', 'elementary',
                'middle', 'high', 'district', 'board', 'trustees'
            },
            'Financial': {
                'financial', 'services', 'management', 'capital', 'investment', 
                'banking', 'advisor', 'wealth', 'asset', 'fund', 'equity',
                'securities', 'brokerage', 'exchange', 'traders', 'trading',
                'markets', 'partners', 'associates', 'advisors', 'planners',
                'retirement', 'insurance', 'trust', 'savings', 'loan',
                'credit', 'mortgage', 'real estate', 'properties'
            },
            'Restaurant': {
                'restaurant', 'cafe', 'diner', 'eatery', 'grill', 'kitchen', 
                'bar', 'house', 'bistro', 'pizzeria', 'steakhouse', 'bakery',
                'coffeehouse', 'pub', 'tavern', 'chophouse', 'seafood',
                'cuisine', 'dining', 'food', 'catering', 'express', 'fast food',
                'buffet', 'garden', 'parlor', 'house', 'shack', 'joint'
            },
            'Retail': {
                'store', 'shop', 'market', 'mart', 'outlet', 'center', 'warehouse',
                'superstore', 'supermarket', 'department', 'boutique', 'emporium',
                'gallery', 'corner', 'supply', 'supplies', 'goods', 'products',
                'retail', 'discount', 'factory', 'co-op', 'cooperative', 'shoppe',
                'super', 'mega', 'mini', 'express', 'neighborhood', 'local'
            }
        }
    
    def _compile_business_suffixes(self):
        """
        Compile comprehensive regex patterns for common business name suffixes
        with more variations and international coverage
        """
        suffix_patterns = {
            r'\bco(?:mpany)?\b': '',             # Company/Co.
            r'\binc(?:orporated)?\b': '',        # Incorporated/Inc.
            r'\bltd\b': '',                      # Ltd
            r'\bllc\b': '',                      # LLC
            r'\bcorp(?:oration)?\b': '',         # Corporation/Corp
            r'\blimited\b': '',                  # Limited
            r'\bplc\b': '',                      # PLC
            r'\bltda\b': '',                     # LTDA (Latin America)
            r'\bgmbh\b': '',                     # GmbH (German)
            r'\bs\.?a\.?\b': '',                 # S.A. (international)
            r'\bs\.?p\.?a\.?\b': '',             # S.p.A. (Italian)
            r'\bs\.?a\.?r\.?l\.?\b': '',         # S.A.R.L. (French)
            r'\bs\.?l\.?\b': '',                 # S.L. (Spanish)
            r'\ba\.?g\.?\b': '',                 # A.G. (German)
            r'\ba\.?s\.?\b': '',                 # A.S. (Scandinavian)
            r'\bpty\b': '',                      # PTY (Australia)
            r'\bpvt\b': '',                      # PVT (private)
            r'\bp\.?t\.?\b': '',                 # P.T. (Indonesian)
            r'\bc\.?v\.?\b': '',                 # C.V. (Dutch)
            r'\bgroup\b': '',                    # Group
            r'\bholdings?\b': '',                # Holdings/Holding
            r'\binternational\b': '',            # International
            r'\bworldwide\b': '',                # Worldwide
            r'\bglobal\b': '',                   # Global
            r'\benterprises?\b': '',             # Enterprises/Enterprise
            r'\bassociates?\b': '',              # Associates/Associate
            r'\bpartners?\b': '',                # Partners/Partner
            r'\bindustries?\b': '',              # Industries/Industry
            r'\bsolutions?\b': '',               # Solutions/Solution
            r'\bservices?\b': '',                # Services/Service
            r'\btechnologies?\b': '',            # Technologies/Technology
            r'\bsystems?\b': ''                  # Systems/System
        }
        
        # Compile all regex patterns
        compiled_patterns = {re.compile(pattern, re.IGNORECASE): replacement 
                             for pattern, replacement in suffix_patterns.items()}
        
        return compiled_patterns
    
    def _compile_business_patterns(self):
        """Compile regex patterns for common business name structures"""
        business_patterns = {
            # Location prefix patterns (e.g., "North Shore Hospital")
            r'(north|south|east|west|central|downtown|uptown|midtown)\s+(.+)': 
                lambda m: m.group(2),  # Keep main business name
                
            # Location suffix patterns (e.g., "Walmart Downtown")
            r'(.+?)\s+(north|south|east|west|central|downtown|uptown|midtown)$': 
                lambda m: m.group(1),  # Keep main business name
                
            # Branch/location patterns (e.g., "Chase Bank - Brooklyn Branch")
            r'(.+?)\s*[\-\–\—\#]\s*(.+?)\s+(branch|location|store|outlet)$': 
                lambda m: m.group(1),  # Keep main entity
                
            # Store number patterns (e.g., "Target #1234")
            r'(.+?)\s*#\s*\d+$': 
                lambda m: m.group(1),  # Keep store name, remove number
                
            # Franchise patterns (e.g., "McDonald's Franchise")
            r'(.+?)\s+(franchise|franchisee)$': 
                lambda m: m.group(1),  # Keep main brand
                
            # Multiple locations pattern (e.g., "Starbucks - New York & Boston")
            r'(.+?)\s*[\-\–\—]\s*(.+?\s+(?:&|and)\s+.+?)$': 
                lambda m: m.group(1)   # Keep main entity
        }
        
        # Compile all regex patterns
        compiled_patterns = {re.compile(pattern, re.IGNORECASE): replacement 
                             for pattern, replacement in business_patterns.items()}
        
        return compiled_patterns
    
    def preprocess(self, text, domain=None, remove_suffixes=True, normalize_mcdonald=True):
        """
        Enhanced preprocessing with better merchant-specific handling
        
        Args:
            text (str): Text to preprocess
            domain (str, optional): Domain for specialized preprocessing
            remove_suffixes (bool): Whether to remove business suffixes
            normalize_mcdonald (bool): Whether to normalize McDonald's variants
            
        Returns:
            str: Preprocessed text
        """
        if not isinstance(text, str):
            return ""
        
        # Convert to lowercase
        text = text.lower().strip()
        
        # Save original for special case handling
        original_text = text
        
        # Handle special cases before general processing
        
        # McDonald's special handling
        if normalize_mcdonald and ('mcdonald' in text or 'mcd' in text or 'mcdonalds' in text):
            # Normalize all McDonald's variants
            for pattern in ['mcdonald\'s', 'mcdonalds', 'mcdonald', 'mcd ', 'mcd\'s']:
                if pattern in text:
                    text = 'mcdonalds'
                    break
        
        # Starbucks special handling
        if 'starbucks' in text or 'sbux' in text:
            for pattern in ['starbucks coffee', 'starbucks coffee company', 'starbucks corp']:
                if pattern in text:
                    text = 'starbucks'
                    break
        
        # Walmart special handling
        if 'walmart' in text or 'wal' in text or 'wmt' in text:
            for pattern in ['wal-mart', 'wal mart', 'walmart supercenter', 'walmart neighborhood market']:
                if pattern in text:
                    text = 'walmart'
                    break
        
        # If we've made a special case substitution, return it directly
        if text != original_text.lower().strip() and len(text) > 0:
            return text
        
        # Better handling of punctuation - preserve apostrophes in business names
        text = re.sub(r'([^a-z0-9\'\.\&\-])', ' ', text)
        
        # Special handling for business name apostrophes with multiple patterns
        text = re.sub(r'\'s\b', 's', text)  # Convert McDonald's to McDonalds
        text = re.sub(r'\'', '', text)      # Remove remaining apostrophes
        
        # Handle special characters in company names
        text = re.sub(r'&amp;', '&', text)  # Fix HTML entities
        text = re.sub(r'@', 'at', text)     # Replace @ with 'at'
        
        # Normalize spaces and remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Apply business suffix removal if requested
        if remove_suffixes:
            for pattern, replacement in self.business_suffixes.items():
                text = pattern.sub(replacement, text)
        
        # Apply business pattern normalization
        for pattern, replacement_func in self.business_patterns.items():
            match = pattern.search(text)
            if match:
                text = replacement_func(match).strip()
        
        # Split words for further processing
        words = text.split()
        
        # Apply general abbreviation expansion
        words = [self.abbreviations.get(word, word) for word in words]
        
        # Apply domain-specific abbreviation expansion if domain is provided
        if domain and domain in self.domain_abbreviations:
            words = [self.domain_abbreviations[domain].get(word, word) for word in words]
        
        # Create expanded stopwords for this domain
        expanded_stopwords = self.stopwords.copy()
        if domain and domain in self.domain_stopwords:
            expanded_stopwords.update(self.domain_stopwords[domain])
        
        # Remove stopwords
        words = [word for word in words if word not in expanded_stopwords]
        
        # Rejoin words and remove extra spaces
        text = ' '.join(words)
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def preprocess_pair(self, acronym, full_name, domain=None):
        """
        Preprocess acronym and full name pair with domain-specific handling
        
        Args:
            acronym (str): Acronym or short name
            full_name (str): Full name
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            tuple: (preprocessed_acronym, preprocessed_full_name)
        """
        acronym_clean = self.preprocess(acronym, domain)
        full_name_clean = self.preprocess(full_name, domain)
        return acronym_clean, full_name_clean

# Initialize merchant preprocessor
merchant_preprocessor = MerchantPreprocessor()
logger.info("Enhanced merchant preprocessor initialized!")

2025-04-08 00:24:27,121 - INFO - Enhanced merchant preprocessor initialized!


In [8]:
# Cell 4: Similarity Algorithms
class SimilarityAlgorithms:
    """
    Enhanced similarity algorithms specifically optimized for merchant name matching.
    Includes traditional string-based methods and specialized merchant-specific algorithms.
    """
    
    def __init__(self, preprocessor=None, bert_embedder=None):
        """
        Initialize similarity algorithms with preprocessor and BERT embedder
        
        Args:
            preprocessor: Merchant name preprocessor
            bert_embedder: BERT embedder for semantic similarity
        """
        self.preprocessor = preprocessor or MerchantPreprocessor()
        self.bert_embedder = bert_embedder
        
        # Initialize TF-IDF vectorizer for fallback
        self.tfidf = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 5))
        self.tfidf_fitted = False
        
        # Thresholds for different similarity types
        self.exact_match_threshold = 0.95
        self.high_match_threshold = 0.85
        self.medium_match_threshold = 0.70
        self.low_match_threshold = 0.50
    
    def jaro_winkler_similarity(self, s1, s2, domain=None):
        """
        Calculate Jaro-Winkler similarity with enhanced preprocessing
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Jaro-Winkler similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Apply Jaro-Winkler algorithm
        return jaro_winkler(s1_clean, s2_clean)
    
    def damerau_levenshtein_similarity(self, s1, s2, domain=None):
        """
        Calculate Damerau-Levenshtein similarity, better for handling transpositions
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Damerau-Levenshtein similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Calculate Damerau-Levenshtein distance
        max_len = max(len(s1_clean), len(s2_clean))
        if max_len == 0:
            return 0
        
        distance = textdistance.damerau_levenshtein.distance(s1_clean, s2_clean)
        similarity = 1 - (distance / max_len)
        return max(0, similarity)  # Ensure non-negative
    
    def tfidf_cosine_similarity(self, s1, s2, domain=None):
        """
        Calculate TF-IDF Cosine similarity for keyword matching
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: TF-IDF cosine similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Fit and transform with TF-IDF
        try:
            tfidf_matrix = self.tfidf.fit_transform([s1_clean, s2_clean])
            similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
            return float(max(0, similarity))  # Ensure non-negative
        except:
            return 0
    
    def jaccard_ngram_similarity(self, s1, s2, domain=None, n=2):
        """
        Calculate Jaccard n-gram similarity for character overlaps
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            n (int): Size of n-grams
            
        Returns:
            float: Jaccard n-gram similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Create n-grams
        def get_ngrams(text, n):
            return [text[i:i+n] for i in range(len(text)-(n-1))]
        
        s1_ngrams = set(get_ngrams(s1_clean, n))
        s2_ngrams = set(get_ngrams(s2_clean, n))
        
        # Calculate Jaccard similarity
        union_size = len(s1_ngrams.union(s2_ngrams))
        if union_size == 0:
            return 0
        
        intersection_size = len(s1_ngrams.intersection(s2_ngrams))
        return intersection_size / union_size
    
    def token_set_ratio(self, s1, s2, domain=None):
        """
        Calculate token set ratio using fuzzywuzzy
        Handles word order and partial matching well
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Token set ratio between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Calculate Token Set Ratio
        return fuzz.token_set_ratio(s1_clean, s2_clean) / 100
    
    def token_sort_ratio(self, s1, s2, domain=None):
        """
        Calculate token sort ratio using fuzzywuzzy
        Handles word order differences well
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Token sort ratio between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Calculate Token Sort Ratio
        return fuzz.token_sort_ratio(s1_clean, s2_clean) / 100
    
    def contains_ratio(self, s1, s2, domain=None):
        """
        Check if one string contains the other, useful for acronym-full name matching
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Containment ratio between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Special case for McDonald's
        if ('mcdonald' in s1_clean and 'mcdonald' in s2_clean) or ('mcd' in s1_clean and 'mcdonald' in s2_clean) or ('mcdonald' in s1_clean and 'mcd' in s2_clean):
            return 1.0
            
        # Special case for Starbucks
        if ('starbuck' in s1_clean and 'starbuck' in s2_clean) or ('sbux' in s1_clean and 'starbuck' in s2_clean) or ('starbuck' in s1_clean and 'sbux' in s2_clean):
            return 1.0
            
        # Special case for Walmart
        if ('walmart' in s1_clean and 'walmart' in s2_clean) or ('wmt' in s1_clean and 'walmart' in s2_clean) or ('walmart' in s1_clean and 'wmt' in s2_clean):
            return 1.0
        
        # Check if one string fully contains the other
        if s1_clean in s2_clean:
            return 1.0
        elif s2_clean in s1_clean:
            return 1.0
            
        # Check for partial containment at word level
        s1_words = set(s1_clean.split())
        s2_words = set(s2_clean.split())
        
        # If all words in the shorter string are in the longer string
        if s1_words.issubset(s2_words):
            return 0.9
        elif s2_words.issubset(s1_words):
            return 0.9
        
        # Check overlap at the word level
        intersection = s1_words.intersection(s2_words)
        shorter_len = min(len(s1_words), len(s2_words))
        
        if shorter_len == 0:
            # Fall back to character-level check for short strings
            return self._char_contains_ratio(s1_clean, s2_clean)
        
        return len(intersection) / shorter_len
    
    def _char_contains_ratio(self, s1, s2):
        """Helper method for character-level containment ratio"""
        s1_chars = list(s1)
        s2_chars = list(s2)
        
        matches = 0
        for char in s1_chars:
            if char in s2_chars:
                matches += 1
                s2_chars.remove(char)  # Remove matched char
        
        return matches / len(s1_chars) if len(s1_chars) > 0 else 0
    
    def acronym_similarity(self, acronym, full_name, domain=None):
        """
        Calculate how well the acronym is formed from the full name
        
        Args:
            acronym (str): The acronym
            full_name (str): The full name
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Acronym similarity score between 0 and 1
        """
        # Special case handling for common acronyms
        acronym_lower = acronym.lower() if isinstance(acronym, str) else ""
        full_name_lower = full_name.lower() if isinstance(full_name, str) else ""
        
        # Check known acronym-full name pairs
        known_pairs = {
            'bofa': 'bank of america',
            'jp': 'jpmorgan',
            'jpm': 'jpmorgan',
            'wf': 'wells fargo',
            'gs': 'goldman sachs',
            'ms': 'morgan stanley',
            'citi': 'citibank',
            'amex': 'american express',
            'mcd': 'mcdonalds',
            'sbux': 'starbucks',
            'wmt': 'walmart',
            'hd': 'home depot',
            'low': 'lowes',
            'msft': 'microsoft',
            'goog': 'google',
            'amzn': 'amazon',
            'aapl': 'apple'
        }
        
        # Check for known acronym pairs
        for known_acr, known_full in known_pairs.items():
            if known_acr in acronym_lower and known_full in full_name_lower:
                return 1.0  # Perfect match for known pairs
        
        # Preprocess inputs
        acronym_clean, full_name_clean = self.preprocessor.preprocess_pair(acronym, full_name, domain)
        
        # Check if strings are empty
        if not acronym_clean or not full_name_clean:
            return 0
        
        # Extract first letters from each word in full name
        words = full_name_clean.split()
        if not words:
            return 0
        
        # Standard acronym formation - first letter of each word
        first_letters = ''.join([word[0] for word in words if word])
        
        # Check if acronym matches first letters pattern
        acronym_lower = acronym_clean.lower()
        first_letters_lower = first_letters.lower()
        
        # Perfect first-letter acronym match
        if acronym_lower == first_letters_lower:
            return 1.0
        
        # Check if acronym contains first letters in sequence
        acronym_pos = 0
        for char in first_letters_lower:
            if acronym_pos < len(acronym_lower) and char == acronym_lower[acronym_pos]:
                acronym_pos += 1
                
        if acronym_pos == len(acronym_lower):
            # All acronym chars found in sequence
            return 0.9
        
        # Check for matching first letters (in any order)
        acronym_chars = set(acronym_lower)
        first_letters_chars = set(first_letters_lower)
        
        if acronym_chars.issubset(first_letters_chars):
            # All acronym chars are first letters
            return 0.8
        
        # Check if acronym formed from first letters of some words (partial match)
        matches = 0
        remaining_acronym = acronym_lower
        
        for word in words:
            if not word or not remaining_acronym:
                continue
                
            if word[0] == remaining_acronym[0]:
                matches += 1
                remaining_acronym = remaining_acronym[1:]
        
        if len(acronym_lower) > 0:
            return matches / len(acronym_lower)
        
        return 0
    
    def specialized_acronym_similarity(self, acronym, full_name, domain=None):
        """
        Enhanced acronym similarity with special handling for various business naming patterns
        
        Args:
            acronym (str): The acronym
            full_name (str): The full name
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Specialized acronym similarity score between 0 and 1
        """
        # Get base acronym similarity
        base_score = self.acronym_similarity(acronym, full_name, domain)
        
        # If base score is high, no need for specialized handling
        if base_score > 0.8:
            return base_score
        
        acronym_clean, full_name_clean = self.preprocessor.preprocess_pair(acronym, full_name, domain)
        
        # Check if strings are empty
        if not acronym_clean or not full_name_clean:
            return 0
        
        # Special case for "Mc" prefixes (common in restaurant names)
        if full_name_clean.startswith('mc') and len(acronym_clean) >= 1 and acronym_clean[0] == 'm':
            # McDonalds -> MCD pattern
            modified_full_name = full_name_clean[2:]  # Remove "mc"
            remaining_chars = acronym_clean[1:]  # Remove "m"
            
            # For "MCD" -> "McDonalds" pattern
            if remaining_chars and len(modified_full_name) > 0:
                # Check if remaining chars match consonants in the name
                consonants = ''.join([c for c in modified_full_name if c not in 'aeiou'])
                if remaining_chars in consonants:
                    return 0.95
                
                # Check first letters after "Mc"
                words = modified_full_name.split()
                if words:
                    first_letters = ''.join([word[0] for word in words if word])
                    if remaining_chars in first_letters:
                        return 0.90
        
        # Handle initialism patterns (e.g., "IBM" -> "International Business Machines")
        if len(acronym_clean) >= 2 and len(full_name_clean.split()) >= len(acronym_clean):
            words = full_name_clean.split()
            
            # Check if acronym consists of first letters of consecutive words
            for i in range(len(words) - len(acronym_clean) + 1):
                subset = words[i:i+len(acronym_clean)]
                initials = ''.join([word[0] for word in subset if word])
                
                if initials.lower() == acronym_clean.lower():
                    return 0.95
        
        # Handle "Bank of X" vs "X Bank" pattern
        if ('bank' in acronym_clean and 'bank' in full_name_clean) or ('banking' in acronym_clean and 'bank' in full_name_clean):
            # Extract the main entity name (before/after "Bank")
            if 'bank of' in full_name_clean:
                entity = full_name_clean.split('bank of')[1].strip()
                if entity and entity[0].lower() in acronym_clean.lower():
                    return 0.85
            
            # Check for "X Bank" pattern
            for word in acronym_clean.split():
                if word != 'bank' and word in full_name_clean:
                    return 0.8
        
        # Check for abbreviated words (like "Intl" -> "International")
        abbr_dict = {
            'intl': 'international',
            'natl': 'national',
            'amer': 'america',
            'assn': 'association',
            'assoc': 'associates',
            'corp': 'corporation',
            'univ': 'university',
            'tech': 'technology',
            'inst': 'institute'
        }
        
        for abbr, full in abbr_dict.items():
            if abbr in acronym_clean and full in full_name_clean:
                return 0.85
        
        # Return base score if no specialized patterns matched
        return base_score
    
    def phonetic_similarity(self, s1, s2, domain=None):
        """
        Calculate phonetic similarity using Soundex algorithm.
        Especially useful for similar-sounding business names.
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Phonetic similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # If either string is empty, return 0
        if not s1_clean or not s2_clean:
            return 0.0
        
        # Get the soundex codes for words
        try:
            # For multi-word strings, get soundex for each word
            s1_words = s1_clean.split()
            s2_words = s2_clean.split()
            
            # Get soundex codes for each word
            s1_codes = [jellyfish.soundex(word) for word in s1_words if len(word) > 1]
            s2_codes = [jellyfish.soundex(word) for word in s2_words if len(word) > 1]
            
            # Calculate matches between codes
            matches = 0
            total = max(len(s1_codes), len(s2_codes))
            
            if total == 0:
                return 0.0
            
            # Count matched codes
            for code in s1_codes:
                if code in s2_codes:
                    matches += 1
                    # Remove the matched code to avoid double counting
                    s2_codes.remove(code)
            
            return matches / total
        except Exception as e:
            logger.warning(f"Error in phonetic similarity: {e}")
            # Fallback if there's an error with the soundex calculation
            return 0.0
    
    def metaphone_similarity(self, s1, s2, domain=None):
        """
        Calculate metaphone similarity (better than Soundex for English).
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Metaphone similarity score between 0 and 1
        """
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # If either string is empty, return 0
        if not s1_clean or not s2_clean:
            return 0.0
        
        try:
            # For multi-word strings, get metaphone for each word
            s1_words = s1_clean.split()
            s2_words = s2_clean.split()
            
            # Get metaphone codes for each word
            s1_codes = [jellyfish.metaphone(word) for word in s1_words if len(word) > 1]
            s2_codes = [jellyfish.metaphone(word) for word in s2_words if len(word) > 1]
            
            # Calculate matches between codes
            matches = 0
            total = max(len(s1_codes), len(s2_codes))
            
            if total == 0:
                return 0.0
            
            # Count matched codes
            for code in s1_codes:
                if code in s2_codes:
                    matches += 1
                    # Remove the matched code to avoid double counting
                    s2_codes.remove(code)
            
            return matches / total
        except Exception as e:
            logger.warning(f"Error in metaphone similarity: {e}")
            return 0.0
    
    def semantic_similarity(self, s1, s2, domain=None):
        """
        Calculate semantic similarity using BERT embeddings
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Semantic similarity score between 0 and 1
        """
        if not self.bert_embedder:
            return 0.0
            
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0.0
        
        try:
            return self.bert_embedder.compute_similarity(s1_clean, s2_clean)
        except Exception as e:
            logger.warning(f"Error in semantic similarity: {e}")
            return 0.0
    
    def multi_algorithm_similarity(self, s1, s2, domain=None):
        """
        Calculate a weighted average of multiple similarity algorithms
        with weights optimized for merchant name matching
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            float: Multi-algorithm similarity score between 0 and 1
        """
        # Preprocess inputs
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty
        if not s1_clean or not s2_clean:
            return 0
        
        # Base weights
        weights = {
            'jaro_winkler': 0.2,
            'token_set_ratio': 0.15,
            'contains_ratio': 0.15,
            'acronym': 0.1,
            'semantic': 0.25,
            'phonetic': 0.05,
            'metaphone': 0.05,
            'jaccard_ngram': 0.05
        }
        
        # Calculate each similarity score
        scores = {
            'jaro_winkler': self.jaro_winkler_similarity(s1, s2, domain),
            'token_set_ratio': self.token_set_ratio(s1, s2, domain),
            'contains_ratio': self.contains_ratio(s1, s2, domain),
            'acronym': self.specialized_acronym_similarity(s1, s2, domain),
            'semantic': self.semantic_similarity(s1, s2, domain) if self.bert_embedder else 0,
            'phonetic': self.phonetic_similarity(s1, s2, domain),
            'metaphone': self.metaphone_similarity(s1, s2, domain),
            'jaccard_ngram': self.jaccard_ngram_similarity(s1, s2, domain, n=3)
        }
        
        # Adjust weights based on domain if provided
        if domain:
            if domain == 'Banking' or domain == 'Financial':
                weights['acronym'] = 0.20
                weights['contains_ratio'] = 0.20
                weights['jaro_winkler'] = 0.15
                weights['semantic'] = 0.25
                weights['token_set_ratio'] = 0.10
            elif domain == 'Restaurant':
                weights['jaro_winkler'] = 0.25
                weights['semantic'] = 0.30
                weights['phonetic'] = 0.15
                weights['token_set_ratio'] = 0.20
                weights['acronym'] = 0.05
            elif domain == 'Retail':
                weights['contains_ratio'] = 0.25
                weights['semantic'] = 0.30
                weights['jaro_winkler'] = 0.20
                weights['token_set_ratio'] = 0.15
                weights['acronym'] = 0.05
        
        # Calculate weighted average
        weighted_sum = sum(weights[algo] * score for algo, score in scores.items())
        total_weight = sum(weights.values())
        
        return weighted_sum / total_weight if total_weight > 0 else 0

# Initialize similarity algorithms with the preprocessor and BERT embedder
similarity_algorithms = SimilarityAlgorithms(
    preprocessor=merchant_preprocessor, 
    bert_embedder=bert_embedder
)
logger.info("Enhanced similarity algorithms initialized!")

2025-04-08 00:24:30,125 - INFO - Enhanced similarity algorithms initialized!


In [10]:
# Cell 5: Advanced Pattern Recognition

class PatternRecognition:
    """
    Advanced pattern recognition for merchant names with industry-specific rules.
    Detects common patterns in merchant naming conventions across different industries
    with enhanced accuracy and specialized handling.
    """
    
    def __init__(self, preprocessor=None, similarity_algorithms=None):
        """
        Initialize pattern recognition with preprocessor and similarity algorithms
        
        Args:
            preprocessor: Merchant name preprocessor
            similarity_algorithms: Similarity algorithms for merchant matching
        """
        self.preprocessor = preprocessor or MerchantPreprocessor()
        self.similarity_algorithms = similarity_algorithms or SimilarityAlgorithms(self.preprocessor)
        
        # Load all pattern dictionaries
        self.common_merchant_patterns = self._load_merchant_patterns()
        self.known_equivalents = self._load_known_equivalents()
        self.abbreviation_patterns = self._load_abbreviation_patterns()
        
        # Initialize regex pattern compilation
        self._compile_regex_patterns()
        
        # Initialize Aho-Corasick automaton for fast pattern matching if available
        self.aho_corasick_finder = None
        if aho_corasick_available:
            self._initialize_aho_corasick()
    
    def _load_merchant_patterns(self):
        """Load comprehensive merchant name patterns by industry"""
        return {
            'banking': {
                'bank_name_patterns': [
                    # Basic bank naming patterns
                    (r'(\w+)\s+bank', r'bank\s+of\s+(\w+)'),  # "Chase Bank" vs "Bank of America"
                    (r'(\w+)\s+banking', r'(\w+)\s+financial'),  # "Western Banking" vs "Western Financial"
                    (r'(\w+)\s+credit\s+union', r'(\w+)\s+cu'),  # "State Credit Union" vs "State CU"
                    (r'(\w+)\s+savings\s+bank', r'(\w+)\s+savings'),  # "Community Savings Bank" vs "Community Savings"
                    (r'(\w+)\s+trust\s+company', r'(\w+)\s+trust'),  # "First Trust Company" vs "First Trust"
                    (r'(\w+)\s+national\s+bank', r'(\w+)\s+natl'),  # "First National Bank" vs "First Natl"
                    
                    # International bank patterns
                    (r'banco\s+(\w+)', r'bank\s+of\s+(\w+)'),  # "Banco Popular" vs "Bank of Popular"
                    (r'banque\s+(\w+)', r'bank\s+of\s+(\w+)'),  # "Banque Nationale" vs "National Bank"
                    (r'deutsche\s+(\w+)', r'german\s+(\w+)'),  # "Deutsche Bank" vs "German Bank"
                ],
                'branch_patterns': [
                    (r'(\w+)\s+branch', r'\1'),  # "Downtown Branch" -> "Downtown"
                    (r'(\w+)\s+location', r'\1'),  # "Main Location" -> "Main"
                    (r'(\w+)\s+-\s+(\w+)', r'\1'),  # "Chase - Brooklyn" -> "Chase"
                    (r'(\w+)\s+banking\s+center', r'\1'),  # "Chase Banking Center" -> "Chase"
                    (r'(\w+)\s+financial\s+center', r'\1'),  # "Wells Fargo Financial Center" -> "Wells Fargo"
                    (r'(\w+)\s+atm', r'\1'),  # "Chase ATM" -> "Chase"
                ]
            },
            'retail': {
                'store_patterns': [
                    (r'(\w+)\s+#\d+', r'\1'),  # "Walmart #1234" -> "Walmart"
                    (r'(\w+)\s+store', r'\1'),  # "Target Store" -> "Target"
                    (r'(\w+)\s+superstore', r'\1'),  # "Walmart Superstore" -> "Walmart"
                    (r'(\w+)\s+supermarket', r'\1'),  # "Kroger Supermarket" -> "Kroger"
                    (r'(\w+)\s+market', r'\1'),  # "Whole Foods Market" -> "Whole Foods"
                    (r'(\w+)\s+express', r'\1'),  # "Safeway Express" -> "Safeway"
                    (r'(\w+)\s+neighborhood\s+market', r'\1'),  # "Walmart Neighborhood Market" -> "Walmart"
                    (r'(\w+)\s+supercenter', r'\1'),  # "Walmart Supercenter" -> "Walmart"
                    (r'(\w+)\s+warehouse', r'\1'),  # "Costco Warehouse" -> "Costco"
                    (r'(\w+)\s+dept\s+store', r'\1'),  # "Macy's Dept Store" -> "Macy's"
                    (r'(\w+)\s+department\s+store', r'\1'),  # "Macy's Department Store" -> "Macy's"
                    (r'(\w+)\s+outlet', r'\1'),  # "Nike Outlet" -> "Nike"
                ],
                'location_patterns': [
                    (r'(\w+)\s+at\s+(\w+\s*\w*)', r'\1'),  # "Target at Springfield Mall" -> "Target"
                    (r'(\w+)\s+in\s+(\w+\s*\w*)', r'\1'),  # "Walmart in Chicago" -> "Walmart"
                    (r'(\w+)\s+on\s+(\w+\s*\w*)', r'\1'),  # "Kroger on Main Street" -> "Kroger"
                    (r'(\w+)\s+-\s+(\w+\s*\w*)', r'\1'),  # "Best Buy - Downtown" -> "Best Buy"
                ],
                'specific_store_types': [
                    (r'(\w+)\s+pharmacy', r'\1'),  # "CVS Pharmacy" -> "CVS"
                    (r'(\w+)\s+drug\s+store', r'\1'),  # "Walgreens Drug Store" -> "Walgreens"
                    (r'(\w+)\s+hardware', r'\1'),  # "Ace Hardware" -> "Ace"
                    (r'(\w+)\s+home\s+improvement', r'\1'),  # "Lowe's Home Improvement" -> "Lowe's"
                    (r'(\w+)\s+home\s+depot', r'home\s+depot'),  # Any "Home Depot" variation
                ],
            },
            'restaurant': {
                'location_patterns': [
                    (r'(\w+)\s+restaurant', r'\1'),  # "McDonald's Restaurant" -> "McDonald's"
                    (r'(\w+)\s+cafe', r'\1'),  # "Starbucks Cafe" -> "Starbucks"
                    (r'(\w+)\s+grill', r'\1'),  # "Applebee's Grill" -> "Applebee's"
                    (r'(\w+)\s+-\s+(\w+)', r'\1'),  # "McDonald's - Downtown" -> "McDonald's"
                    (r'(\w+)\s+kitchen', r'\1'),  # "Chipotle Mexican Kitchen" -> "Chipotle Mexican"
                    (r'(\w+)\s+bar\s+&?\s*grill', r'\1'),  # "Chili's Bar & Grill" -> "Chili's"
                    (r'(\w+)\s+eatery', r'\1'),  # "Panera Eatery" -> "Panera"
                    (r'(\w+)\s+diner', r'\1'),  # "Denny's Diner" -> "Denny's"
                ],
                'chain_name_patterns': [
                    (r'mcdonald\'?s', r'mcdonalds'),  # Normalize McDonald's variations
                    (r'dunkin\'?\s*donuts?', r'dunkin'),  # "Dunkin' Donuts" -> "Dunkin"
                    (r'starbucks\s+coffee', r'starbucks'),  # "Starbucks Coffee" -> "Starbucks"
                    (r'kfc|kentucky\s+fried\s+chicken', r'kfc'),  # KFC variations
                    (r'burger\s+king', r'bk'),  # "Burger King" -> "BK"
                    (r'pizza\s+hut', r'pizzahut'),  # "Pizza Hut" -> "PizzaHut"
                    (r'taco\s+bell', r'tacobell'),  # "Taco Bell" -> "TacoBell"
                ],
                'food_type_patterns': [
                    (r'(\w+)\s+pizzeria', r'\1'),  # "Domino's Pizzeria" -> "Domino's"
                    (r'(\w+)\s+steakhouse', r'\1'),  # "Outback Steakhouse" -> "Outback"
                    (r'(\w+)\s+sushi', r'\1'),  # "Tokyo Sushi" -> "Tokyo"
                    (r'(\w+)\s+bakery', r'\1'),  # "Panera Bakery" -> "Panera"
                    (r'(\w+)\s+bbq', r'\1'),  # "Famous Dave's BBQ" -> "Famous Dave's"
                    (r'(\w+)\s+taco', r'\1'),  # "Chronic Taco" -> "Chronic"
                ],
            },
            'hotel': {
                'property_patterns': [
                    (r'(\w+)\s+hotel', r'\1'),  # "Marriott Hotel" -> "Marriott"
                    (r'(\w+)\s+inn', r'\1'),  # "Holiday Inn" -> "Holiday"
                    (r'(\w+)\s+suites', r'\1'),  # "Comfort Suites" -> "Comfort"
                    (r'(\w+)\s+resort', r'\1'),  # "Wynn Resort" -> "Wynn"
                    (r'(\w+)\s+lodge', r'\1'),  # "Pine Lodge" -> "Pine"
                    (r'(\w+)\s+motel', r'\1'),  # "Super 8 Motel" -> "Super 8"
                    (r'(\w+)\s+&\s+suites', r'\1'),  # "Hampton & Suites" -> "Hampton"
                    (r'(\w+)\s+hotel\s+&\s+resort', r'\1'),  # "Hilton Hotel & Resort" -> "Hilton"
                    (r'(\w+)\s+by\s+(\w+)', r'\2'),  # "Fairfield by Marriott" -> "Marriott"
                ],
                'chain_patterns': [
                    (r'hyatt\s+regency', r'hyatt'),  # "Hyatt Regency" -> "Hyatt"
                    (r'marriott\s+courtyard', r'marriott'),  # "Marriott Courtyard" -> "Marriott"
                    (r'hilton\s+garden\s+inn', r'hilton'),  # "Hilton Garden Inn" -> "Hilton"
                    (r'holiday\s+inn\s+express', r'holiday\s+inn'),  # "Holiday Inn Express" -> "Holiday Inn"
                    (r'four\s+seasons', r'four\s+seasons'),  # Preserve "Four Seasons"
                    (r'best\s+western\s+plus', r'best\s+western'),  # "Best Western Plus" -> "Best Western"
                ],
            },
            'gas_station': {
                'station_patterns': [
                    (r'(\w+)\s+gas', r'\1'),  # "Shell Gas" -> "Shell"
                    (r'(\w+)\s+gas\s+station', r'\1'),  # "Chevron Gas Station" -> "Chevron"
                    (r'(\w+)\s+fuel', r'\1'),  # "BP Fuel" -> "BP"
                    (r'(\w+)\s+oil', r'\1'),  # "Mobil Oil" -> "Mobil"
                    (r'(\w+)\s+service\s+station', r'\1'),  # "Texaco Service Station" -> "Texaco"
                    (r'(\w+)\s+petroleum', r'\1'),  # "Phillips Petroleum" -> "Phillips"
                ],
                'convenience_patterns': [
                    (r'(\w+)\s+convenience', r'\1'),  # "7-Eleven Convenience" -> "7-Eleven"
                    (r'(\w+)\s+mart', r'\1'),  # "Speedway Mart" -> "Speedway"
                    (r'(\w+)\s+corner\s+store', r'\1'),  # "Shell Corner Store" -> "Shell"
                ],
            },
            'general': {
                'location_prefixes': [
                    (r'(north|south|east|west|downtown|midtown|uptown|central)\s+(\w+)', r'\2'),
                    (r'(\w+)\s+(north|south|east|west|downtown|midtown|uptown|central)', r'\1'),
                    (r'(n|s|e|w)\s+(\w+)', r'\2'),  # "N Target" -> "Target"
                    (r'(\w+)\s+(n|s|e|w)', r'\1'),  # "Target S" -> "Target"
                ],
                'franchise_patterns': [
                    (r'(\w+)\s+franchise', r'\1'),
                    (r'(\w+)\s+franchisee', r'\1'),
                    (r'(\w+)\s+franchisor', r'\1'),
                    (r'(\w+)\s+licensed', r'\1'),
                ],
                'subsidiary_patterns': [
                    (r'(\w+),\s+a\s+(\w+)\s+company', r'\1'),
                    (r'(\w+),\s+subsidiary\s+of\s+(\w+)', r'\1'),
                    (r'(\w+)\s+division\s+of\s+(\w+)', r'\1'),
                    (r'(\w+)\s+by\s+(\w+)', r'\1'),
                    (r'(\w+)\s+owned\s+by\s+(\w+)', r'\1'),
                ],
                'numbered_locations': [
                    (r'(\w+)\s+\d+', r'\1'),  # "Starbucks 123" -> "Starbucks"
                    (r'(\w+)\s+no\.?\s*\d+', r'\1'),  # "McDonald's No. 456" -> "McDonald's"
                    (r'(\w+)\s+number\s*\d+', r'\1'),  # "Subway Number 789" -> "Subway"
                    (r'(\w+)\s+store\s*\d+', r'\1'),  # "Target Store 101" -> "Target"
                ],
                'special_character_patterns': [
                    (r'(\w+)\s*[&@]\s*(\w+)', r'\1 and \2'),  # "AT&T" -> "AT and T", "M@cys" -> "M and cys"
                    (r'(\w+)\s*\+\s*(\w+)', r'\1 plus \2'),  # "Bed+Bath" -> "Bed plus Bath"
                ],
            },
        }
    
    def _load_known_equivalents(self):
        """Load known equivalent merchant names"""
        return {
            # Banking and financial
            'bofa': ['bank of america', 'bankofamerica', 'bank america', 'b of a'],
            'chase': ['jpmorgan chase', 'jp morgan', 'chase bank', 'jpmorgan'],
            'wells fargo': ['wf', 'wellsfargo', 'wells'],
            'citi': ['citibank', 'citigroup', 'citicorp'],
            'amex': ['american express', 'americanexpress'],
            'discover': ['discover card', 'discover financial'],
            'capital one': ['capitalone', 'cap1', 'cap one'],
            
            # Retail
            'walmart': ['wal-mart', 'wal mart', 'wmt', 'walmart supercenter', 'walmart neighborhood market'],
            'target': ['target store', 'super target', 'target superstore'],
            'costco': ['costco wholesale', 'costco warehouse', 'price costco'],
            'amazon': ['amazon.com', 'amazon marketplace', 'amzn'],
            'home depot': ['thd', 'the home depot', 'homedepot'],
            'lowes': ['lowes home improvement', 'lowe\'s'],
            'best buy': ['bestbuy', 'bby'],
            'cvs': ['cvs pharmacy', 'cvs caremark', 'cvs health'],
            'walgreens': ['walgreen', 'wag'],
            
            # Fast food and restaurants
            'mcdonalds': ['mcdonald\'s', 'mcd', 'mcds', 'micky ds'],
            'starbucks': ['starbucks coffee', 'sbux', 'starbucks coffee company'],
            'subway': ['subway restaurant', 'subway sandwiches'],
            'taco bell': ['tacobell', 'bell'],
            'burger king': ['bk', 'burgerking'],
            'wendys': ['wendy\'s', 'wendys old fashioned hamburgers'],
            'kfc': ['kentucky fried chicken', 'kentucky fried'],
            'chipotle': ['chipotle mexican grill', 'cmg'],
            'dominos': ['domino\'s pizza', 'dominos pizza'],
            'pizza hut': ['pizzahut', 'the hut'],
            
            # Gas stations
            'shell': ['shell oil', 'shell gas', 'shell gas station'],
            'exxon': ['exxonmobil', 'exxon mobil', 'esso'],
            'bp': ['british petroleum', 'bp gas', 'bp gas station'],
            'chevron': ['chevron gas', 'chevron gas station', 'chevron texaco'],
            'mobil': ['mobil gas', 'mobil gas station', 'exxon mobil'],
            
            # Hotels
            'marriott': ['marriott hotel', 'marriott international', 'jw marriott'],
            'hilton': ['hilton hotel', 'hilton worldwide', 'hilton hotels'],
            'hyatt': ['hyatt hotel', 'hyatt regency', 'grand hyatt'],
            'holiday inn': ['holiday inn express', 'ihg', 'holiday'],
            'sheraton': ['sheraton hotel', 'sheraton resort'],
            
            # Tech companies
            'apple': ['apple store', 'apple inc', 'apple computer'],
            'microsoft': ['msft', 'microsoft store', 'microsoft corporation'],
            'google': ['google inc', 'alphabet', 'google play'],
            'facebook': ['fb', 'meta', 'facebook inc'],
            'amazon': ['amazon web services', 'aws', 'amazon.com'],
        }
    
    def _load_abbreviation_patterns(self):
        """Load patterns for handling common abbreviations in merchant names"""
        return {
            # Business type abbreviations
            'inc': ['incorporated', 'incorporation'],
            'corp': ['corporation'],
            'co': ['company'],
            'ltd': ['limited'],
            'llc': ['limited liability company'],
            'lp': ['limited partnership'],
            'intl': ['international'],
            'assoc': ['associates', 'association'],
            'svcs': ['services'],
            'mgmt': ['management'],
            'tech': ['technology', 'technologies'],
            'sys': ['systems', 'system'],
            'sol': ['solutions', 'solution'],
            
            # Industry abbreviations
            'ctr': ['center'],
            'dept': ['department'],
            'pharm': ['pharmacy', 'pharmaceutical'],
            'rest': ['restaurant'],
            'cafe': ['cafeteria', 'café'],
            'mkt': ['market'],
            'groc': ['grocery', 'groceries'],
            'disc': ['discount'],
            'whs': ['warehouse'],
            
            # Location abbreviations
            'st': ['street'],
            'ave': ['avenue'],
            'blvd': ['boulevard'],
            'rd': ['road'],
            'hwy': ['highway'],
            'cty': ['city'],
            'twn': ['town'],
            'vlg': ['village'],
            'plz': ['plaza'],
            'sq': ['square'],
            'stn': ['station'],
            
            # Direction abbreviations
            'n': ['north'],
            's': ['south'],
            'e': ['east'],
            'w': ['west'],
            'ne': ['northeast'],
            'nw': ['northwest'],
            'se': ['southeast'],
            'sw': ['southwest'],
        }
    
    def _compile_regex_patterns(self):
        """Compile all regex patterns for faster matching"""
        self.compiled_patterns = {}
        
        for domain, pattern_groups in self.common_merchant_patterns.items():
            self.compiled_patterns[domain] = {}
            
            for pattern_type, patterns in pattern_groups.items():
                self.compiled_patterns[domain][pattern_type] = [
                    (re.compile(pattern, re.IGNORECASE), replacement)
                    for pattern, replacement in patterns
                ]
    
    def _initialize_aho_corasick(self):
        """Initialize Aho-Corasick automaton for fast pattern matching of known equivalents"""
        if not aho_corasick_available:
            return
            
        try:
            self.aho_corasick_finder = pyahocorasick.Automaton()
            
            # Add all merchant names and their equivalents
            for key, equivalents in self.known_equivalents.items():
                for eq in equivalents:
                    self.aho_corasick_finder.add_str(eq, (key, eq))
                self.aho_corasick_finder.add_str(key, (key, key))
                
            # Finalize the automaton
            self.aho_corasick_finder.make_automaton()
            logger.info("Aho-Corasick automaton initialized successfully")
        except Exception as e:
            logger.warning(f"Failed to initialize Aho-Corasick automaton: {e}")
            self.aho_corasick_finder = None
    
    def detect_merchant_patterns(self, s1, s2, domain=None):
        """
        Detect merchant name patterns in a pair of strings
        
        Args:
            s1 (str): First string
            s2 (str): Second string
            domain (str, optional): Domain for specialized pattern detection
            
        Returns:
            dict: Detected patterns and their confidence scores
        """
        if not s1 or not s2:
            return {}
            
        # Preprocess the strings
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty after preprocessing
        if not s1_clean or not s2_clean:
            return {}
        
        patterns = {}
        
        # First check for known equivalents - highest confidence
        equiv_match = self._detect_known_equivalents(s1_clean, s2_clean)
        if equiv_match:
            patterns['known_equivalent'] = {
                'pattern': equiv_match,
                'confidence': 0.95,
                'explanation': f"Known equivalent pair found: '{equiv_match[0]}' and '{equiv_match[1]}'"
            }
            # If known equivalents found, we can return immediately as it's high confidence
            return patterns
        
        # Check for abbreviation patterns
        abbrev_patterns = self._detect_abbreviation_patterns(s1_clean, s2_clean)
        if abbrev_patterns:
            patterns['abbreviation'] = {
                'pattern': abbrev_patterns,
                'confidence': 0.85,
                'explanation': f"Abbreviation pattern detected: {abbrev_patterns}"
            }
        
        # Detect domain-specific patterns
        if domain:
            # Check specific domain patterns if provided
            if domain.lower() in self.common_merchant_patterns:
                domain_patterns = self._detect_domain_specific_patterns(s1_clean, s2_clean, domain.lower())
                if domain_patterns:
                    patterns[f'{domain.lower()}_specific'] = {
                        'pattern': domain_patterns,
                        'confidence': 0.9,
                        'explanation': f"{domain} specific pattern detected: {domain_patterns}"
                    }
        else:
            # Try all domains if not specified
            for domain_name in self.common_merchant_patterns:
                if domain_name == 'general':
                    continue  # Skip general, we'll check it separately
                
                domain_patterns = self._detect_domain_specific_patterns(s1_clean, s2_clean, domain_name)
                if domain_patterns:
                    patterns[f'{domain_name}_pattern'] = {
                        'pattern': domain_patterns,
                        'confidence': 0.8,
                        'explanation': f"{domain_name} pattern detected: {domain_patterns}"
                    }
        
        # Always check general patterns
        general_patterns = self._detect_domain_specific_patterns(s1_clean, s2_clean, 'general')
        if general_patterns:
            patterns['general_pattern'] = {
                'pattern': general_patterns,
                'confidence': 0.7,
                'explanation': f"General pattern detected: {general_patterns}"
            }
        
        # Detect character-level patterns
        char_patterns = self._detect_character_patterns(s1_clean, s2_clean)
        if char_patterns:
            patterns['character_pattern'] = {
                'pattern': char_patterns,
                'confidence': 0.65,
                'explanation': f"Character-level pattern detected: {char_patterns}"
            }
        
        # Detect grammatical structure patterns
        structure_patterns = self._detect_structure_patterns(s1_clean, s2_clean)
        if structure_patterns:
            patterns['structure_pattern'] = {
                'pattern': structure_patterns,
                'confidence': 0.75,
                'explanation': f"Grammatical structure pattern detected: {structure_patterns}"
            }
        
        return patterns
    
    def _detect_known_equivalents(self, s1, s2):
        """
        Check if the merchant names are known equivalents using optimized matching
        
        Args:
            s1 (str): First preprocessed string
            s2 (str): Second preprocessed string
            
        Returns:
            tuple or None: Tuple of (canonical_name, equivalent) if found, otherwise None
        """
        # Use Aho-Corasick for fast matching if available
        if self.aho_corasick_finder:
            # Find matches in both strings
            s1_matches = list(self.aho_corasick_finder.iter(s1))
            s2_matches = list(self.aho_corasick_finder.iter(s2))
            
            # Check if we have matches in both strings
            if s1_matches and s2_matches:
                # Get canonical names from matches
                s1_canonicals = set(match[1][0] for match in s1_matches)
                s2_canonicals = set(match[1][0] for match in s2_matches)
                
                # Check for intersection
                common_canonicals = s1_canonicals.intersection(s2_canonicals)
                if common_canonicals:
                    # Return the first common canonical name and the original strings
                    return (next(iter(common_canonicals)), (s1, s2))
        
        # Fallback to dictionary lookup if Aho-Corasick is not available
        for canonical, equivalents in self.known_equivalents.items():
            # Check if both strings are related to the same canonical name
            s1_matches = canonical == s1 or s1 in equivalents
            s2_matches = canonical == s2 or s2 in equivalents
            
            if s1_matches and s2_matches:
                return (canonical, (s1, s2))
            
            # Try fuzzy matching for each equivalent
            all_variants = equivalents + [canonical]
            for variant in all_variants:
                jw_s1 = self.similarity_algorithms.jaro_winkler_similarity(variant, s1)
                jw_s2 = self.similarity_algorithms.jaro_winkler_similarity(variant, s2)
                
                # If high similarity with both strings, they're likely equivalent
                if jw_s1 > 0.85 and jw_s2 > 0.85:
                    return (canonical, (s1, s2))
        
        return None
    
    def _detect_abbreviation_patterns(self, s1, s2):
        """
        Detect if one string is an abbreviation of the other
        
        Args:
            s1 (str): First preprocessed string
            s2 (str): Second preprocessed string
            
        Returns:
            dict or None: Abbreviation pattern details if found, otherwise None
        """
        # Check for common abbreviation patterns
        s1_words = s1.split()
        s2_words = s2.split()
        
        abbrev_found = False
        expanded_found = False
        abbrev_word = None
        expanded_word = None
        
        # Check each word against known abbreviations
        for word in s1_words:
            if word in self.abbreviation_patterns:
                # Check if any expanded form exists in s2
                for expanded in self.abbreviation_patterns[word]:
                    if expanded in s2 or any(expanded in w for w in s2_words):
                        abbrev_found = True
                        abbrev_word = word
                        expanded_word = expanded
                        break
        
        # Check in reverse direction too
        if not abbrev_found:
            for word in s2_words:
                if word in self.abbreviation_patterns:
                    # Check if any expanded form exists in s1
                    for expanded in self.abbreviation_patterns[word]:
                        if expanded in s1 or any(expanded in w for w in s1_words):
                            abbrev_found = True
                            abbrev_word = word
                            expanded_word = expanded
                            break
        
        if abbrev_found:
            return {
                'type': 'abbreviation_expansion',
                'abbreviation': abbrev_word,
                'expansion': expanded_word,
                'direction': 's1_to_s2' if abbrev_word in s1_words else 's2_to_s1'
            }
        
        # Check for acronyms (first letters of words)
        # See if one string is an acronym of the other
        s1_is_acronym = len(s1_words) == 1 and len(s1.replace('.', '')) <= 5
        s2_is_acronym = len(s2_words) == 1 and len(s2.replace('.', '')) <= 5
        
        if s1_is_acronym and not s2_is_acronym:
            # Check if s1 is acronym of s2
            first_letters = ''.join([word[0] for word in s2_words if word]).lower()
            if s1.lower() == first_letters:
                return {
                    'type': 'acronym_expansion',
                    'acronym': s1,
                    'expansion': s2,
                    'direction': 's1_to_s2'
                }
        elif s2_is_acronym and not s1_is_acronym:
            # Check if s2 is acronym of s1
            first_letters = ''.join([word[0] for word in s1_words if word]).lower()
            if s2.lower() == first_letters:
                return {
                    'type': 'acronym_expansion',
                    'acronym': s2,
                    'expansion': s1,
                    'direction': 's2_to_s1'
                }
        
        # Check for partial abbreviations (like "Intl" for "International")
        for i, word1 in enumerate(s1_words):
            for j, word2 in enumerate(s2_words):
                # Skip very short words or exact matches
                if len(word1) < 3 or len(word2) < 3 or word1 == word2:
                    continue
                
                # Check if one word is a prefix of the other
                if (word1.startswith(word2) or word2.startswith(word1)) and abs(len(word1) - len(word2)) > 2:
                    shorter = word1 if len(word1) < len(word2) else word2
                    longer = word2 if len(word1) < len(word2) else word1
                    
                    # Check if shorter is at least 3 chars and longer is at least 50% longer
                    if len(shorter) >= 3 and len(longer) > len(shorter) * 1.5:
                        return {
                            'type': 'partial_abbreviation',
                            'abbreviation': shorter,
                            'expansion': longer,
                            'direction': 's1_to_s2' if shorter == word1 else 's2_to_s1'
                        }
        
        return None
    
    def _detect_domain_specific_patterns(self, s1, s2, domain):
        """
        Detect industry-specific patterns based on the provided domain
        
        Args:
            s1 (str): First preprocessed string
            s2 (str): Second preprocessed string
            domain (str): Domain for pattern detection
            
        Returns:
            list: List of detected patterns for the domain
        """
        if domain not in self.common_merchant_patterns:
            return []
        
        detected_patterns = []
        
        for pattern_type, patterns in self.compiled_patterns[domain].items():
            for pattern, replacement in patterns:
                # Check if pattern matches first string
                s1_match = pattern.search(s1)
                if s1_match:
                    s1_normalized = pattern.sub(replacement, s1)
                    # Check if normalized version now matches second string (exact or high similarity)
                    if s1_normalized == s2 or self.similarity_algorithms.jaro_winkler_similarity(s1_normalized, s2) > 0.85:
                        detected_patterns.append({
                            'domain': domain,
                            'pattern_type': pattern_type,
                            'original': s1,
                            'normalized': s1_normalized,
                            'match': s2,
                            'regex': pattern.pattern
                        })
                
                # Check in reverse direction
                s2_match = pattern.search(s2)
                if s2_match:
                    s2_normalized = pattern.sub(replacement, s2)
                    # Check if normalized version now matches first string (exact or high similarity)
                    if s2_normalized == s1 or self.similarity_algorithms.jaro_winkler_similarity(s2_normalized, s1) > 0.85:
                        detected_patterns.append({
                            'domain': domain,
                            'pattern_type': pattern_type,
                            'original': s2,
                            'normalized': s2_normalized,
                            'match': s1,
                            'regex': pattern.pattern
                        })
        
        return detected_patterns
    
    def _detect_character_patterns(self, s1, s2):
        """
        Detect character-level patterns like typos, character substitutions, etc.
        
        Args:
            s1 (str): First preprocessed string
            s2 (str): Second preprocessed string
            
        Returns:
            list: List of detected character-level patterns
        """
        patterns = []
        
        # Detect common character substitutions
        char_subs = {
            'zero_o': ('0', 'o'),
            'one_l': ('1', 'l'),
            'at_a': ('@', 'a'),
            'plus_t': ('+', 't'),
            'ampersand_and': ('&', 'and'),
            'dollar_s': ('$', 's'),
            'exclamation_i': ('!', 'i'),
        }
        
        for sub_name, (char1, char2) in char_subs.items():
            # Check if s1 contains char1 and s2 contains char2 or vice versa
            if (char1 in s1 and char2 in s2) or (char2 in s1 and char1 in s2):
                # Create normalized versions by replacing the characters
                s1_norm = s1.replace(char1, char2).replace(char2, char2)
                s2_norm = s2.replace(char1, char2).replace(char2, char2)
                
                # See if normalized versions are more similar
                orig_sim = self.similarity_algorithms.jaro_winkler_similarity(s1, s2)
                norm_sim = self.similarity_algorithms.jaro_winkler_similarity(s1_norm, s2_norm)
                
                if norm_sim > orig_sim:
                    patterns.append({
                        'type': 'character_substitution',
                        'subtype': sub_name,
                        'original_similarity': orig_sim,
                        'normalized_similarity': norm_sim,
                        'improvement': norm_sim - orig_sim
                    })
        
        # Detect transpositions (swapped adjacent characters)
        s1_chars = list(s1)
        for i in range(len(s1_chars) - 1):
            # Try swapping adjacent characters
            swapped = s1_chars.copy()
            swapped[i], swapped[i+1] = swapped[i+1], swapped[i]
            s1_swapped = ''.join(swapped)
            
            # Check if swapped version is more similar to s2
            orig_sim = self.similarity_algorithms.jaro_winkler_similarity(s1, s2)
            swap_sim = self.similarity_algorithms.jaro_winkler_similarity(s1_swapped, s2)
            
            if swap_sim > orig_sim + 0.1:  # Significant improvement
                patterns.append({
                    'type': 'transposition',
                    'position': i,
                    'original': s1,
                    'swapped': s1_swapped,
                    'original_similarity': orig_sim,
                    'swapped_similarity': swap_sim
                })
        
        # Detect common typos in merchant names
        typo_corrections = {
            'walmart': ['wallmart', 'walmark', 'walmrt'],
            'target': ['targat', 'tarket', 'targt'],
            'starbucks': ['starbuks', 'starbuck', 'startbucks'],
            'mcdonalds': ['macdonalds', 'mcdonlds', 'medonalds'],
            'amazon': ['amazan', 'amizon', 'amazom'],
            'costco': ['cosco', 'casco', 'cotsco'],
            'pizza': ['piza', 'pizzza', 'pizaa'],
            'restaurant': ['resturant', 'restrant', 'restarant']
        }
        
        for correct, typos in typo_corrections.items():
            if correct in s1 and any(typo in s2 for typo in typos):
                patterns.append({
                    'type': 'common_typo',
                    'correct': correct,
                    'typo': next((typo for typo in typos if typo in s2), None),
                    'direction': 's1_to_s2'
                })
            elif correct in s2 and any(typo in s1 for typo in typos):
                patterns.append({
                    'type': 'common_typo',
                    'correct': correct,
                    'typo': next((typo for typo in typos if typo in s1), None),
                    'direction': 's2_to_s1'
                })
        
        return patterns
    
    def _detect_structure_patterns(self, s1, s2):
        """
        Detect grammatical structure patterns between merchant names
        
        Args:
            s1 (str): First preprocessed string
            s2 (str): Second preprocessed string
            
        Returns:
            list: List of detected structural patterns
        """
        patterns = []
        
        # Check for word order differences
        s1_words = set(s1.split())
        s2_words = set(s2.split())
        
        # If the sets of words are the same but the strings are different,
        # then it's a word order difference
        if s1_words == s2_words and s1 != s2:
            patterns.append({
                'type': 'word_order',
                'words': sorted(list(s1_words)),
                'original_s1': s1,
                'original_s2': s2
            })
        
        # Check for compound vs. separate words
        # e.g., "walmart" vs "wal mart"
        s1_nospace = s1.replace(' ', '')
        s2_nospace = s2.replace(' ', '')
        
        if s1_nospace == s2_nospace and s1 != s2:
            patterns.append({
                'type': 'compound_vs_separate',
                'compound': s1 if ' ' not in s1 else s2,
                'separate': s2 if ' ' in s2 else s1
            })
        
        # Check for possessive forms
        # e.g., "McDonald's" vs "McDonalds"
        s1_noposs = s1.replace('\'s', 's')
        s2_noposs = s2.replace('\'s', 's')
        
        if s1_noposs == s2_noposs and s1 != s2:
            patterns.append({
                'type': 'possessive_form',
                'with_apostrophe': s1 if '\'' in s1 else s2,
                'without_apostrophe': s2 if '\'' not in s2 else s1
            })
        
        # Check for plural vs. singular forms
        singular_endings = ['s', 'es', 'ies']
        for word1 in s1.split():
            for word2 in s2.split():
                # Skip short words
                if len(word1) < 4 or len(word2) < 4:
                    continue
                    
                # Check if one is a plural of the other
                for ending in singular_endings:
                    if word1 == word2 + ending:
                        patterns.append({
                            'type': 'plural_singular',
                            'plural': word1,
                            'singular': word2,
                            'direction': 's1_to_s2'
                        })
                    elif word2 == word1 + ending:
                        patterns.append({
                            'type': 'plural_singular',
                            'plural': word2,
                            'singular': word1,
                            'direction': 's2_to_s1'
                        })
        
        return patterns
    
    def get_canonical_name(self, merchant_name, domain=None):
        """
        Get the canonical name for a merchant name
        
        Args:
            merchant_name (str): Merchant name to normalize
            domain (str, optional): Domain for specialized processing
            
        Returns:
            str: Canonical merchant name
        """
        if not merchant_name:
            return ""
            
        # Preprocess the input
        clean_name = self.preprocessor.preprocess(merchant_name, domain)
        
        # Check for known equivalents
        if self.aho_corasick_finder:
            matches = list(self.aho_corasick_finder.iter(clean_name))
            if matches:
                # Get the canonical name from the longest match
                longest_match = max(matches, key=lambda x: x[0] - x[1][1].find(clean_name))
                return longest_match[1][0]  # Return canonical name
        
        # Fallback to dictionary lookup
        for canonical, equivalents in self.known_equivalents.items():
            if clean_name == canonical or clean_name in equivalents:
                return canonical
            
            # Try fuzzy matching
            for variant in equivalents + [canonical]:
                jw_sim = self.similarity_algorithms.jaro_winkler_similarity(variant, clean_name)
                if jw_sim > 0.9:
                    return canonical
        
        # Apply domain-specific normalization if domain is provided
        if domain and domain.lower() in self.common_merchant_patterns:
            # Try applying all patterns for this domain
            for pattern_type, patterns in self.compiled_patterns[domain.lower()].items():
                for pattern, replacement in patterns:
                    if pattern.search(clean_name):
                        normalized = pattern.sub(replacement, clean_name)
                        if normalized != clean_name:
                            return normalized
        
        # Apply general patterns
        for pattern_type, patterns in self.compiled_patterns['general'].items():
            for pattern, replacement in patterns:
                if pattern.search(clean_name):
                    normalized = pattern.sub(replacement, clean_name)
                    if normalized != clean_name:
                        return normalized
        
        # If no patterns matched, return the preprocessed name
        return clean_name

In [12]:
# Cell 6: Advanced BERT-Based Semantic Analysis

class BertSemanticAnalyzer:
    """
    Advanced semantic analyzer using BERT and domain-specific tuning for
    merchant name matching with enhanced accuracy.
    """
    
    def __init__(self, bert_embedder=None, similarity_algorithms=None, preprocessor=None):
        """
        Initialize semantic analyzer with BERT embedder and supporting algorithms
        
        Args:
            bert_embedder: BERT embedder for semantic similarity
            similarity_algorithms: Similarity algorithms for merchant matching
            preprocessor: Merchant name preprocessor
        """
        self.bert_embedder = bert_embedder or AdvancedBERTEmbedder()
        self.similarity_algorithms = similarity_algorithms or SimilarityAlgorithms()
        self.preprocessor = preprocessor or MerchantPreprocessor()
        
        # Thresholds for semantic analysis
        self.high_similarity_threshold = 0.85
        self.medium_similarity_threshold = 0.70
        self.low_similarity_threshold = 0.55
        
        # Cached semantic clusters for merchant names
        self.merchant_clusters = {}
        self.cluster_embeddings = {}
        
        # Track if domain adaptation has been performed
        self.adapted_to_domain = False
    
    def analyze_semantic_match(self, s1, s2, domain=None):
        """
        Analyze the semantic match between two merchant names
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized analysis
            
        Returns:
            dict: Semantic analysis results
        """
        if not s1 or not s2:
            return {
                'semantic_similarity': 0.0,
                'match_level': 'no_match',
                'analysis': 'Empty input provided'
            }
        
        # Preprocess inputs
        s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
        
        # Check if strings are empty after preprocessing
        if not s1_clean or not s2_clean:
            return {
                'semantic_similarity': 0.0,
                'match_level': 'no_match',
                'analysis': 'Empty strings after preprocessing'
            }
        
        # Calculate semantic similarity
        semantic_similarity = self.bert_embedder.compute_similarity(s1_clean, s2_clean)
        
        # Determine match level
        match_level = self._get_match_level(semantic_similarity)
        
        # Get enhanced analysis
        analysis = self._get_enhanced_analysis(s1_clean, s2_clean, semantic_similarity, domain)
        
        return {
            'semantic_similarity': semantic_similarity,
            'match_level': match_level,
            'analysis': analysis
        }
    
    def _get_match_level(self, similarity_score):
        """Determine the match level based on similarity score"""
        if similarity_score >= self.high_similarity_threshold:
            return 'high_match'
        elif similarity_score >= self.medium_similarity_threshold:
            return 'medium_match'
        elif similarity_score >= self.low_similarity_threshold:
            return 'low_match'
        else:
            return 'no_match'
    
    def _get_enhanced_analysis(self, s1, s2, similarity, domain=None):
        """
        Get enhanced analysis of semantic match with domain-specific insights
        
        Args:
            s1 (str): First preprocessed merchant name
            s2 (str): Second preprocessed merchant name
            similarity (float): Semantic similarity score
            domain (str, optional): Domain for specialized analysis
            
        Returns:
            dict: Enhanced analysis results
        """
        analysis = {
            'semantic_matching_points': [],
            'context_similarity': 0.0,
            'potential_relationship': None,
            'confidence': 0.0
        }
        
        # Analyze word-level matching
        s1_words = s1.split()
        s2_words = s2.split()
        
        matching_words = set(s1_words).intersection(set(s2_words))
        
        # Calculate word match ratio
        s1_match_ratio = len(matching_words) / len(s1_words) if s1_words else 0
        s2_match_ratio = len(matching_words) / len(s2_words) if s2_words else 0
        
        analysis['word_match_ratio'] = (s1_match_ratio + s2_match_ratio) / 2
        analysis['matching_words'] = list(matching_words)
        
        # Calculate context similarity (how words are used together)
        if len(s1_words) > 1 and len(s2_words) > 1:
            s1_bigrams = set(zip(s1_words[:-1], s1_words[1:]))
            s2_bigrams = set(zip(s2_words[:-1], s2_words[1:]))
            
            matching_bigrams = s1_bigrams.intersection(s2_bigrams)
            
            s1_bigram_ratio = len(matching_bigrams) / len(s1_bigrams) if s1_bigrams else 0
            s2_bigram_ratio = len(matching_bigrams) / len(s2_bigrams) if s2_bigrams else 0
            
            analysis['context_similarity'] = (s1_bigram_ratio + s2_bigram_ratio) / 2
        
        # Identify semantic matching points (key concepts that match)
        # Use BERT to identify important semantic units in each name
        if hasattr(self.bert_embedder, 'model') and hasattr(self.bert_embedder.model, 'tokenizer'):
            try:
                # Get token contributions to semantic meaning
                s1_tokens = self.bert_embedder.model.tokenizer.tokenize(s1)
                s2_tokens = self.bert_embedder.model.tokenizer.tokenize(s2)
                
                # Filter out special tokens and find matching semantic tokens
                s1_tokens = [t for t in s1_tokens if t not in ['[CLS]', '[SEP]']]
                s2_tokens = [t for t in s2_tokens if t not in ['[CLS]', '[SEP]']]
                
                matching_tokens = set(s1_tokens).intersection(set(s2_tokens))
                
                # Add matching tokens as semantic matching points
                analysis['semantic_matching_points'] = list(matching_tokens)
            except:
                # Fallback to word-level matching if tokenizer isn't available
                analysis['semantic_matching_points'] = list(matching_words)
        else:
            # Fallback to word-level matching if tokenizer isn't available
            analysis['semantic_matching_points'] = list(matching_words)
        
        # Determine potential relationship
        if similarity >= self.high_similarity_threshold:
            if s1 == s2:
                analysis['potential_relationship'] = 'exact_match'
                analysis['confidence'] = 0.99
            elif len(s1) > len(s2) * 1.5:
                analysis['potential_relationship'] = 'expansion'
                analysis['confidence'] = 0.90
            elif len(s2) > len(s1) * 1.5:
                analysis['potential_relationship'] = 'abbreviation'
                analysis['confidence'] = 0.90
            else:
                analysis['potential_relationship'] = 'variant'
                analysis['confidence'] = 0.85
        elif similarity >= self.medium_similarity_threshold:
            if self.similarity_algorithms.acronym_similarity(s1, s2) > 0.7:
                analysis['potential_relationship'] = 'acronym'
                analysis['confidence'] = 0.80
            elif self.similarity_algorithms.token_set_ratio(s1, s2) > 0.8:
                analysis['potential_relationship'] = 'reordered'
                analysis['confidence'] = 0.75
            else:
                analysis['potential_relationship'] = 'related'
                analysis['confidence'] = 0.65
        elif similarity >= self.low_similarity_threshold:
            analysis['potential_relationship'] = 'loosely_related'
            analysis['confidence'] = 0.50
        else:
            analysis['potential_relationship'] = 'unrelated'
            analysis['confidence'] = 0.90
        
        # Add domain-specific analysis if domain is provided
        if domain:
            domain_analysis = self._get_domain_specific_analysis(s1, s2, domain)
            if domain_analysis:
                analysis['domain_specific'] = domain_analysis
        
        return analysis
    
    def _get_domain_specific_analysis(self, s1, s2, domain):
        """
        Get domain-specific analysis for semantic matching
        
        Args:
            s1 (str): First preprocessed merchant name
            s2 (str): Second preprocessed merchant name
            domain (str): Domain for specialized analysis
            
        Returns:
            dict: Domain-specific analysis or None if domain not supported
        """
        domain_lower = domain.lower()
        
        if domain_lower == 'banking' or domain_lower == 'financial':
            return self._analyze_financial_merchants(s1, s2)
        elif domain_lower == 'retail':
            return self._analyze_retail_merchants(s1, s2)
        elif domain_lower == 'restaurant' or domain_lower == 'food':
            return self._analyze_restaurant_merchants(s1, s2)
        elif domain_lower == 'hotel' or domain_lower == 'hospitality':
            return self._analyze_hotel_merchants(s1, s2)
        elif domain_lower == 'gas_station' or domain_lower == 'fuel':
            return self._analyze_gas_station_merchants(s1, s2)
        else:
            return None
    
    def _analyze_financial_merchants(self, s1, s2):
        """Specialized analysis for financial institutions"""
        analysis = {
            'industry': 'financial',
            'subtype': None,
            'notes': []
        }
        
        # Detect bank type
        bank_indicators = ['bank', 'credit union', 'financial', 'invest', 'capital', 'trust']
        credit_indicators = ['card', 'credit', 'loan', 'lending', 'mortgage', 'finance']
        payment_indicators = ['pay', 'wallet', 'transfer', 'remit', 'money']
        
        for indicator in bank_indicators:
            if indicator in s1 or indicator in s2:
                analysis['subtype'] = 'banking'
                analysis['notes'].append(f"Banking institution detected via keyword: '{indicator}'")
                break
                
        if not analysis['subtype']:
            for indicator in credit_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'credit'
                    analysis['notes'].append(f"Credit institution detected via keyword: '{indicator}'")
                    break
        
        if not analysis['subtype']:
            for indicator in payment_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'payment'
                    analysis['notes'].append(f"Payment service detected via keyword: '{indicator}'")
                    break
        
        # If still not classified, use default
        if not analysis['subtype']:
            analysis['subtype'] = 'general_financial'
            
        # Check for branch patterns
        branch_patterns = ['branch', 'location', 'center', 'office', 'atm']
        for pattern in branch_patterns:
            if pattern in s1 and pattern not in s2:
                analysis['notes'].append(f"'{s1}' appears to be a branch of '{s2}'")
                break
            elif pattern in s2 and pattern not in s1:
                analysis['notes'].append(f"'{s2}' appears to be a branch of '{s1}'")
                break
        
        return analysis
    
    def _analyze_retail_merchants(self, s1, s2):
        """Specialized analysis for retail merchants"""
        analysis = {
            'industry': 'retail',
            'subtype': None,
            'notes': []
        }
        
        # Detect retail type
        grocery_indicators = ['grocery', 'market', 'supermarket', 'food', 'mart']
        department_indicators = ['department', 'store', 'mall', 'center', 'warehouse']
        specialty_indicators = ['electronics', 'furniture', 'clothing', 'apparel', 'hardware', 'pharmacy']
        
        for indicator in grocery_indicators:
            if indicator in s1 or indicator in s2:
                analysis['subtype'] = 'grocery'
                analysis['notes'].append(f"Grocery retailer detected via keyword: '{indicator}'")
                break
                
        if not analysis['subtype']:
            for indicator in department_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'department'
                    analysis['notes'].append(f"Department store detected via keyword: '{indicator}'")
                    break
        
        if not analysis['subtype']:
            for indicator in specialty_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'specialty'
                    analysis['notes'].append(f"Specialty retailer detected via keyword: '{indicator}'")
                    break
        
        # If still not classified, use default
        if not analysis['subtype']:
            analysis['subtype'] = 'general_retail'
            
        # Check for location patterns
        location_patterns = ['#', 'store', 'location', 'supercenter', 'express']
        for pattern in location_patterns:
            if pattern in s1 and pattern not in s2:
                analysis['notes'].append(f"'{s1}' appears to be a specific location of '{s2}'")
                break
            elif pattern in s2 and pattern not in s1:
                analysis['notes'].append(f"'{s2}' appears to be a specific location of '{s1}'")
                break
        
        return analysis
    
    def _analyze_restaurant_merchants(self, s1, s2):
        """Specialized analysis for restaurant merchants"""
        analysis = {
            'industry': 'restaurant',
            'subtype': None,
            'notes': []
        }
        
        # Detect restaurant type
        fast_food_indicators = ['mcdonald', 'burger', 'taco', 'pizza', 'wendy', 'kfc', 'subway']
        cafe_indicators = ['coffee', 'cafe', 'starbucks', 'tea', 'bakery', 'dunkin']
        dining_indicators = ['restaurant', 'grill', 'kitchen', 'house', 'steakhouse', 'bistro']
        
        for indicator in fast_food_indicators:
            if indicator in s1 or indicator in s2:
                analysis['subtype'] = 'fast_food'
                analysis['notes'].append(f"Fast food restaurant detected via keyword: '{indicator}'")
                break
                
        if not analysis['subtype']:
            for indicator in cafe_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'cafe'
                    analysis['notes'].append(f"Cafe detected via keyword: '{indicator}'")
                    break
        
        if not analysis['subtype']:
            for indicator in dining_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'dining'
                    analysis['notes'].append(f"Dining restaurant detected via keyword: '{indicator}'")
                    break
        
        # If still not classified, use default
        if not analysis['subtype']:
            analysis['subtype'] = 'general_restaurant'
            
        # Check for location patterns
        location_patterns = ['#', 'restaurant', 'location', 'express', 'drive']
        for pattern in location_patterns:
            if pattern in s1 and pattern not in s2:
                analysis['notes'].append(f"'{s1}' appears to be a specific location of '{s2}'")
                break
            elif pattern in s2 and pattern not in s1:
                analysis['notes'].append(f"'{s2}' appears to be a specific location of '{s1}'")
                break
        
        return analysis
    
    def _analyze_hotel_merchants(self, s1, s2):
        """Specialized analysis for hotel merchants"""
        analysis = {
            'industry': 'hotel',
            'subtype': None,
            'notes': []
        }
        
        # Detect hotel type
        luxury_indicators = ['resort', 'spa', 'luxury', 'grand', 'palace', 'ritz']
        budget_indicators = ['inn', 'motel', 'lodge', 'suites', 'stay', 'sleep']
        chain_indicators = ['marriott', 'hilton', 'hyatt', 'holiday inn', 'sheraton', 'westin']
        
        for indicator in luxury_indicators:
            if indicator in s1 or indicator in s2:
                analysis['subtype'] = 'luxury'
                analysis['notes'].append(f"Luxury hotel detected via keyword: '{indicator}'")
                break
                
        if not analysis['subtype']:
            for indicator in budget_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'budget'
                    analysis['notes'].append(f"Budget hotel detected via keyword: '{indicator}'")
                    break
        
        if not analysis['subtype']:
            for indicator in chain_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'chain'
                    analysis['notes'].append(f"Hotel chain detected via keyword: '{indicator}'")
                    break
        
        # If still not classified, use default
        if not analysis['subtype']:
            analysis['subtype'] = 'general_hotel'
            
        # Check for property patterns
        property_patterns = ['hotel', 'resort', 'suites', 'inn', 'by']
        for pattern in property_patterns:
            if pattern in s1 and pattern not in s2:
                analysis['notes'].append(f"'{s1}' appears to be a specific property of '{s2}'")
                break
            elif pattern in s2 and pattern not in s1:
                analysis['notes'].append(f"'{s2}' appears to be a specific property of '{s1}'")
                break
        
        return analysis
    
    def _analyze_gas_station_merchants(self, s1, s2):
        """Specialized analysis for gas station merchants"""
        analysis = {
            'industry': 'gas_station',
            'subtype': None,
            'notes': []
        }
        
        # Detect gas station type
        major_indicators = ['shell', 'exxon', 'mobil', 'bp', 'chevron', 'texaco']
        convenience_indicators = ['7-eleven', 'circle k', 'speedway', 'am/pm', 'quicktrip']
        service_indicators = ['service', 'auto', 'repair', 'maintenance', 'lube']
        
        for indicator in major_indicators:
            if indicator in s1 or indicator in s2:
                analysis['subtype'] = 'major_brand'
                analysis['notes'].append(f"Major gas station brand detected via keyword: '{indicator}'")
                break
                
        if not analysis['subtype']:
            for indicator in convenience_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'convenience'
                    analysis['notes'].append(f"Convenience store gas station detected via keyword: '{indicator}'")
                    break
        
        if not analysis['subtype']:
            for indicator in service_indicators:
                if indicator in s1 or indicator in s2:
                    analysis['subtype'] = 'service'
                    analysis['notes'].append(f"Service station detected via keyword: '{indicator}'")
                    break
        
        # If still not classified, use default
        if not analysis['subtype']:
            analysis['subtype'] = 'general_gas_station'
            
        # Check for location patterns
        location_patterns = ['gas', 'station', 'fuel', 'mart', 'convenience']
        for pattern in location_patterns:
            if pattern in s1 and pattern not in s2:
                analysis['notes'].append(f"'{s1}' appears to be a specific location of '{s2}'")
                break
            elif pattern in s2 and pattern not in s1:
                analysis['notes'].append(f"'{s2}' appears to be a specific location of '{s1}'")
                break
        
        return analysis
    
    def adapt_to_domain(self, examples_df, domain=None):
        """
        Adapt the semantic analyzer to a specific domain using example data
        
        Args:
            examples_df (DataFrame): DataFrame with example merchant name pairs
            domain (str, optional): Domain for adaptation
            
        Returns:
            bool: True if adaptation was successful, False otherwise
        """
        if not isinstance(examples_df, pd.DataFrame) or len(examples_df) < 5:
            logger.warning("Insufficient data for domain adaptation")
            return False
        
        logger.info(f"Adapting semantic analyzer to domain: {domain or 'general'}")
        
        # First, adapt the BERT embedder if available
        if hasattr(self.bert_embedder, 'adapt_to_domain'):
            try:
                self.bert_embedder.adapt_to_domain(examples_df, epochs=5)
            except Exception as e:
                logger.warning(f"Failed to adapt BERT embedder to domain: {e}")
        
        # Build merchant clusters from the examples
        try:
            self._build_merchant_clusters(examples_df, domain)
        except Exception as e:
            logger.warning(f"Failed to build merchant clusters: {e}")
        
        self.adapted_to_domain = True
        return True
    
    def _build_merchant_clusters(self, examples_df, domain=None):
        """
        Build merchant name clusters from example data
        
        Args:
            examples_df (DataFrame): DataFrame with example merchant name pairs
            domain (str, optional): Domain for clustering
        """
        # Extract all unique merchant names
        all_names = []
        
        # Check different column combinations to extract merchant names
        if 'Acronym' in examples_df.columns and 'Full_Name' in examples_df.columns:
            all_names.extend(examples_df['Acronym'].tolist())
            all_names.extend(examples_df['Full_Name'].tolist())
        elif 'input_name' in examples_df.columns and 'matched_name' in examples_df.columns:
            all_names.extend(examples_df['input_name'].tolist())
            all_names.extend(examples_df['matched_name'].tolist())
        else:
            # Just take all string columns as potential merchant names
            for col in examples_df.columns:
                if examples_df[col].dtype == 'object':
                    all_names.extend(examples_df[col].tolist())
        
        # Filter out non-string values and deduplicate
        all_names = [name for name in all_names if isinstance(name, str) and name.strip()]
        unique_names = list(set(all_names))
        
        # Preprocess all names
        processed_names = [self.preprocessor.preprocess(name, domain) for name in unique_names]
        
        # Get embeddings for all names
        try:
            embeddings = self.bert_embedder.encode(processed_names)
            
            # Create clusters using hierarchical clustering
            from sklearn.cluster import AgglomerativeClustering
            
            # Determine optimal number of clusters
            max_clusters = min(20, len(processed_names) // 2)
            if max_clusters < 2:
                max_clusters = 2
                
            clustering = AgglomerativeClustering(
                n_clusters=max_clusters,
                affinity='euclidean',
                linkage='ward'
            )
            
            cluster_labels = clustering.fit_predict(embeddings)
            
            # Store clusters
            self.merchant_clusters = {}
            for i, (name, processed, embedding, label) in enumerate(
                zip(unique_names, processed_names, embeddings, cluster_labels)
            ):
                if label not in self.merchant_clusters:
                    self.merchant_clusters[label] = []
                    self.cluster_embeddings[label] = []
                
                self.merchant_clusters[label].append((name, processed))
                self.cluster_embeddings[label].append(embedding)
            
            logger.info(f"Built {len(self.merchant_clusters)} merchant clusters from {len(unique_names)} names")
            
        except Exception as e:
            logger.warning(f"Failed to create merchant clusters: {e}")
    
    def get_semantically_similar_merchants(self, merchant_name, top_k=5, domain=None):
        """
        Get semantically similar merchants from the adapted clusters
        
        Args:
            merchant_name (str): Merchant name to find similar matches for
            top_k (int): Number of top similar merchants to return
            domain (str, optional): Domain for specialized processing
            
        Returns:
            list: List of tuples with (merchant_name, similarity_score)
        """
        if not self.merchant_clusters:
            logger.warning("No merchant clusters available. Call adapt_to_domain first.")
            return []
            
        # Preprocess the input
        processed_name = self.preprocessor.preprocess(merchant_name, domain)
        
        # Get embedding for the input name
        query_embedding = self.bert_embedder.encode([processed_name])[0]
        
        # Find the closest cluster
        closest_cluster = None
        max_similarity = -1
        
        for label, cluster_embeddings in self.cluster_embeddings.items():
            # Calculate average similarity to cluster
            cluster_similarities = [
                np.dot(query_embedding, emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(emb))
                for emb in cluster_embeddings
            ]
            avg_similarity = np.mean(cluster_similarities)
            
            if avg_similarity > max_similarity:
                max_similarity = avg_similarity
                closest_cluster = label
        
        # If no close cluster found, return empty list
        if closest_cluster is None:
            return []
            
        # Calculate similarity to all merchants in the closest cluster and other high-similarity clusters
        all_similarities = []
        
        # Check the closest cluster and nearby clusters
        clusters_to_check = [closest_cluster]
        
        # Add other potential clusters if they're reasonably similar
        for label in self.merchant_clusters:
            if label != closest_cluster:
                # Calculate similarity to cluster centroid
                centroid = np.mean(self.cluster_embeddings[label], axis=0)
                similarity = np.dot(query_embedding, centroid) / (np.linalg.norm(query_embedding) * np.linalg.norm(centroid))
                
                if similarity > 0.7:  # Check other reasonably similar clusters
                    clusters_to_check.append(label)
        
        # Calculate similarities to all merchants in selected clusters
        for label in clusters_to_check:
            for i, (name, processed) in enumerate(self.merchant_clusters[label]):
                embedding = self.cluster_embeddings[label][i]
                similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding))
                all_similarities.append((name, similarity))
        
        # Sort by similarity score and return top-k
        all_similarities.sort(key=lambda x: x[1], reverse=True)
        return all_similarities[:top_k]

In [14]:
# Cell 7: Enhanced Merchant Matching System



In [16]:
# 7.1: Core Matcher Class Definition and Initialization

class EnhancedMerchantMatcher:
    """
    Advanced merchant matching system that combines multiple algorithms with machine learning
    for high-accuracy merchant name matching across various domains.
    
    Key features:
    - Multi-algorithm approach combining string similarity, semantic similarity, and pattern matching
    - Domain-specific customization for different industries
    - Machine learning integration for optimal feature weighting
    - Comprehensive explanation of matching decisions
    - Batch processing capabilities for large datasets
    """
    
    def __init__(self, 
                 preprocessor=None, 
                 similarity_algorithms=None,
                 pattern_recognition=None,
                 semantic_analyzer=None,
                 bert_embedder=None,
                 weights=None,
                 thresholds=None):
        """
        Initialize enhanced merchant matcher with configurable components
        
        Args:
            preprocessor: Merchant name preprocessor for text normalization
            similarity_algorithms: Collection of string similarity algorithms
            pattern_recognition: Pattern recognition for merchant naming patterns
            semantic_analyzer: Semantic analysis using language models
            bert_embedder: BERT embedder for semantic similarity
            weights (dict): Custom weights for different matching algorithms
            thresholds (dict): Custom thresholds for match level determination
        """
        # Initialize BERT embedder first since other components depend on it
        self.bert_embedder = bert_embedder or AdvancedBERTEmbedder(
            model_name='sentence-transformers/all-mpnet-base-v2',  # Using high-quality pretrained model
            pooling_strategy='mean',  # Better for merchant name comparisons than CLS token
            cache_size=10000,  # Increased cache for better performance on large datasets
            device=device  # Use GPU if available
        )
        
        # Initialize preprocessing component
        self.preprocessor = preprocessor or MerchantPreprocessor()
        
        # Initialize similarity algorithms
        self.similarity_algorithms = similarity_algorithms or SimilarityAlgorithms(
            preprocessor=self.preprocessor,
            bert_embedder=self.bert_embedder
        )
        
        # Initialize pattern recognition
        self.pattern_recognition = pattern_recognition or PatternRecognition(
            preprocessor=self.preprocessor,
            similarity_algorithms=self.similarity_algorithms
        )
        
        # Initialize semantic analyzer
        self.semantic_analyzer = semantic_analyzer or BertSemanticAnalyzer(
            bert_embedder=self.bert_embedder,
            similarity_algorithms=self.similarity_algorithms,
            preprocessor=self.preprocessor
        )
        
        # Set match thresholds (customizable)
        self.thresholds = thresholds or {
            'high': 0.85,   # Threshold for high confidence matches
            'medium': 0.75, # Threshold for medium confidence matches
            'low': 0.60     # Threshold for low confidence matches
        }
        
        # Set feature weights (customizable)
        self.weights = weights or {
            'string_similarity': 0.20,    # Traditional string matching (Jaro-Winkler)
            'token_set_similarity': 0.15, # Word-level matching regardless of order
            'semantic_similarity': 0.30,  # Deep meaning similarity via BERT
            'pattern_match': 0.25,        # Industry-specific patterns
            'contains_check': 0.05,       # One name contains the other
            'acronym_check': 0.05         # Acronym/abbreviation relationship
        }
        
        # Initialize model attributes for ML-enhanced matching
        self.ml_model = None              # Will hold trained classifier model
        self.feature_names = None         # Feature names for model input
        self.feature_importance = None    # Importance of each feature after training
        self.model_trained = False        # Flag to track if model is ready
        
        # Performance tracking
        self.cache = {}                   # Cache for previously computed matches
        self.cache_hits = 0
        self.cache_size = 1000            # Max cache size
        
        logger.info("Enhanced Merchant Matcher initialized with all components")

In [18]:
# 7.2: Core Matching Algorithm Implementation

def match_merchants(self, s1, s2, domain=None, return_details=False, use_cache=True):
    """
    Match two merchant names using an enhanced multi-algorithm approach
    
    Args:
        s1 (str): First merchant name
        s2 (str): Second merchant name
        domain (str, optional): Domain for specialized matching (e.g., banking, retail)
        return_details (bool): Whether to return detailed match information
        use_cache (bool): Whether to use and update the cache
        
    Returns:
        float or dict: Match score (0-1) or detailed match information dictionary
    """
    # Check cache first if enabled
    if use_cache:
        cache_key = f"{s1}|{s2}|{domain}"
        if cache_key in self.cache:
            self.cache_hits += 1
            return self.cache[cache_key] if not return_details else self.cache[cache_key + "_details"]
    
    # Check for empty inputs
    if not s1 or not s2:
        result = 0.0
        details = {
            'match_score': 0.0,
            'match_level': 'no_match',
            'explanation': 'Empty input provided'
        }
        if use_cache:
            self._update_cache(s1, s2, domain, result, details)
        return details if return_details else result
    
    # Preprocess inputs based on domain-specific rules
    s1_clean, s2_clean = self.preprocessor.preprocess_pair(s1, s2, domain)
    
    # Check if strings are empty after preprocessing
    if not s1_clean or not s2_clean:
        result = 0.0
        details = {
            'match_score': 0.0,
            'match_level': 'no_match',
            'explanation': 'Empty strings after preprocessing'
        }
        if use_cache:
            self._update_cache(s1, s2, domain, result, details)
        return details if return_details else result
    
    # Exact match fast path (significant performance optimization)
    if s1_clean == s2_clean:
        result = 1.0
        details = {
            'match_score': 1.0,
            'match_level': 'exact_match',
            'explanation': 'Exact match after preprocessing'
        }
        if use_cache:
            self._update_cache(s1, s2, domain, result, details)
        return details if return_details else result
    
    # Length check optimization - if lengths are too different, they're unlikely to match
    len_ratio = min(len(s1_clean), len(s2_clean)) / max(len(s1_clean), len(s2_clean))
    if len_ratio < 0.3:  # Very different lengths suggest different entities
        # Quick check for acronyms before rejecting
        acronym_score = self.similarity_algorithms.specialized_acronym_similarity(s1_clean, s2_clean, domain)
        if acronym_score < 0.8:  # Not a strong acronym relationship
            result = 0.2 * acronym_score  # Low score with some influence from acronym check
            details = {
                'match_score': result,
                'match_level': 'no_match',
                'explanation': 'Very different lengths and not a clear acronym relationship'
            }
            if use_cache:
                self._update_cache(s1, s2, domain, result, details)
            return details if return_details else result
    
    # Calculate individual similarity scores with optimized order (fastest first)
    # First, calculate string similarity scores (faster than semantic)
    string_sim = self.similarity_algorithms.jaro_winkler_similarity(s1_clean, s2_clean)
    token_set_sim = self.similarity_algorithms.token_set_ratio(s1_clean, s2_clean)
    contains_score = self.similarity_algorithms.contains_ratio(s1_clean, s2_clean)
    acronym_score = self.similarity_algorithms.specialized_acronym_similarity(s1_clean, s2_clean, domain)
    
    # Early decision based on string metrics (optimization for clear non-matches)
    if max(string_sim, token_set_sim, contains_score, acronym_score) < 0.4:
        result = max(string_sim, token_set_sim, contains_score, acronym_score)
        details = {
            'match_score': result,
            'match_level': 'no_match',
            'explanation': 'Low similarity across all string metrics'
        }
        if use_cache:
            self._update_cache(s1, s2, domain, result, details)
        return details if return_details else result
    
    # Early decision based on string metrics (optimization for very high confidence matches)
    high_string_match = (string_sim > 0.95 or token_set_sim > 0.95 or 
                         contains_score > 0.95 or acronym_score > 0.95)
    if high_string_match:
        result = max(string_sim, token_set_sim, contains_score, acronym_score)
        details = {
            'match_score': result,
            'match_level': 'high_match',
            'explanation': 'Very high string similarity'
        }
        if use_cache:
            self._update_cache(s1, s2, domain, result, details)
        return details if return_details else result
    
    # Only compute more expensive pattern and semantic metrics if needed
    patterns = self.pattern_recognition.detect_merchant_patterns(s1_clean, s2_clean, domain)
    pattern_score = max([p.get('confidence', 0) for p in patterns.values()]) if patterns else 0.0
    
    semantic_sim = self.semantic_analyzer.analyze_semantic_match(
        s1_clean, s2_clean, domain
    )['semantic_similarity']
    
    # Collect all features for either ML model or weighted scoring
    features = {
        'string_similarity': string_sim,
        'token_set_similarity': token_set_sim,
        'semantic_similarity': semantic_sim,
        'contains_score': contains_score,
        'acronym_score': acronym_score,
        'pattern_score': pattern_score,
        # Add contextual features for more robust matching
        'length_ratio': len_ratio,
        'word_count_ratio': min(len(s1_clean.split()), len(s2_clean.split())) / 
                           max(len(s1_clean.split()), len(s2_clean.split())) 
                           if max(len(s1_clean.split()), len(s2_clean.split())) > 0 else 0,
        'common_prefix_length': self._common_prefix_length(s1_clean, s2_clean) / 
                               min(len(s1_clean), len(s2_clean)) if min(len(s1_clean), len(s2_clean)) > 0 else 0,
    }
    
    # Determine score using ML model if trained, otherwise use weighted average
    if self.model_trained and self.ml_model is not None:
        try:
            # Prepare features for the model
            feature_vector = [features[feat] for feat in self.feature_names]
            feature_vector = np.array(feature_vector).reshape(1, -1)
            
            # Get model prediction (confidence of match)
            score = float(self.ml_model.predict_proba(feature_vector)[0, 1])
            explanation = "Score determined by machine learning model"
        except Exception as e:
            logger.warning(f"Error using ML model for prediction: {e}")
            # Fall back to weighted average
            score = self._calculate_weighted_score(features)
            explanation = "Score determined by weighted algorithm average (ML model failed)"
    else:
        # Use weighted average of different similarity measures
        score = self._calculate_weighted_score(features)
        explanation = "Score determined by weighted algorithm average"
    
    # Determine match level
    match_level = self._determine_match_level(score)
    
    # Generate comprehensive explanation
    full_explanation = self._generate_explanation(
        s1, s2, s1_clean, s2_clean, 
        features, patterns, match_level, domain,
        explanation
    )
    
    # Prepare result
    result = score
    details = {
        'match_score': score,
        'match_level': match_level,
        'explanation': full_explanation,
        'features': features,
        'patterns': patterns,
        'processed_s1': s1_clean,
        'processed_s2': s2_clean
    }
    
    # Update cache
    if use_cache:
        self._update_cache(s1, s2, domain, result, details)
    
    return details if return_details else result

def _update_cache(self, s1, s2, domain, result, details):
    """Update the cache with match results"""
    cache_key = f"{s1}|{s2}|{domain}"
    # Also cache the reverse pair to maximize cache hits
    reverse_key = f"{s2}|{s1}|{domain}"
    
    # Enforce cache size limit using simple LRU strategy
    if len(self.cache) >= self.cache_size:
        # Remove oldest item (first key in dict)
        self.cache.pop(next(iter(self.cache)))
    
    self.cache[cache_key] = result
    self.cache[cache_key + "_details"] = details
    self.cache[reverse_key] = result
    self.cache[reverse_key + "_details"] = details

def _common_prefix_length(self, s1, s2):
    """Calculate the length of the common prefix between two strings"""
    common_len = 0
    for c1, c2 in zip(s1, s2):
        if c1 == c2:
            common_len += 1
        else:
            break
    return common_len

In [20]:
# 7.3: Score Calculation and Match Level Determination

def _calculate_weighted_score(self, features):
    """
    Calculate weighted score based on individual feature scores
    with adaptive weighting based on feature values
    
    Args:
        features (dict): Dictionary of feature names and values
        
    Returns:
        float: Weighted score between 0 and 1
    """
    weighted_sum = 0.0
    total_weight = 0.0
    
    # Dynamic weights adjustment based on feature values
    adjusted_weights = self.weights.copy()
    
    # Boost importance of high-value features
    for feature_name, value in features.items():
        if feature_name in adjusted_weights and value > 0.9:
            # Boost weights of very strong signals
            adjusted_weights[feature_name] *= 1.5
        elif feature_name in adjusted_weights and value < 0.3:
            # Reduce weights of very weak signals
            adjusted_weights[feature_name] *= 0.5
    
    # Special case: if acronym score is high, boost its importance
    if features.get('acronym_score', 0) > 0.8:
        adjusted_weights['acronym_check'] = max(
            adjusted_weights.get('acronym_check', 0) * 2.0,
            0.3  # Ensure significant weight
        )
    
    # Apply all weights
    for feature_name, weight in adjusted_weights.items():
        if feature_name in features:
            weighted_sum += features[feature_name] * weight
            total_weight += weight
    
    # Add other features with small weights
    for feature_name, value in features.items():
        if feature_name not in adjusted_weights:
            weighted_sum += value * 0.05
            total_weight += 0.05
    
    # Normalize
    normalized_score = weighted_sum / total_weight if total_weight > 0 else 0.0
    
    # Apply sigmoid scaling for better distribution of scores
    # This helps separate clear matches from borderline cases
    return self._sigmoid_scale(normalized_score)

def _sigmoid_scale(self, score, steepness=10, midpoint=0.5):
    """
    Apply sigmoid scaling to concentrate scores at the extremes
    
    Args:
        score (float): Input score between 0 and 1
        steepness (float): Controls how steep the sigmoid curve is
        midpoint (float): Point around which to center the sigmoid
        
    Returns:
        float: Scaled score between 0 and 1
    """
    # Skip scaling for extreme values
    if score > 0.95:
        return 1.0
    if score < 0.05:
        return 0.0
        
    # Apply sigmoid transformation
    scaled = 1 / (1 + np.exp(-steepness * (score - midpoint)))
    
    # Rescale from [0.27, 0.73] to [0, 1]
    min_sigmoid = 1 / (1 + np.exp(-steepness * (0 - midpoint)))
    max_sigmoid = 1 / (1 + np.exp(-steepness * (1 - midpoint)))
    
    return (scaled - min_sigmoid) / (max_sigmoid - min_sigmoid)

def _determine_match_level(self, score):
    """
    Determine the match level based on the score
    
    Args:
        score (float): Match score between 0 and 1
        
    Returns:
        str: Match level category
    """
    if score >= self.thresholds['high']:
        return 'high_match'
    elif score >= self.thresholds['medium']:
        return 'medium_match'
    elif score >= self.thresholds['low']:
        return 'low_match'
    else:
        return 'no_match'

In [22]:
# 7.4: Explanation Generation

def _generate_explanation(self, s1, s2, s1_clean, s2_clean, features, patterns, match_level, domain, base_explanation):
    """
    Generate a comprehensive human-readable explanation of the match
    
    Args:
        s1 (str): Original first merchant name
        s2 (str): Original second merchant name
        s1_clean (str): Preprocessed first merchant name
        s2_clean (str): Preprocessed second merchant name
        features (dict): Feature scores dictionary
        patterns (dict): Detected patterns dictionary
        match_level (str): Determined match level
        domain (str): Domain context
        base_explanation (str): Base explanation about scoring method
        
    Returns:
        str: Detailed explanation of the match
    """
    explanations = [base_explanation]
    
    # Add preprocessed form explanation if different from input
    if s1 != s1_clean or s2 != s2_clean:
        if s1 != s1_clean:
            explanations.append(f"Preprocessed '{s1}' to '{s1_clean}'")
        if s2 != s2_clean:
            explanations.append(f"Preprocessed '{s2}' to '{s2_clean}'")
    
    # Add key feature explanations with formatting based on strength
    string_sim = features['string_similarity']
    if string_sim > 0.9:
        explanations.append(f"STRONG string similarity: {string_sim:.2f}")
    elif string_sim > 0.7:
        explanations.append(f"GOOD string similarity: {string_sim:.2f}")
    else:
        explanations.append(f"String similarity: {string_sim:.2f}")
    
    # Add token similarity explanation
    token_sim = features['token_set_similarity']
    if token_sim > string_sim + 0.2:
        explanations.append(f"Word-level matching ({token_sim:.2f}) much stronger than character-level, " +
                           "suggesting possible word reordering or extra words")
    
    # Add semantic similarity explanation
    semantic_sim = features['semantic_similarity']
    if semantic_sim > 0.9:
        explanations.append(f"STRONG semantic similarity: {semantic_sim:.2f}")
    elif semantic_sim > 0.7:
        explanations.append(f"GOOD semantic similarity: {semantic_sim:.2f}")
    elif semantic_sim > string_sim + 0.2:
        explanations.append(f"Semantic similarity ({semantic_sim:.2f}) much stronger than string similarity, " +
                           "suggesting conceptually similar entities with different naming")
    
    # Add pattern match explanations
    if patterns:
        for pattern_type, pattern_info in patterns.items():
            confidence = pattern_info.get('confidence', 0)
            if confidence > 0.8:
                explanations.append(f"STRONG pattern match: {pattern_type} with confidence {confidence:.2f}")
            elif confidence > 0.6:
                explanations.append(f"Pattern match: {pattern_type} with confidence {confidence:.2f}")
    
    # Add specific feature explanations
    contains_score = features['contains_score']
    if contains_score > 0.9:
        explanations.append("One name fully contains the other")
    elif contains_score > 0.7:
        explanations.append("One name partially contains the other")
    
    acronym_score = features['acronym_score'] 
    if acronym_score > 0.9:
        explanations.append("Strong acronym relationship detected")
    elif acronym_score > 0.7:
        explanations.append("Possible acronym relationship detected")
    
    # Add length difference explanation if significant
    len_ratio = features.get('length_ratio', 0)
    if len_ratio < 0.5:
        explanations.append(f"Significant length difference (ratio: {len_ratio:.2f})")
    
    # Add domain-specific explanation if available
    if domain:
        explanations.append(f"Analysis performed in {domain} industry context")
    
    # Add match level explanation with appropriate emphasis
    if match_level == 'high_match':
        explanations.append("✓ HIGH CONFIDENCE MATCH: Names very likely refer to the same merchant")
    elif match_level == 'medium_match':
        explanations.append("MEDIUM CONFIDENCE MATCH: Names probably refer to the same merchant")
    elif match_level == 'low_match':
        explanations.append("LOW CONFIDENCE MATCH: Names might refer to the same merchant")
    else:
        explanations.append("✗ NO MATCH: Names likely refer to different merchants")
    
    # Add examples of similar merchants if using ML model and confidence is medium
    if self.model_trained and match_level == 'medium_match':
        explanations.append("This match is based on patterns observed in the training data.")
    
    return "\n".join(explanations)

In [24]:
# 7.5: Batch Processing with Parallelization

def match_merchant_batch(self, pairs_df, s1_col='s1', s2_col='s2', 
                         domain_col=None, return_details=False, 
                         batch_size=1000, use_multiprocessing=True,
                         n_jobs=-1):
    """
    Match a batch of merchant name pairs with parallel processing
    
    Args:
        pairs_df (DataFrame): DataFrame with merchant name pairs
        s1_col (str): Column name for first merchant names
        s2_col (str): Column name for second merchant names
        domain_col (str, optional): Column name for domain information
        return_details (bool): Whether to return detailed match info
        batch_size (int): Size of batches for processing
        use_multiprocessing (bool): Whether to use multiprocessing
        n_jobs (int): Number of jobs for parallel processing (-1 for all cores)
        
    Returns:
        DataFrame: Original DataFrame with added match scores and levels
    """
    import time
    start_time = time.time()
    
    # Copy input DataFrame
    result_df = pairs_df.copy()
    
    # Add match columns
    result_df['match_score'] = 0.0
    result_df['match_level'] = 'no_match'
    
    if return_details:
        result_df['match_details'] = None
    
    # Check for empty DataFrame
    if len(result_df) == 0:
        logger.warning("Empty DataFrame provided for batch matching")
        return result_df
    
    # Convert to records for faster access during parallel processing
    records = result_df.to_dict('records')
    total_records = len(records)
    
    # Function to process a single record
    def process_record(record):
        s1 = record[s1_col]
        s2 = record[s2_col]
        
        # Skip invalid entries
        if not isinstance(s1, str) or not isinstance(s2, str):
            return {
                'match_score': 0.0,
                'match_level': 'no_match',
                'match_details': None if return_details else None
            }
        
        # Get domain if column provided
        domain = record[domain_col] if domain_col and domain_col in record else None
        
        # Match merchants
        match_result = self.match_merchants(s1, s2, domain, return_details=return_details)
        
        if return_details:
            # Return score, level and details
            return {
                'match_score': match_result['match_score'],
                'match_level': match_result['match_level'],
                'match_details': match_result
            }
        else:
            # Just return score and derived level
            score = match_result
            level = self._determine_match_level(score)
            return {
                'match_score': score,
                'match_level': level
            }
    
    # Process in parallel if requested
    results = []
    
    if use_multiprocessing and total_records > 100:
        try:
            from joblib import Parallel, delayed
            
            # Determine number of jobs
            if n_jobs == -1:
                import multiprocessing
                n_jobs = multiprocessing.cpu_count()
            
            # Process in batches to save memory
            for i in range(0, total_records, batch_size):
                batch = records[i:min(i+batch_size, total_records)]
                logger.info(f"Processing batch {i//batch_size + 1}/{(total_records-1)//batch_size + 1} " +
                           f"({len(batch)} records)")
                
                batch_results = Parallel(n_jobs=n_jobs)(
                    delayed(process_record)(record) for record in batch
                )
                results.extend(batch_results)
                
        except Exception as e:
            logger.warning(f"Parallel processing failed with error: {e}. Falling back to sequential processing.")
            results = [process_record(record) for record in records]
    else:
        # Sequential processing
        for i, record in enumerate(records):
            if i % 100 == 0:
                logger.info(f"Processing record {i+1}/{total_records}")
            results.append(process_record(record))
    
    # Update the DataFrame with results
    for i, result in enumerate(results):
        for key, value in result.items():
            if key in result_df.columns:
                result_df.iloc[i, result_df.columns.get_loc(key)] = value
    
    end_time = time.time()
    processing_time = end_time - start_time
    records_per_second = total_records / processing_time if processing_time > 0 else 0
    
    logger.info(f"Batch processing completed in {processing_time:.2f} seconds " +
               f"({records_per_second:.2f} records/second)")
    logger.info(f"Cache performance: {self.cache_hits} hits")
    
    return result_df

In [26]:
# 7.6: Advanced Machine Learning Integration

def train_model(self, training_data, target_column='is_match', 
                test_size=0.2, random_state=42, use_advanced_model=True,
                optimize_hyperparams=False):
    """
    Train a machine learning model to improve matching accuracy,
    with options for advanced models and hyperparameter optimization
    
    Args:
        training_data (DataFrame): DataFrame with merchant name pairs and match labels
        target_column (str): Column name containing match labels (1=match, 0=no match)
        test_size (float): Proportion of data to use for testing
        random_state (int): Random seed for reproducibility
        use_advanced_model (bool): Whether to use a more advanced ensemble model
        optimize_hyperparams (bool): Whether to optimize hyperparameters
        
    Returns:
        dict: Training results and metrics
    """
    try:
        # Import required libraries
        import xgboost as xgb
        from sklearn.model_selection import train_test_split, GridSearchCV
        from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
        from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
        from sklearn.preprocessing import StandardScaler
    except ImportError:
        logger.error("Required packages not installed. Install xgboost and scikit-learn.")
        return {'success': False, 'error': 'Required packages not installed'}
    
    try:
        # Check if training data has required columns
        required_cols = ['s1', 's2', target_column]
        if not all(col in training_data.columns for col in required_cols):
            return {
                'success': False, 
                'error': f"Training data missing required columns: {required_cols}"
            }
        
        # Extract features for each merchant pair
        logger.info(f"Extracting features from training data with {len(training_data)} rows...")
        features = []
        labels = []
        
        domain_col = 'domain' if 'domain' in training_data.columns else None
        
        # Use more informative progress reporting for large datasets
        total_rows = len(training_data)
        for i, row in enumerate(training_data.iterrows()):
            idx, row_data = row
            s1 = row_data['s1']
            s2 = row_data['s2']
            domain = row_data[domain_col] if domain_col else None
            
            # Progress reporting
            if i % max(1, total_rows // 20) == 0:
                logger.info(f"Processing row {i+1}/{total_rows} ({(i+1)/total_rows*100:.1f}%)")
            
            # Skip rows with empty values
            if not isinstance(s1, str) or not isinstance(s2, str):
                continue
            
            # Extract features - use detailed version to get all features
            match_details = self.match_merchants(s1, s2, domain, return_details=True)
            
            # Get feature vector
            feature_dict = match_details['features']
            
            # Add derived features for better model performance
            feature_dict['string_token_diff'] = abs(feature_dict['string_similarity'] - 
                                                  feature_dict.get('token_set_similarity', 0))
            feature_dict['string_semantic_diff'] = abs(feature_dict['string_similarity'] - 
                                                     feature_dict.get('semantic_similarity', 0))
            feature_dict['has_pattern'] = 1.0 if match_details.get('patterns', {}) else 0.0
            
            features.append(feature_dict)
            labels.append(int(row_data[target_column]))
        
        # Convert features to DataFrame for better handling
        feature_df = pd.DataFrame(features)
        
        # Store feature names for later use
        self.feature_names = list(feature_df.columns)
        
        # Convert to numpy arrays
        X = feature_df.values
        y = np.array(labels)
        
        # Scale features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X_scaled, y, test_size=test_size, random_state=random_state, stratify=y
        )
        
        # Report class distribution
        train_pos = sum(y_train)
        train_neg = len(y_train) - train_pos
        test_pos = sum(y_test)
        test_neg = len(y_test) - test_pos
        
        logger.info(f"Training set: {len(y_train)} samples, {train_pos} positive, {train_neg} negative " +
                   f"({train_pos/len(y_train)*100:.1f}% positive)")
        logger.info(f"Test set: {len(y_test)} samples, {test_pos} positive, {test_neg} negative " +
                   f"({test_pos/len(y_test)*100:.1f}% positive)")
        
        # Select model
        if use_advanced_model:
            # Use a more advanced ensemble approach with multiple models
            models = {
                'xgboost': xgb.XGBClassifier(
                    objective='binary:logistic',
                    n_estimators=100,
                    max_depth=5,
                    learning_rate=0.1,
                    subsample=0.8,
                    colsample_bytree=0.8,
                    random_state=random_state
                ),
                'gradient_boosting': GradientBoostingClassifier(
                    n_estimators=100,
                    learning_rate=0.1,
                    max_depth=5,
                    random_state=random_state
                ),
                'random_forest': RandomForestClassifier(
                    n_estimators=100,
                    max_depth=10,
                    random_state=random_state
                )
            }
            
            # Train and evaluate all models
            model_performances = {}
            for name, model in models.items():
                logger.info(f"Training {name} model...")
                
                # Hyperparameter optimization if requested
                if optimize_hyperparams:
                    if name == 'xgboost':
                        param_grid = {
                            'n_estimators': [50, 100, 200],
                            'max_depth': [3, 5, 7],
                            'learning_rate': [0.01, 0.1, 0.2],
                            'subsample': [0.8, 1.0],
                            'colsample_bytree': [0.8, 1.0]
                        }
                    elif name == 'gradient_boosting':
                        param_grid = {
                            'n_estimators': [50, 100, 200],
                            'max_depth': [3, 5, 7],
                            'learning_rate': [0.01, 0.1, 0.2],
                            'subsample': [0.8, 1.0]
                        }
                    elif name == 'random_forest':
                        param_grid = {
                            'n_estimators': [50, 100, 200],
                            'max_depth': [5, 10, 15],
                            'min_samples_split': [2, 5, 10]
                        }
                    
                    logger.info(f"Optimizing hyperparameters for {name}...")
                    grid_search = GridSearchCV(
                        model, param_grid, cv=3, scoring='roc_auc', n_jobs=-1
                    )
                    grid_search.fit(X_train, y_train)
                    model = grid_search.best_estimator_
                    logger.info(f"Best parameters for {name}: {grid_search.best_params_}")
                else:
                    # Train with default parameters
                    model.fit(X_train, y_train)
                
                # Evaluate
                y_pred = model.predict(X_test)
                y_pred_proba = model.predict_proba(X_test)[:, 1]
                
                # Calculate metrics
                auc_score = roc_auc_score(y_test, y_pred_proba)
                classification_metrics = classification_report(y_test, y_pred, output_dict=True)
                conf_matrix = confusion_matrix(y_test, y_pred)
                
                logger.info(f"{name} model - AUC: {auc_score:.4f}, " +
                           f"Accuracy: {classification_metrics['accuracy']:.4f}")
                
                # Store performance
                model_performances[name] = {
                    'model': model,
                    'auc': auc_score,
                    'metrics': classification_metrics,
                    'confusion_matrix': conf_matrix
                }
            
            # Select best model based on AUC
            best_model_name = max(model_performances, key=lambda k: model_performances[k]['auc'])
            best_model_data = model_performances[best_model_name]
            self.ml_model = best_model_data['model']
            
            logger.info(f"Selected {best_model_name} as best model with AUC: {best_model_data['auc']:.4f}")
            
            # Get feature importance
            if hasattr(self.ml_model, 'feature_importances_'):
                feature_importance = dict(zip(self.feature_names, self.ml_model.feature_importances_))
                # Sort by importance
                feature_importance = {k: v for k, v in sorted(
                    feature_importance.items(), key=lambda item: item[1], reverse=True
                )}
                self.feature_importance = feature_importance
            else:
                self.feature_importance = {feature: 1.0/len(self.feature_names) 
                                         for feature in self.feature_names}
            
            # Update weights based on feature importance
            self._update_weights_from_model()
            
            # Set flag to indicate model is trained
            self.model_trained = True
            
            # Return results
            return {
                'success': True,
                'model_name': best_model_name,
                'auc_score': best_model_data['auc'],
                'classification_metrics': best_model_data['metrics'],
                'confusion_matrix': best_model_data['confusion_matrix'],
                'feature_importance': self.feature_importance,
                'feature_names': self.feature_names,
                'model_performances': {k: {
                    'auc': v['auc'], 
                    'metrics': v['metrics']
                } for k, v in model_performances.items()}
            }
            
        else:
            # Use simpler XGBoost model
            logger.info("Training XGBoost model...")
            self.ml_model = xgb.XGBClassifier(
                objective='binary:logistic',
                n_estimators=100,
                max_depth=5,
                learning_rate=0.1,
                subsample=0.8,
                colsample_bytree=0.8,
                random_state=random_state
            )
            
            # Train model
            self.ml_model.fit(X_train, y_train)
            
            # Evaluate on test set
            y_pred = self.ml_model.predict(X_test)
            y_pred_proba = self.ml_model.predict_proba(X_test)[:, 1]
            
            # Calculate metrics
            classification_metrics = classification_report(y_test, y_pred, output_dict=True)
            auc_score = roc_auc_score(y_test, y_pred_proba)
            conf_matrix = confusion_matrix(y_test, y_pred)
            
            # Get feature importance
            feature_importance = dict(zip(self.feature_names, self.ml_model.feature_importances_))
            # Sort by importance
            feature_importance = {k: v for k, v in sorted(
                feature_importance.items(), key=lambda item: item[1], reverse=True
            )}
            self.feature_importance = feature_importance
            
            # Update weights based on feature importance
            self._update_weights_from_model()
            
            # Set flag to indicate model is trained
            self.model_trained = True
            
            logger.info(f"Model trained successfully. AUC: {auc_score:.4f}")
            
            return {
                'success': True,
                'model_name': 'xgboost',
                'auc_score': auc_score,
                'classification_metrics': classification_metrics,
                'confusion_matrix': conf_matrix,
                'feature_importance': self.feature_importance,
                'feature_names': self.feature_names
            }
            
    except Exception as e:
        logger.error(f"Error training model: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return {'success': False, 'error': str(e)}

def _update_weights_from_model(self):
    """Update weights based on learned feature importance"""
    if not self.feature_importance:
        return
    
    # Map feature importance to weights
    importance_to_weight = {
        'string_similarity': 'string_similarity',
        'token_set_similarity': 'token_set_similarity',
        'semantic_similarity': 'semantic_similarity',
        'pattern_score': 'pattern_match',
        'contains_score': 'contains_check',
        'acronym_score': 'acronym_check'
    }
    
    # Calculate new weights
    new_weights = {}
    for feature, importance in self.feature_importance.items():
        if feature in importance_to_weight:
            weight_key = importance_to_weight[feature]
            new_weights[weight_key] = importance
    
    # Normalize weights
    total = sum(new_weights.values())
    if total > 0:
        new_weights = {k: v/total for k, v in new_weights.items()}
        
        # Update weights that exist in both
        for k in set(self.weights.keys()) & set(new_weights.keys()):
            # Blend original and learned weights (70% learned, 30% original)
            self.weights[k] = 0.7 * new_weights[k] + 0.3 * self.weights[k]
        
        logger.info(f"Updated weights based on model importance: {self.weights}")

In [28]:
# 7.7: Advanced Finding and Threshold Tuning

def find_merchant_matches(self, query, candidates, domain=None, top_k=5, 
                          threshold=0.6, return_details=False, use_cache=True):
    """
    Find best matches for a query merchant name from a list of candidates,
    with optimized search for large candidate lists
    
    Args:
        query (str): Query merchant name
        candidates (list): List of candidate merchant names
        domain (str, optional): Domain for specialized matching
        top_k (int): Number of top matches to return
        threshold (float): Minimum score threshold
        return_details (bool): Whether to return detailed match info
        use_cache (bool): Whether to use the match cache
        
    Returns:
        list: List of tuples with match information
    """
    # Check for empty query
    if not query:
        return []
    
    # Check for empty candidates list
    if not candidates:
        return []
    
    # Preprocess query
    query_clean = self.preprocessor.preprocess(query, domain)
    
    # Check if query is empty after preprocessing
    if not query_clean:
        return []
    
    # For very large candidate lists, use optimization strategies
    if len(candidates) > 1000 and self.bert_embedder.initialized:
        return self._find_merchant_matches_optimized(
            query, query_clean, candidates, domain, top_k, threshold, return_details
        )
    
    # Calculate match scores for all candidates
    matches = []
    
    for candidate in candidates:
        # Skip empty candidates
        if not candidate:
            continue
        
        # Match with the query
        if return_details:
            result = self.match_merchants(query_clean, candidate, domain, return_details=True, use_cache=use_cache)
            score = result['match_score']
            
            # Add to matches if above threshold
            if score >= threshold:
                matches.append((candidate, score, result['match_level'], result))
        else:
            score = self.match_merchants(query_clean, candidate, domain, return_details=False, use_cache=use_cache)
            
            # Add to matches if above threshold
            if score >= threshold:
                level = self._determine_match_level(score)
                matches.append((candidate, score, level))
    
    # Sort by score descending and take top-k
    matches.sort(key=lambda x: x[1], reverse=True)
    return matches[:top_k]

def _find_merchant_matches_optimized(self, query, query_clean, candidates, 
                                     domain, top_k, threshold, return_details):
    """
    Optimized version of find_merchant_matches for large candidate lists
    using semantic embeddings for pre-filtering
    
    Args:
        query (str): Original query merchant name
        query_clean (str): Preprocessed query
        candidates (list): List of candidate merchant names
        domain (str): Domain for specialized matching
        top_k (int): Number of top matches to return
        threshold (float): Minimum score threshold
        return_details (bool): Whether to return detailed match info
        
    Returns:
        list: List of tuples with match information
    """
    import time
    start_time = time.time()
    
    # Step 1: Preprocess all candidates
    logger.info(f"Preprocessing {len(candidates)} candidates...")
    candidates_clean = [self.preprocessor.preprocess(c, domain) for c in candidates]
    
    # Filter out empty candidates after preprocessing
    valid_indices = [i for i, c in enumerate(candidates_clean) if c]
    valid_candidates = [candidates[i] for i in valid_indices]
    valid_candidates_clean = [candidates_clean[i] for i in valid_indices]
    
    # Step 2: Use semantic embeddings for initial filtering
    logger.info("Computing semantic embeddings for candidates...")
    try:
        # Get query embedding
        query_embedding = self.bert_embedder.encode([query_clean])[0]
        
        # Get candidate embeddings in batches
        batch_size = 1000
        all_candidate_embeddings = []
        
        for i in range(0, len(valid_candidates_clean), batch_size):
            batch = valid_candidates_clean[i:i+batch_size]
            batch_embeddings = self.bert_embedder.encode(batch)
            all_candidate_embeddings.append(batch_embeddings)
        
        candidate_embeddings = np.vstack(all_candidate_embeddings)
        
        # Calculate cosine similarities
        # Normalize query embedding
        query_embedding_norm = query_embedding / np.linalg.norm(query_embedding)
        
        # Normalize candidate embeddings
        candidate_norms = np.linalg.norm(candidate_embeddings, axis=1, keepdims=True)
        normalized_candidates = candidate_embeddings / np.maximum(candidate_norms, 1e-10)
        
        # Calculate similarities
        similarities = np.dot(normalized_candidates, query_embedding_norm)
        
        # Step 3: Select candidates with similarity above relaxed threshold
        # Use a more relaxed threshold for pre-filtering to avoid missing matches
        relaxed_threshold = max(0.4, threshold - 0.2)
        prefilter_indices = np.where(similarities >= relaxed_threshold)[0]
        
        logger.info(f"Pre-filtered to {len(prefilter_indices)} candidates in {time.time() - start_time:.2f} seconds")
        
        # If too few candidates pass the filter, take the top candidates
        if len(prefilter_indices) < top_k * 2:
            # Take at least top_k*2 candidates
            top_indices = np.argsort(similarities)[::-1][:top_k * 2]
            prefilter_indices = top_indices
        
        # Step 4: Run full matching algorithm on pre-filtered candidates
        prefiltered_candidates = [valid_candidates[i] for i in prefilter_indices]
        
        # Use the standard matching for the reduced candidate set
        standard_matches = []
        
        for candidate in prefiltered_candidates:
            if return_details:
                result = self.match_merchants(query_clean, candidate, domain, return_details=True)
                score = result['match_score']
                
                if score >= threshold:
                    standard_matches.append((candidate, score, result['match_level'], result))
            else:
                score = self.match_merchants(query_clean, candidate, domain, return_details=False)
                
                if score >= threshold:
                    level = self._determine_match_level(score)
                    standard_matches.append((candidate, score, level))
        
        # Sort by score and take top-k
        standard_matches.sort(key=lambda x: x[1], reverse=True)
        
        logger.info(f"Found {len(standard_matches)} matches above threshold {threshold} " +
                   f"in {time.time() - start_time:.2f} seconds")
        
        return standard_matches[:top_k]
        
    except Exception as e:
        logger.warning(f"Optimized matching failed with error: {e}. Falling back to standard approach.")
        # Fall back to standard approach
        return self.find_merchant_matches(
            query, candidates, domain, top_k, threshold, return_details, use_cache=True
        )

def tune_thresholds(self, validation_data, target_column='is_match', domain_col=None):
    """
    Tune matching thresholds based on validation data to optimize
    precision/recall tradeoff
    
    Args:
        validation_data (DataFrame): DataFrame with merchant pairs and ground truth
        target_column (str): Column containing match labels (1=match, 0=no match)
        domain_col (str, optional): Column containing domain information
        
    Returns:
        dict: Optimized thresholds and metrics
    """
    try:
        from sklearn.metrics import precision_recall_curve, f1_score, roc_curve, auc, confusion_matrix
        import matplotlib.pyplot as plt
    except ImportError:
        logger.error("Required packages not installed. Install scikit-learn and matplotlib.")
        return {'success': False, 'error': 'Required packages not installed'}
    
    try:
        # Check if validation data has required columns
        required_cols = ['s1', 's2', target_column]
        if not all(col in validation_data.columns for col in required_cols):
            return {
                'success': False, 
                'error': f"Validation data missing required columns: {required_cols}"
            }
        
        # Get scores for all pairs
        logger.info(f"Calculating match scores for {len(validation_data)} validation pairs...")
        
        y_true = []
        y_scores = []
        domains = []
        
        for i, row in validation_data.iterrows():
            s1 = row['s1']
            s2 = row['s2']
            
            # Skip rows with empty values
            if not isinstance(s1, str) or not isinstance(s2, str):
                continue
            
            # Get domain if column provided
            domain = row[domain_col] if domain_col and domain_col in validation_data.columns else None
            domains.append(domain)
            
            # Get ground truth
            true_label = int(row[target_column])
            y_true.append(true_label)
            
            # Get score
            score = self.match_merchants(s1, s2, domain, return_details=False)
            y_scores.append(score)
        
        # Convert to numpy arrays
        y_true = np.array(y_true)
        y_scores = np.array(y_scores)
        
        # Calculate precision-recall curve
        precision, recall, thresholds_pr = precision_recall_curve(y_true, y_scores)
        
        # Find threshold that maximizes F1 score
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
        best_f1_idx = np.argmax(f1_scores[:-1])  # Exclude the last point which has no threshold
        best_f1_threshold = thresholds_pr[best_f1_idx]
        best_f1 = f1_scores[best_f1_idx]
        
        # Calculate ROC curve
        fpr, tpr, thresholds_roc = roc_curve(y_true, y_scores)
        roc_auc = auc(fpr, tpr)
        
        # Find threshold for different precision/recall targets
        high_precision_idx = np.argmax(precision >= 0.95)
        high_precision_threshold = thresholds_pr[high_precision_idx] if high_precision_idx < len(thresholds_pr) else 0.9
        
        high_recall_idx = np.argmax(recall[::-1] >= 0.95)
        high_recall_threshold = thresholds_pr[len(thresholds_pr) - 1 - high_recall_idx] if high_recall_idx < len(thresholds_pr) else 0.6
        
        # Calculate confusion matrices at different thresholds
        def get_confusion_matrix(threshold):
            y_pred = (y_scores >= threshold).astype(int)
            return confusion_matrix(y_true, y_pred)
        
        # Set new thresholds
        new_thresholds = {
            'high': high_precision_threshold,
            'medium': best_f1_threshold,
            'low': high_recall_threshold
        }
        
        # Generate performance metrics at different thresholds
        threshold_metrics = {}
        for name, threshold in new_thresholds.items():
            y_pred = (y_scores >= threshold).astype(int)
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            
            if tp + fp > 0:
                prec = tp / (tp + fp)
            else:
                prec = 0
                
            if tp + fn > 0:
                rec = tp / (tp + fn)
            else:
                rec = 0
                
            if prec + rec > 0:
                f1 = 2 * (prec * rec) / (prec + rec)
            else:
                f1 = 0
                
            threshold_metrics[name] = {
                'threshold': threshold,
                'precision': prec,
                'recall': rec,
                'f1_score': f1,
                'true_positives': int(tp),
                'false_positives': int(fp),
                'true_negatives': int(tn),
                'false_negatives': int(fn)
            }
        
        # Update the matcher's thresholds
        self.thresholds = new_thresholds
        logger.info(f"Updated thresholds: {self.thresholds}")
        
        # If the matcher has domain-specific data, analyze by domain
        domain_analysis = {}
        if domain_col and len(set(domains) - {None}) > 1:
            unique_domains = list(set(d for d in domains if d))
            
            for domain in unique_domains:
                domain_indices = [i for i, d in enumerate(domains) if d == domain]
                if len(domain_indices) < 10:  # Skip domains with too few samples
                    continue
                    
                domain_y_true = y_true[domain_indices]
                domain_y_scores = y_scores[domain_indices]
                
                # Calculate domain-specific precision-recall curve
                domain_precision, domain_recall, domain_thresholds = precision_recall_curve(
                    domain_y_true, domain_y_scores
                )
                
                # Find domain-specific optimal threshold
                if len(domain_thresholds) > 0:
                    domain_f1_scores = 2 * (domain_precision[:-1] * domain_recall[:-1]) / (domain_precision[:-1] + domain_recall[:-1] + 1e-10)
                    domain_best_idx = np.argmax(domain_f1_scores)
                    domain_best_threshold = domain_thresholds[domain_best_idx]
                    domain_best_f1 = domain_f1_scores[domain_best_idx]
                    
                    domain_analysis[domain] = {
                        'optimal_threshold': domain_best_threshold,
                        'f1_score': domain_best_f1,
                        'sample_count': len(domain_indices)
                    }
        
        # Generate and return comprehensive results
        result = {
            'success': True,
            'thresholds': self.thresholds,
            'metrics': threshold_metrics,
            'roc_auc': roc_auc,
            'best_f1': best_f1,
            'domain_analysis': domain_analysis if domain_analysis else None
        }
        
        # Generate visualizations if matplotlib is available
        if 'matplotlib.pyplot' in sys.modules:
            try:
                # Create a figure with precision-recall and ROC curves
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
                
                # Precision-Recall curve
                ax1.plot(recall, precision, 'b-', label=f'Precision-Recall (F1={best_f1:.3f})')
                ax1.scatter([recall[best_f1_idx]], [precision[best_f1_idx]], marker='o', color='red', s=100,
                           label=f'Best F1 threshold: {best_f1_threshold:.3f}')
                ax1.set_xlabel('Recall')
                ax1.set_ylabel('Precision')
                ax1.set_title('Precision-Recall Curve')
                ax1.legend()
                ax1.grid(True)
                
                # ROC curve
                ax2.plot(fpr, tpr, 'g-', label=f'ROC (AUC={roc_auc:.3f})')
                ax2.plot([0, 1], [0, 1], 'k--')  # Diagonal line
                ax2.set_xlabel('False Positive Rate')
                ax2.set_ylabel('True Positive Rate')
                ax2.set_title('ROC Curve')
                ax2.legend()
                ax2.grid(True)
                
                plt.tight_layout()
                
                # Save as BytesIO object to include in results
                import io
                buf = io.BytesIO()
                plt.savefig(buf, format='png')
                buf.seek(0)
                
                # Convert to base64 for easy display
                import base64
                result['visualization_base64'] = base64.b64encode(buf.read()).decode('utf-8')
                
                plt.close()
                
            except Exception as e:
                logger.warning(f"Failed to generate visualization: {e}")
        
        return result
        
    except Exception as e:
        logger.error(f"Error tuning thresholds: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return {'success': False, 'error': str(e)}

In [30]:
# 7.8: Domain Adaptation and Utility Methods

def adapt_to_domain(self, examples_df, domain=None, epochs=5):
    """
    Adapt all components to a specific domain using example data
    
    Args:
        examples_df (DataFrame): DataFrame with example merchant name pairs
        domain (str, optional): Domain name for specialization
        epochs (int): Number of training epochs for adaptation
        
    Returns:
        dict: Adaptation results including success metrics
    """
    if not isinstance(examples_df, pd.DataFrame):
        return {
            'success': False,
            'error': 'Invalid input: examples_df must be a pandas DataFrame'
        }
    
    if len(examples_df) < 5:
        return {
            'success': False,
            'error': 'Insufficient data: at least 5 examples needed for domain adaptation'
        }
    
    logger.info(f"Adapting merchant matcher to domain: {domain or 'general'} with {len(examples_df)} examples")
    
    results = {
        'success': True,
        'component_results': {}
    }
    
    # 1. Adapt BERT embedder
    if hasattr(self.bert_embedder, 'adapt_to_domain'):
        try:
            start_time = time.time()
            self.bert_embedder.adapt_to_domain(examples_df, epochs=epochs)
            bert_time = time.time() - start_time
            
            results['component_results']['bert_embedder'] = {
                'success': True,
                'adaptation_time': bert_time
            }
        except Exception as e:
            logger.warning(f"Failed to adapt BERT embedder: {e}")
            results['component_results']['bert_embedder'] = {
                'success': False,
                'error': str(e)
            }
    
    # 2. Adapt semantic analyzer
    try:
        start_time = time.time()
        self.semantic_analyzer.adapt_to_domain(examples_df, domain)
        semantic_time = time.time() - start_time
        
        results['component_results']['semantic_analyzer'] = {
            'success': True,
            'adaptation_time': semantic_time
        }
    except Exception as e:
        logger.warning(f"Failed to adapt semantic analyzer: {e}")
        results['component_results']['semantic_analyzer'] = {
            'success': False,
            'error': str(e)
        }
    
    # 3. Tune weights based on the domain if we have labeled data
    if 'is_match' in examples_df.columns or 'expected_match' in examples_df.columns:
        try:
            # Determine which column has match labels
            target_column = 'is_match' if 'is_match' in examples_df.columns else 'expected_match'
            
            # Only use a sample for weight tuning if the dataset is large
            tuning_data = examples_df.sample(min(500, len(examples_df))) if len(examples_df) > 500 else examples_df
            
            # Prepare s1/s2 columns if not already present
            if not ('s1' in tuning_data.columns and 's2' in tuning_data.columns):
                if 'Acronym' in tuning_data.columns and 'Full_Name' in tuning_data.columns:
                    tuning_data['s1'] = tuning_data['Acronym']
                    tuning_data['s2'] = tuning_data['Full_Name']
                elif 'input_name' in tuning_data.columns and 'matched_name' in tuning_data.columns:
                    tuning_data['s1'] = tuning_data['input_name']
                    tuning_data['s2'] = tuning_data['matched_name']
            
            # Tune weights if we have valid data
            if 's1' in tuning_data.columns and 's2' in tuning_data.columns:
                start_time = time.time()
                
                # Calculate optimal weights through grid search
                self._tune_weights(tuning_data, target_column, domain)
                
                weight_time = time.time() - start_time
                
                results['component_results']['weight_tuning'] = {
                    'success': True,
                    'adaptation_time': weight_time,
                    'new_weights': self.weights
                }
        except Exception as e:
            logger.warning(f"Failed to tune weights: {e}")
            results['component_results']['weight_tuning'] = {
                'success': False,
                'error': str(e)
            }
    
    # 4. Clear cache after adaptation
    self.cache = {}
    self.cache_hits = 0
    
    logger.info(f"Domain adaptation completed with results: {results}")
    return results

def _tune_weights(self, tuning_data, target_column, domain=None):
    """
    Tune weights for different similarity algorithms based on example data
    
    Args:
        tuning_data (DataFrame): DataFrame with merchant name pairs and match labels
        target_column (str): Column containing match labels
        domain (str, optional): Domain context
    """
    # Define weight combinations to try
    weight_combinations = []
    
    # Generate a range of weight combinations
    for semantic_w in [0.2, 0.3, 0.4]:
        for string_w in [0.1, 0.2, 0.3]:
            for pattern_w in [0.1, 0.2, 0.3]:
                for token_w in [0.1, 0.15, 0.2]:
                    # Calculate remaining weight
                    other_w = 1.0 - semantic_w - string_w - pattern_w - token_w
                    
                    # Ensure weights sum to 1.0 and are positive
                    if other_w >= 0:
                        # Split remaining weight between contains and acronym
                        acronym_w = other_w / 2
                        contains_w = other_w / 2
                        
                        weight_combinations.append({
                            'semantic_similarity': semantic_w,
                            'string_similarity': string_w,
                            'pattern_match': pattern_w,
                            'token_set_similarity': token_w,
                            'contains_check': contains_w,
                            'acronym_check': acronym_w
                        })
    
    # Add domain-specific weight combinations
    if domain:
        if domain.lower() in ['banking', 'financial']:
            weight_combinations.append({
                'semantic_similarity': 0.25,
                'string_similarity': 0.15,
                'pattern_match': 0.20,
                'token_set_similarity': 0.10,
                'contains_check': 0.10,
                'acronym_check': 0.20  # Higher for banking acronyms
            })
        elif domain.lower() in ['retail', 'store']:
            weight_combinations.append({
                'semantic_similarity': 0.30,
                'string_similarity': 0.25,
                'pattern_match': 0.25,
                'token_set_similarity': 0.10,
                'contains_check': 0.05,
                'acronym_check': 0.05
            })
        elif domain.lower() in ['restaurant', 'food']:
            weight_combinations.append({
                'semantic_similarity': 0.35,
                'string_similarity': 0.25,
                'pattern_match': 0.20,
                'token_set_similarity': 0.15,
                'contains_check': 0.05,
                'acronym_check': 0.0
            })
    
    # Evaluate each weight combination
    best_score = 0
    best_weights = None
    
    # Save original weights
    original_weights = self.weights.copy()
    
    # Try each combination
    for weights in weight_combinations:
        # Set weights
        self.weights = weights
        
        # Evaluate on tuning data
        scores = []
        true_labels = []
        
        for i, row in tuning_data.iterrows():
            s1 = row['s1']
            s2 = row['s2']
            
            # Skip rows with empty values
            if not isinstance(s1, str) or not isinstance(s2, str):
                continue
            
            # Get ground truth
            true_label = int(row[target_column])
            true_labels.append(true_label)
            
            # Calculate score with current weights
            score = self.match_merchants(s1, s2, domain, return_details=False, use_cache=False)
            scores.append(score)
        
        # Convert to numpy arrays
        y_true = np.array(true_labels)
        y_scores = np.array(scores)
        
        # Calculate AUC-ROC as the evaluation metric
        from sklearn.metrics import roc_auc_score
        try:
            auc_score = roc_auc_score(y_true, y_scores)
            
            if auc_score > best_score:
                best_score = auc_score
                best_weights = weights.copy()
        except:
            # Skip if AUC calculation fails (e.g., only one class)
            pass
    
    # Set weights to the best combination or restore original if no improvement
    if best_weights and best_score > 0.5:
        self.weights = best_weights
        logger.info(f"Weight tuning complete. Best AUC: {best_score:.4f} with weights: {best_weights}")
    else:
        self.weights = original_weights
        logger.warning("Weight tuning did not improve performance. Reverting to original weights.")

def save_model(self, filepath):
    """
    Save the trained matcher model to a file
    
    Args:
        filepath (str): Path to save the model
        
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        import pickle
        import os
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        # Prepare model data
        model_data = {
            'weights': self.weights,
            'thresholds': self.thresholds,
            'feature_names': self.feature_names,
            'feature_importance': self.feature_importance,
            'model_trained': self.model_trained
        }
        
        # Check if ML model is trained and can be pickled
        if self.model_trained and self.ml_model is not None:
            try:
                # Test if model can be pickled
                pickle.dumps(self.ml_model)
                model_data['ml_model'] = self.ml_model
            except:
                logger.warning("ML model could not be pickled. Saving weights only.")
                model_data['ml_model'] = None
        
        # Save to file
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
        
        logger.info(f"Model saved to {filepath}")
        return True
        
    except Exception as e:
        logger.error(f"Error saving model: {str(e)}")
        return False

def load_model(self, filepath):
    """
    Load a trained matcher model from a file
    
    Args:
        filepath (str): Path to the saved model
        
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        import pickle
        
        # Load from file
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        # Update model attributes
        self.weights = model_data.get('weights', self.weights)
        self.thresholds = model_data.get('thresholds', self.thresholds)
        self.feature_names = model_data.get('feature_names', None)
        self.feature_importance = model_data.get('feature_importance', None)
        self.model_trained = model_data.get('model_trained', False)
        
        # Load ML model if available
        if 'ml_model' in model_data and model_data['ml_model'] is not None:
            self.ml_model = model_data['ml_model']
        
        logger.info(f"Model loaded from {filepath}")
        return True
        
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return False

In [32]:
# 7.9: Interactive Evaluation and Visualization

def evaluate_match_interactively(self, s1, s2, domain=None):
    """
    Evaluate a merchant name match with detailed visualization
    of why the match was made and component contributions
    
    Args:
        s1 (str): First merchant name
        s2 (str): Second merchant name
        domain (str, optional): Domain for specialized matching
        
    Returns:
        dict: Comprehensive evaluation results
    """
    # Get detailed match results
    match_details = self.match_merchants(s1, s2, domain, return_details=True)
    
    # Extract key information
    score = match_details['match_score']
    level = match_details['match_level']
    features = match_details['features']
    patterns = match_details['patterns']
    explanation = match_details['explanation']
    
    # Get preprocessed forms
    s1_clean = match_details['processed_s1']
    s2_clean = match_details['processed_s2']
    
    # Create visualization if matplotlib is available
    visualization_data = None
    try:
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        import io
        import base64
        
        # Create a figure with multiple subplots
        fig = plt.figure(figsize=(12, 9))
        gs = fig.add_gridspec(3, 2)
        
        # 1. Feature contribution plot
        ax1 = fig.add_subplot(gs[0, 0])
        feature_names = []
        feature_values = []
        for fname, value in features.items():
            if fname in self.weights:
                feature_names.append(fname)
                feature_values.append(value)
        
        colors = ['#3498db' if v >= 0.7 else '#e74c3c' for v in feature_values]
        ax1.barh(feature_names, feature_values, color=colors)
        ax1.set_xlim(0, 1)
        ax1.set_title('Feature Contributions')
        ax1.axvline(x=0.7, color='gray', linestyle='--', alpha=0.5)
        
        # 2. Match summary with score visualization
        ax2 = fig.add_subplot(gs[0, 1])
        ax2.axis('off')
        
        # Create a horizontal gauge
        gauge_height = 0.3
        background = patches.Rectangle((0, 0.5-gauge_height/2), 1, gauge_height, 
                                       facecolor='#ecf0f1', edgecolor='#bdc3c7')
        ax2.add_patch(background)
        
        # Add color regions
        low_region = patches.Rectangle((0, 0.5-gauge_height/2), self.thresholds['low'], gauge_height, 
                                      facecolor='#e74c3c', edgecolor=None, alpha=0.7)
        med_region = patches.Rectangle((self.thresholds['low'], 0.5-gauge_height/2), 
                                      self.thresholds['medium']-self.thresholds['low'], 
                                      gauge_height, facecolor='#f39c12', edgecolor=None, alpha=0.7)
        high_region = patches.Rectangle((self.thresholds['medium'], 0.5-gauge_height/2), 
                                       self.thresholds['high']-self.thresholds['medium'], 
                                       gauge_height, facecolor='#2ecc71', edgecolor=None, alpha=0.7)
        v_high_region = patches.Rectangle((self.thresholds['high'], 0.5-gauge_height/2), 
                                         1-self.thresholds['high'], 
                                         gauge_height, facecolor='#27ae60', edgecolor=None, alpha=0.7)
        
        ax2.add_patch(low_region)
        ax2.add_patch(med_region)
        ax2.add_patch(high_region)
        ax2.add_patch(v_high_region)
        
        # Add score marker
        marker_width = 0.02
        score_marker = patches.Rectangle((score-marker_width/2, 0.5-gauge_height/2-0.05), 
                                        marker_width, gauge_height+0.1, 
                                        facecolor='black', edgecolor=None)
        ax2.add_patch(score_marker)
        
        # Add text
        ax2.text(0.5, 0.8, f"Match Score: {score:.3f}", ha='center', fontsize=14, weight='bold')
        ax2.text(0.5, 0.2, f"Match Level: {level.upper().replace('_', ' ')}", ha='center', fontsize=12)
        
        # Add threshold labels
        ax2.text(self.thresholds['low'], 0.5-gauge_height/2-0.1, 'Low', ha='center', fontsize=8)
        ax2.text(self.thresholds['medium'], 0.5-gauge_height/2-0.1, 'Medium', ha='center', fontsize=8)
        ax2.text(self.thresholds['high'], 0.5-gauge_height/2-0.1, 'High', ha='center', fontsize=8)
        
        # 3. Text comparison visualization
        ax3 = fig.add_subplot(gs[1, :])
        ax3.axis('off')
        
        # Split explanation into lines
        explanation_lines = explanation.split('\n')
        
        # Display original and preprocessed names
        ax3.text(0.02, 0.95, f"Original: '{s1}' vs '{s2}'", fontsize=12, weight='bold')
        ax3.text(0.02, 0.85, f"Preprocessed: '{s1_clean}' vs '{s2_clean}'", fontsize=12)
        
        # Display explanation text
        for i, line in enumerate(explanation_lines):
            if i < 10:  # Limit to 10 lines to avoid overflow
                y_pos = 0.75 - i * 0.05
                ax3.text(0.02, y_pos, line, fontsize=10)
        
        # 4. Pattern visualization if patterns exist
        ax4 = fig.add_subplot(gs[2, 0])
        if patterns:
            ax4.axis('off')
            ax4.text(0.5, 0.9, "Detected Patterns", ha='center', fontsize=12, weight='bold')
            
            y_pos = 0.8
            for pattern_type, pattern_info in patterns.items():
                if y_pos > 0.1:  # Avoid overflow
                    ax4.text(0.02, y_pos, f"• {pattern_type}", fontsize=10, weight='bold')
                    confidence = pattern_info.get('confidence', 0)
                    ax4.text(0.02, y_pos-0.05, f"  Confidence: {confidence:.2f}", fontsize=9)
                    if 'explanation' in pattern_info:
                        ax4.text(0.02, y_pos-0.1, f"  {pattern_info['explanation']}", fontsize=9)
                        y_pos -= 0.15
                    else:
                        y_pos -= 0.1
        else:
            ax4.axis('off')
            ax4.text(0.5, 0.5, "No patterns detected", ha='center', va='center', fontsize=12)
        
        # 5. Word similarity visualization (if text is not too long)
        ax5 = fig.add_subplot(gs[2, 1])
        if len(s1_clean.split()) <= 10 and len(s2_clean.split()) <= 10:
            ax5.axis('off')
            ax5.text(0.5, 0.9, "Word-level Similarity", ha='center', fontsize=12, weight='bold')
            
            # Get word-by-word similarities
            s1_words = s1_clean.split()
            s2_words = s2_clean.split()
            
            # Calculate similarity matrix
            import numpy as np
            sim_matrix = np.zeros((len(s1_words), len(s2_words)))
            
            for i, w1 in enumerate(s1_words):
                for j, w2 in enumerate(s2_words):
                    # Use Jaro-Winkler for word similarity
                    sim_matrix[i, j] = jaro_winkler(w1, w2)
            
            # Display words with connections
            for i, word in enumerate(s1_words):
                if i < 5:  # Limit to 5 words to avoid overflow
                    ax5.text(0.1, 0.8 - i*0.12, word, fontsize=10, ha='right')
            
            for j, word in enumerate(s2_words):
                if j < 5:  # Limit to 5 words to avoid overflow
                    ax5.text(0.9, 0.8 - j*0.12, word, fontsize=10, ha='left')
            
            # Draw connections for high similarities
            for i, w1 in enumerate(s1_words):
                if i >= 5:  # Skip if beyond display limit
                    continue
                    
                for j, w2 in enumerate(s2_words):
                    if j >= 5:  # Skip if beyond display limit
                        continue
                        
                    sim = sim_matrix[i, j]
                    if sim > 0.7:  # Only show strong connections
                        # Draw line with alpha based on similarity
                        ax5.plot([0.11, 0.89], [0.8 - i*0.12, 0.8 - j*0.12], 
                               alpha=sim*0.7, color='#3498db', linewidth=sim*2)
        else:
            ax5.axis('off')
            ax5.text(0.5, 0.5, "Texts too long for\nword visualization", 
                   ha='center', va='center', fontsize=12)
        
        plt.tight_layout()
        
        # Save as BytesIO object
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        
        # Convert to base64 for easy display
        visualization_data = base64.b64encode(buf.read()).decode('utf-8')
        
        plt.close()
        
    except Exception as e:
        logger.warning(f"Failed to generate visualization: {e}")
        import traceback
        logger.debug(traceback.format_exc())
    
    # Return comprehensive evaluation
    result = {
        'match_score': score,
        'match_level': level,
        'explanation': explanation,
        'features': features,
        'patterns': patterns,
        's1_original': s1,
        's2_original': s2,
        's1_processed': s1_clean,
        's2_processed': s2_clean,
        'thresholds': self.thresholds,
        'weights': self.weights,
        'visualization_base64': visualization_data
    }
    
    return result

In [34]:
# Cell 8: Main Execution Pipeline

class GMARTMerchantMatchingPipeline:
    """
    Comprehensive pipeline for merchant name matching that integrates all components
    into a unified workflow with configuration, preprocessing, matching, and evaluation.
    
    Key features:
    - Configuration management for different matching scenarios
    - Integrated preprocessing, matching, and evaluation
    - Error handling and robust logging
    - Support for different deployment environments
    - Performance metrics and reporting
    """
    
    def __init__(self, config_path=None, domain=None, use_ml=True, 
                 debug_mode=False, log_level='INFO'):
        """
        Initialize the merchant matching pipeline with configurable components
        
        Args:
            config_path (str, optional): Path to configuration file
            domain (str, optional): Default domain for specialized matching
            use_ml (bool): Whether to use machine learning enhancement
            debug_mode (bool): Enable additional logging and diagnostics
            log_level (str): Logging level (DEBUG, INFO, WARNING, ERROR)
        """
        # Configure logging
        self._setup_logging(log_level)
        
        # Set instance attributes
        self.config_path = config_path
        self.domain = domain
        self.use_ml = use_ml
        self.debug_mode = debug_mode
        
        # Load configuration if provided
        self.config = self._load_configuration(config_path)
        
        # Initialize component tracker for lazy loading
        self._initialized_components = {}
        self._merchant_matcher = None
        
        logger.info(f"GMART Merchant Matching Pipeline initialized with domain: {domain}, ML: {use_ml}")
    
    def _setup_logging(self, log_level):
        """Configure logging based on specified level"""
        log_format = '%(asctime)s - %(levelname)s - %(message)s'
        
        # Set log level based on input
        numeric_level = getattr(logging, log_level.upper(), None)
        if not isinstance(numeric_level, int):
            numeric_level = logging.INFO
            
        # Configure root logger
        logging.basicConfig(level=numeric_level, format=log_format)
        
        # Create a file handler for persistent logging
        try:
            file_handler = logging.FileHandler('gmart_matching.log')
            file_handler.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(file_handler)
        except:
            logging.warning("Could not create log file. Continuing with console logging only.")
    
    def _load_configuration(self, config_path):
        """
        Load configuration from file with fallback to defaults
        
        Args:
            config_path (str): Path to configuration file (JSON or YAML)
            
        Returns:
            dict: Configuration settings
        """
        # Default configuration
        default_config = {
            "thresholds": {
                "high": 0.85,
                "medium": 0.75,
                "low": 0.60
            },
            "weights": {
                "string_similarity": 0.20,
                "token_set_similarity": 0.15,
                "semantic_similarity": 0.30,
                "pattern_match": 0.25,
                "contains_check": 0.05,
                "acronym_check": 0.05
            },
            "preprocessing": {
                "remove_business_suffixes": True,
                "normalize_special_merchants": True,
                "expand_abbreviations": True
            },
            "feature_extraction": {
                "use_semantic": True,
                "use_phonetic": True,
                "use_acronym_detection": True
            },
            "performance": {
                "batch_size": 1000,
                "use_multiprocessing": True,
                "cache_size": 10000
            },
            "domains": {
                "banking": {
                    "weights": {
                        "acronym_check": 0.15,
                        "semantic_similarity": 0.30
                    }
                },
                "retail": {
                    "weights": {
                        "pattern_match": 0.30,
                        "string_similarity": 0.25
                    }
                },
                "restaurant": {
                    "weights": {
                        "semantic_similarity": 0.35,
                        "pattern_match": 0.25
                    }
                }
            }
        }
        
        # Return defaults if no config path
        if not config_path:
            return default_config
            
        try:
            # Load from file based on extension
            if config_path.endswith('.json'):
                import json
                with open(config_path, 'r') as f:
                    loaded_config = json.load(f)
            elif config_path.endswith(('.yaml', '.yml')):
                try:
                    import yaml
                    with open(config_path, 'r') as f:
                        loaded_config = yaml.safe_load(f)
                except ImportError:
                    logger.warning("YAML package not available. Falling back to default configuration.")
                    return default_config
            else:
                logger.warning(f"Unsupported configuration file format: {config_path}")
                return default_config
                
            # Merge with defaults (recursive)
            return self._merge_configs(default_config, loaded_config)
            
        except Exception as e:
            logger.error(f"Error loading configuration from {config_path}: {str(e)}")
            return default_config
    
    def _merge_configs(self, default_config, custom_config):
        """
        Recursively merge custom configuration with defaults
        
        Args:
            default_config (dict): Default configuration dictionary
            custom_config (dict): Custom configuration to override defaults
            
        Returns:
            dict: Merged configuration
        """
        result = default_config.copy()
        
        for key, value in custom_config.items():
            # If both are dicts, merge recursively
            if key in result and isinstance(result[key], dict) and isinstance(value, dict):
                result[key] = self._merge_configs(result[key], value)
            # Otherwise override with custom value
            else:
                result[key] = value
                
        return result
    
    def _get_merchant_matcher(self):
        """
        Lazy loading of the EnhancedMerchantMatcher to save resources
        
        Returns:
            EnhancedMerchantMatcher: The initialized matcher
        """
        if self._merchant_matcher is None:
            # Initialize all components from configuration
            thresholds = self.config.get('thresholds', {})
            weights = self.config.get('weights', {})
            
            # Adjust weights based on domain if specified
            if self.domain and self.domain in self.config.get('domains', {}):
                domain_weights = self.config['domains'][self.domain].get('weights', {})
                # Update only specified weights, keeping others
                for k, v in domain_weights.items():
                    weights[k] = v
            
            # Initialize BERT embedder
            bert_embedder = AdvancedBERTEmbedder(
                pooling_strategy='mean',
                cache_size=self.config.get('performance', {}).get('cache_size', 10000)
            )
            
            # Initialize preprocessor with config
            preproc_config = self.config.get('preprocessing', {})
            preprocessor = MerchantPreprocessor()
            
            # Initialize similarity algorithms
            sim_algorithms = SimilarityAlgorithms(
                preprocessor=preprocessor,
                bert_embedder=bert_embedder
            )
            
            # Initialize pattern recognition
            pattern_recognition = PatternRecognition(
                preprocessor=preprocessor,
                similarity_algorithms=sim_algorithms
            )
            
            # Initialize semantic analyzer
            semantic_analyzer = BertSemanticAnalyzer(
                bert_embedder=bert_embedder,
                similarity_algorithms=sim_algorithms,
                preprocessor=preprocessor
            )
            
            # Initialize the merchant matcher
            self._merchant_matcher = EnhancedMerchantMatcher(
                preprocessor=preprocessor, 
                similarity_algorithms=sim_algorithms,
                pattern_recognition=pattern_recognition,
                semantic_analyzer=semantic_analyzer,
                bert_embedder=bert_embedder,
                weights=weights,
                thresholds=thresholds
            )
            
            # Track components
            self._initialized_components = {
                'preprocessor': preprocessor,
                'bert_embedder': bert_embedder,
                'similarity_algorithms': sim_algorithms,
                'pattern_recognition': pattern_recognition,
                'semantic_analyzer': semantic_analyzer
            }
            
            logger.info("Merchant matcher initialized with all components")
            
            # Training if ML model available
            if self.use_ml and hasattr(self, '_training_data') and self._training_data is not None:
                try:
                    logger.info("Training ML model with provided data...")
                    self._merchant_matcher.train_model(self._training_data)
                except Exception as e:
                    logger.warning(f"Error training ML model: {e}")
        
        return self._merchant_matcher
    
    def match_merchants(self, s1, s2, domain=None, return_details=False):
        """
        Match two merchant names
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized matching
            return_details (bool): Whether to return detailed match info
            
        Returns:
            float or dict: Match score or detailed match information
        """
        # Get domain (use instance default if not specified)
        effective_domain = domain if domain is not None else self.domain
        
        # Get matcher and perform match
        matcher = self._get_merchant_matcher()
        return matcher.match_merchants(s1, s2, effective_domain, return_details)
    
    def match_batch(self, data, s1_col='s1', s2_col='s2', domain_col=None, output_path=None):
        """
        Process a batch of merchant name pairs
        
        Args:
            data (DataFrame or str): Pandas DataFrame or path to CSV file
            s1_col (str): Column name for first merchant names
            s2_col (str): Column name for second merchant names
            domain_col (str, optional): Column name for domain information
            output_path (str, optional): Path to save results
            
        Returns:
            DataFrame: Results with match scores and levels
        """
        # Load data if string path provided
        if isinstance(data, str):
            try:
                data = pd.read_csv(data)
                logger.info(f"Loaded data from {data} with {len(data)} rows")
            except Exception as e:
                logger.error(f"Error loading data from {data}: {e}")
                return None
        
        # Ensure required columns exist
        if s1_col not in data.columns or s2_col not in data.columns:
            logger.error(f"Required columns not found in data: {s1_col}, {s2_col}")
            return None
            
        # Get matcher and process batch
        matcher = self._get_merchant_matcher()
        
        # Get batch size from config
        batch_size = self.config.get('performance', {}).get('batch_size', 1000)
        use_multiprocessing = self.config.get('performance', {}).get('use_multiprocessing', True)
        
        # Process batch
        try:
            results = matcher.match_merchant_batch(
                data, 
                s1_col=s1_col, 
                s2_col=s2_col,
                domain_col=domain_col,
                return_details=False,
                batch_size=batch_size,
                use_multiprocessing=use_multiprocessing
            )
            
            # Save results if output path provided
            if output_path:
                results.to_csv(output_path, index=False)
                logger.info(f"Results saved to {output_path}")
                
            return results
            
        except Exception as e:
            logger.error(f"Error processing batch: {e}")
            import traceback
            logger.debug(traceback.format_exc())
            return None
    
    def find_matches(self, query, candidates, top_k=5, threshold=0.6, domain=None):
        """
        Find top matches for a query merchant name from a list of candidates
        
        Args:
            query (str): Query merchant name
            candidates (list): List of candidate merchant names
            top_k (int): Number of top matches to return
            threshold (float): Minimum score threshold
            domain (str, optional): Domain for specialized matching
            
        Returns:
            list: Top matches with scores and levels
        """
        # Get domain (use instance default if not specified)
        effective_domain = domain if domain is not None else self.domain
        
        # Get matcher and find matches
        matcher = self._get_merchant_matcher()
        return matcher.find_merchant_matches(
            query, 
            candidates, 
            effective_domain, 
            top_k, 
            threshold
        )
    
    def train_with_data(self, training_data, target_column='is_match', test_size=0.2):
        """
        Train the matcher with labeled data to improve accuracy
        
        Args:
            training_data (DataFrame): DataFrame with merchant pairs and match labels
            target_column (str): Column name containing match labels
            test_size (float): Proportion of data to use for testing
            
        Returns:
            dict: Training results and metrics
        """
        # Store training data for lazy initialization
        self._training_data = training_data
        
        # If matcher already initialized, train it directly
        if self._merchant_matcher is not None:
            return self._merchant_matcher.train_model(
                training_data, 
                target_column=target_column,
                test_size=test_size,
                use_advanced_model=True
            )
        
        # Otherwise training will happen on first matcher access
        return {"status": "pending", "message": "Training will occur when matcher is initialized"}
    
    def adapt_to_domain(self, examples_df, domain=None):
        """
        Adapt the matcher to a specific domain based on examples
        
        Args:
            examples_df (DataFrame): DataFrame with example merchant name pairs
            domain (str, optional): Domain name for specialization
            
        Returns:
            dict: Adaptation results
        """
        effective_domain = domain if domain is not None else self.domain
        
        # Get matcher and adapt
        matcher = self._get_merchant_matcher()
        return matcher.adapt_to_domain(examples_df, effective_domain)
    
    def tune_thresholds(self, validation_data, target_column='is_match'):
        """
        Tune matching thresholds based on validation data
        
        Args:
            validation_data (DataFrame): DataFrame with merchant pairs and ground truth
            target_column (str): Column containing match labels (1=match, 0=no match)
            
        Returns:
            dict: Optimized thresholds and metrics
        """
        # Get matcher and tune
        matcher = self._get_merchant_matcher()
        
        domain_col = None
        if self.domain is not None and 'domain' in validation_data.columns:
            domain_col = 'domain'
            
        return matcher.tune_thresholds(validation_data, target_column, domain_col)
    
    def save_model(self, filepath):
        """
        Save the trained model to a file
        
        Args:
            filepath (str): Path to save the model
            
        Returns:
            bool: Success status
        """
        matcher = self._get_merchant_matcher()
        return matcher.save_model(filepath)
    
    def load_model(self, filepath):
        """
        Load a trained model from a file
        
        Args:
            filepath (str): Path to the saved model
            
        Returns:
            bool: Success status
        """
        matcher = self._get_merchant_matcher()
        return matcher.load_model(filepath)
    
    def evaluate_match(self, s1, s2, domain=None):
        """
        Generate detailed evaluation of a merchant match
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized matching
            
        Returns:
            dict: Detailed evaluation results
        """
        effective_domain = domain if domain is not None else self.domain
        matcher = self._get_merchant_matcher()
        return matcher.evaluate_match_interactively(s1, s2, effective_domain)
    
    def get_component_status(self):
        """
        Get initialization status of all components
        
        Returns:
            dict: Component status information
        """
        # Basic status
        status = {
            "pipeline_initialized": True,
            "merchant_matcher_initialized": self._merchant_matcher is not None,
            "components_initialized": list(self._initialized_components.keys()) if self._initialized_components else [],
            "ml_model_trained": False,
            "domain": self.domain,
            "config_loaded": self.config_path is not None
        }
        
        # Add ML model status if matcher initialized
        if self._merchant_matcher is not None:
            status["ml_model_trained"] = getattr(self._merchant_matcher, "model_trained", False)
            
        return status

In [36]:
# Cell 9: Interactive Merchant Name Matching

class InteractiveMerchantMatcher:
    """
    Interactive interface for testing and analyzing merchant name matches
    with rich visualization and explanation capabilities.
    
    This class provides a user-friendly way to explore merchant name matching,
    analyze match decisions, and visualize the matching process.
    """
    
    def __init__(self, pipeline=None, config_path=None, domain=None):
        """
        Initialize the interactive matcher
        
        Args:
            pipeline (GMARTMerchantMatchingPipeline, optional): Existing pipeline
            config_path (str, optional): Path to configuration file
            domain (str, optional): Default domain for specialized matching
        """
        # Use existing pipeline or create a new one
        if pipeline:
            self.pipeline = pipeline
        else:
            self.pipeline = GMARTMerchantMatchingPipeline(
                config_path=config_path,
                domain=domain,
                debug_mode=True
            )
        
        # Store recent matches for history
        self.match_history = []
        self.max_history = 20
    
    def match(self, s1, s2, domain=None, visualize=True):
        """
        Match two merchant names and visualize the results
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized matching
            visualize (bool): Whether to generate visualization
            
        Returns:
            dict: Match results with visualization
        """
        # Check inputs
        if not s1 or not s2:
            print("Please provide non-empty merchant names")
            return None
        
        # Perform detailed match evaluation
        result = self.pipeline.evaluate_match(s1, s2, domain)
        
        # Add to history
        self._add_to_history(s1, s2, result['match_score'], result['match_level'], domain)
        
        # Display results
        self._display_results(result, visualize)
        
        return result
    
    def _add_to_history(self, s1, s2, score, level, domain):
        """Add match to history"""
        self.match_history.append({
            's1': s1,
            's2': s2,
            'score': score,
            'level': level,
            'domain': domain,
            'timestamp': time.time()
        })
        
        # Limit history size
        if len(self.match_history) > self.max_history:
            self.match_history = self.match_history[-self.max_history:]
    
    def _display_results(self, result, visualize):
        """Display match results with optional visualization"""
        # Print header
        print("\n" + "="*50)
        print(f"MERCHANT MATCH ANALYSIS")
        print("="*50)
        
        # Print basic match info
        print(f"\nOriginal Names:")
        print(f"  Name 1: '{result['s1_original']}'")
        print(f"  Name 2: '{result['s2_original']}'")
        
        print(f"\nPreprocessed Names:")
        print(f"  Name 1: '{result['s1_processed']}'")
        print(f"  Name 2: '{result['s2_processed']}'")
        
        print(f"\nMatch Results:")
        print(f"  Score: {result['match_score']:.3f}")
        print(f"  Level: {result['match_level'].upper().replace('_', ' ')}")
        
        # Print thresholds
        print("\nMatch Thresholds:")
        for level, threshold in result['thresholds'].items():
            print(f"  {level.capitalize()}: {threshold:.2f}")
        
        # Print top feature scores
        print("\nFeature Scores:")
        for name, score in sorted(result['features'].items(), key=lambda x: x[1], reverse=True)[:5]:
            print(f"  {name}: {score:.3f}")
        
        # Print explanation
        print("\nExplanation:")
        for line in result['explanation'].split('\n'):
            print(f"  {line}")
        
        # Display visualization if available and requested
        if visualize and result.get('visualization_base64'):
            try:
                self._display_visualization(result['visualization_base64'])
            except Exception as e:
                print(f"\nVisualization error: {e}")
    
    def _display_visualization(self, base64_image):
        """Display the visualization image"""
        try:
            import IPython.display as display
            from IPython.display import HTML
            import base64
            
            # Create HTML for image display
            html = f"""
            <div style="background-color: white; padding: 10px; border-radius: 5px;">
                <img src="data:image/png;base64,{base64_image}" style="max-width: 100%">
            </div>
            """
            
            # Display in notebook
            display.display(HTML(html))
            
        except ImportError:
            print("\nVisualization available but IPython display not available.")
            print("Run in Jupyter notebook to see visualization.")
    
    def compare_matches(self, merchant_name, candidates, domain=None, top_k=5, threshold=0.6):
        """
        Compare a merchant name against multiple candidates
        
        Args:
            merchant_name (str): Reference merchant name
            candidates (list): List of candidate merchant names
            domain (str, optional): Domain for specialized matching
            top_k (int): Number of top matches to display
            threshold (float): Minimum score threshold
            
        Returns:
            list: Top matches with scores
        """
        if not merchant_name or not candidates:
            print("Please provide a merchant name and candidates list")
            return []
        
        # Find matches
        matches = self.pipeline.find_matches(
            merchant_name, 
            candidates, 
            top_k=top_k, 
            threshold=threshold, 
            domain=domain
        )
        
        # Display results
        print("\n" + "="*50)
        print(f"TOP MERCHANT MATCHES FOR: '{merchant_name}'")
        print("="*50)
        
        print(f"\nFound {len(matches)} matches above threshold {threshold}")
        
        # Print matches in table format
        print("\n{:<5} {:<40} {:<10} {:<15}".format("Rank", "Merchant Name", "Score", "Match Level"))
        print("-"*75)
        
        for i, (name, score, level) in enumerate(matches, 1):
            print("{:<5} {:<40} {:<10.3f} {:<15}".format(
                i, name[:38], score, level.upper()
            ))
            
        return matches
    
    def show_match_history(self, limit=10):
        """
        Display recent match history
        
        Args:
            limit (int): Number of recent matches to display
            
        Returns:
            list: Recent match history
        """
        history = self.match_history[-limit:]
        
        if not history:
            print("No match history available")
            return []
        
        print("\n" + "="*80)
        print("MERCHANT MATCH HISTORY")
        print("="*80)
        
        # Print matches in table format
        print("\n{:<20} {:<20} {:<10} {:<15} {:<15}".format(
            "Merchant 1", "Merchant 2", "Score", "Match Level", "Domain"
        ))
        print("-"*80)
        
        for entry in reversed(history):
            print("{:<20} {:<20} {:<10.3f} {:<15} {:<15}".format(
                entry['s1'][:18], 
                entry['s2'][:18], 
                entry['score'], 
                entry['level'].upper(),
                entry['domain'] or 'None'
            ))
            
        return history
    
    def analyze_patterns(self, s1, s2, domain=None):
        """
        Analyze merchant name patterns in detail
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized matching
            
        Returns:
            dict: Pattern analysis
        """
        # Get component from pipeline
        matcher = self.pipeline._get_merchant_matcher()
        pattern_recognition = self.pipeline._initialized_components.get('pattern_recognition')
        
        if not pattern_recognition:
            print("Pattern recognition component not initialized")
            return None
        
        # Get patterns
        patterns = pattern_recognition.detect_merchant_patterns(s1, s2, domain)
        
        # Display results
        print("\n" + "="*50)
        print(f"PATTERN ANALYSIS")
        print("="*50)
        
        print(f"\nMerchant Names:")
        print(f"  Name 1: '{s1}'")
        print(f"  Name 2: '{s2}'")
        
        if not patterns:
            print("\nNo patterns detected between these merchant names")
            return {}
        
        print(f"\nDetected Patterns ({len(patterns)}):")
        
        for pattern_type, pattern_info in patterns.items():
            print(f"\n• {pattern_type.upper()}")
            print(f"  Confidence: {pattern_info.get('confidence', 0):.2f}")
            
            if 'explanation' in pattern_info:
                print(f"  Explanation: {pattern_info['explanation']}")
                
            # Print pattern details based on type
            if pattern_type == 'known_equivalent':
                print(f"  Canonical: {pattern_info.get('pattern', ('', ''))[0]}")
            elif 'pattern' in pattern_info:
                if isinstance(pattern_info['pattern'], list):
                    for i, p in enumerate(pattern_info['pattern']):
                        if isinstance(p, dict):
                            print(f"  Pattern {i+1}: {p.get('domain', '')} - {p.get('pattern_type', '')}")
                            if 'original' in p and 'normalized' in p:
                                print(f"    '{p['original']}' → '{p['normalized']}'")
        
        return patterns
    
    def analyze_semantic_similarity(self, s1, s2, domain=None):
        """
        Analyze semantic similarity in detail
        
        Args:
            s1 (str): First merchant name
            s2 (str): Second merchant name
            domain (str, optional): Domain for specialized matching
            
        Returns:
            dict: Semantic analysis
        """
        # Get components from pipeline
        semantic_analyzer = self.pipeline._initialized_components.get('semantic_analyzer')
        
        if not semantic_analyzer:
            print("Semantic analyzer component not initialized")
            return None
        
        # Analyze semantic match
        semantic_results = semantic_analyzer.analyze_semantic_match(s1, s2, domain)
        
        # Display results
        print("\n" + "="*50)
        print(f"SEMANTIC ANALYSIS")
        print("="*50)
        
        print(f"\nMerchant Names:")
        print(f"  Name 1: '{s1}'")
        print(f"  Name 2: '{s2}'")
        
        print(f"\nSemantic Similarity: {semantic_results['semantic_similarity']:.3f}")
        print(f"Match Level: {semantic_results['match_level'].upper().replace('_', ' ')}")
        
        # Print semantic matching details
        analysis = semantic_results.get('analysis', {})
        
        if 'word_match_ratio' in analysis:
            print(f"\nWord Match Ratio: {analysis['word_match_ratio']:.3f}")
            
        if 'matching_words' in analysis:
            print(f"Matching Words: {', '.join(analysis['matching_words'])}")
            
        if 'context_similarity' in analysis:
            print(f"Context Similarity: {analysis['context_similarity']:.3f}")
            
        if 'semantic_matching_points' in analysis:
            print(f"Semantic Matching Points: {', '.join(analysis['semantic_matching_points'])}")
            
        if 'potential_relationship' in analysis:
            print(f"\nPotential Relationship: {analysis['potential_relationship'].replace('_', ' ').title()}")
            print(f"Confidence: {analysis['confidence']:.3f}")
            
        # Print domain-specific analysis if available
        if 'domain_specific' in analysis:
            domain_info = analysis['domain_specific']
            print(f"\nDomain-Specific Analysis ({domain_info.get('industry', 'unknown')}):")
            print(f"  Subtype: {domain_info.get('subtype', 'general')}")
            
            for note in domain_info.get('notes', []):
                print(f"  • {note}")
        
        return semantic_results
    
    def explain_preprocessing(self, merchant_name, domain=None):
        """
        Explain preprocessing steps for a merchant name
        
        Args:
            merchant_name (str): Merchant name to preprocess
            domain (str, optional): Domain for specialized preprocessing
            
        Returns:
            dict: Preprocessing details
        """
        # Get preprocessor from pipeline
        preprocessor = self.pipeline._initialized_components.get('preprocessor')
        
        if not preprocessor:
            print("Preprocessor component not initialized")
            return None
        
        # Track original name
        original = merchant_name
        
        # Process step by step
        steps = []
        
        # Step 1: Basic cleanup
        step1 = merchant_name.lower().strip()
        if step1 != original:
            steps.append(("Convert to lowercase and trim", step1))
        
        # Step 2: Handle punctuation
        step2 = re.sub(r'([^a-z0-9\'\.\&\-])', ' ', step1)
        if step2 != step1:
            steps.append(("Remove most punctuation", step2))
        
        # Step 3: Handle apostrophes
        step3 = re.sub(r'\'s\b', 's', step2)
        step3 = re.sub(r'\'', '', step3)
        if step3 != step2:
            steps.append(("Normalize apostrophes", step3))
        
        # Step 4: Normalize spaces
        step4 = re.sub(r'\s+', ' ', step3).strip()
        if step4 != step3:
            steps.append(("Normalize spaces", step4))
        
        # Step 5: Remove business suffixes
        step5 = step4
        for pattern, replacement in preprocessor.business_suffixes.items():
            step5 = pattern.sub(replacement, step5)
        if step5 != step4:
            steps.append(("Remove business suffixes", step5))
        
        # Step 6: Apply business patterns
        step6 = step5
        for pattern, replacement_func in preprocessor.business_patterns.items():
            match = pattern.search(step6)
            if match:
                step6 = replacement_func(match).strip()
                break
        if step6 != step5:
            steps.append(("Apply business patterns", step6))
        
        # Step 7: Expand abbreviations
        words = step6.split()
        expanded_words = [preprocessor.abbreviations.get(word, word) for word in words]
        step7 = ' '.join(expanded_words)
        if step7 != step6:
            steps.append(("Expand abbreviations", step7))
        
        # Step 8: Remove stopwords
        filtered_words = [word for word in step7.split() if word not in preprocessor.stopwords]
        step8 = ' '.join(filtered_words)
        if step8 != step7:
            steps.append(("Remove stopwords", step8))
        
        # Final result
        final = preprocessor.preprocess(merchant_name, domain)
        if final != step8 and steps:
            steps.append(("Domain-specific processing", final))
        
        # Display results
        print("\n" + "="*50)
        print(f"PREPROCESSING ANALYSIS FOR: '{original}'")
        print("="*50)
        
        if not steps:
            print("\nNo preprocessing changes were applied")
        else:
            print("\nPreprocessing Steps:")
            for i, (description, result) in enumerate(steps, 1):
                print(f"\n{i}. {description}")
                print(f"   '{result}'")
        
        print(f"\nFinal preprocessed result: '{final}'")
        
        # Return preprocessing details
        return {
            'original': original,
            'final': final,
            'steps': steps,
            'domain': domain
        }

In [40]:
# Cell 10: Batch Processing and PySpark Adaptation
# I'll create a comprehensive implementation for batch processing and PySpark adaptation that integrates with the existing merchant matching system.

In [42]:
# 10.1: Scalable Batch Processing Framework

class BatchProcessor:
    """
    High-performance batch processing framework for merchant name matching
    with support for various input formats and parallel execution.
    
    Key features:
    - Multi-format support (CSV, Excel, JSON, Parquet)
    - Chunked processing for memory efficiency
    - Progress tracking and reporting
    - Automatic error recovery
    - Result aggregation and export
    """
    
    def __init__(self, matcher=None, chunk_size=10000, n_jobs=-1, 
                 output_format='csv', temp_dir=None):
        """
        Initialize batch processor with configurable parameters
        
        Args:
            matcher: Merchant matcher instance (EnhancedMerchantMatcher or pipeline)
            chunk_size (int): Number of records per processing chunk
            n_jobs (int): Number of parallel jobs (-1 for all cores)
            output_format (str): Format for output files ('csv', 'excel', 'json', 'parquet')
            temp_dir (str): Directory for temporary files
        """
        self.matcher = matcher
        self.chunk_size = chunk_size
        self.n_jobs = n_jobs
        self.output_format = output_format.lower()
        self.temp_dir = temp_dir or os.path.join(os.getcwd(), 'temp')
        
        # Create temp directory if it doesn't exist
        if not os.path.exists(self.temp_dir):
            os.makedirs(self.temp_dir)
            
        # Configure logging
        self.logger = logging.getLogger('BatchProcessor')
        self.logger.setLevel(logging.INFO)
        
        # Track processing metrics
        self.metrics = {
            'total_processed': 0,
            'successful_matches': 0,
            'failed_matches': 0,
            'processing_time': 0,
            'matches_per_second': 0
        }
    
    def process_file(self, input_file, output_file=None, 
                     s1_col='s1', s2_col='s2', domain_col=None, 
                     id_col=None, return_detailed=False):
        """
        Process a file containing merchant name pairs
        
        Args:
            input_file (str): Path to input file
            output_file (str, optional): Path to output file (generated if None)
            s1_col (str): Column name for first merchant names
            s2_col (str): Column name for second merchant names
            domain_col (str, optional): Column name for domain information
            id_col (str, optional): Column name for unique identifier
            return_detailed (bool): Whether to return detailed match info
            
        Returns:
            str: Path to output file with results
        """
        start_time = time.time()
        self.logger.info(f"Starting batch processing of {input_file}")
        
        # Generate output file name if not provided
        if output_file is None:
            file_base = os.path.splitext(os.path.basename(input_file))[0]
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            output_file = f"{file_base}_matched_{timestamp}.{self.output_format}"
            # If relative path, put in current directory
            if not os.path.isabs(output_file):
                output_file = os.path.join(os.getcwd(), output_file)
        
        # Determine file format and read data
        try:
            data_chunks = self._read_input_file(input_file, self.chunk_size)
        except Exception as e:
            self.logger.error(f"Error reading input file: {e}")
            raise
        
        # Process chunks with parallel execution
        processed_chunks = []
        chunk_paths = []
        
        self.logger.info("Processing data in chunks...")
        
        for i, chunk in enumerate(data_chunks):
            self.logger.info(f"Processing chunk {i+1}")
            
            # Process chunk
            if self.n_jobs == 1:
                # Sequential processing
                processed_chunk = self._process_chunk(
                    chunk, s1_col, s2_col, domain_col, id_col, return_detailed
                )
                processed_chunks.append(processed_chunk)
            else:
                # Parallel processing - save chunk to temp file
                chunk_path = os.path.join(self.temp_dir, f"chunk_{i}.csv")
                chunk.to_csv(chunk_path, index=False)
                chunk_paths.append(chunk_path)
        
        # If parallel processing, use joblib to process all chunks
        if self.n_jobs != 1 and chunk_paths:
            try:
                from joblib import Parallel, delayed
                
                # Process all chunks in parallel
                processed_chunks = Parallel(n_jobs=self.n_jobs)(
                    delayed(self._process_chunk_file)(
                        path, s1_col, s2_col, domain_col, id_col, return_detailed
                    ) for path in chunk_paths
                )
                
                # Clean up temp files
                for path in chunk_paths:
                    try:
                        os.remove(path)
                    except:
                        pass
                        
            except ImportError:
                self.logger.warning(
                    "joblib not available for parallel processing. Using sequential processing."
                )
                # Fall back to sequential processing
                processed_chunks = []
                for path in chunk_paths:
                    chunk = pd.read_csv(path)
                    processed_chunk = self._process_chunk(
                        chunk, s1_col, s2_col, domain_col, id_col, return_detailed
                    )
                    processed_chunks.append(processed_chunk)
                    
                    # Clean up temp file
                    try:
                        os.remove(path)
                    except:
                        pass
        
        # Combine all processed chunks
        if processed_chunks:
            result_df = pd.concat(processed_chunks, ignore_index=True)
            
            # Export results
            self._write_output_file(result_df, output_file)
            
            # Update metrics
            self.metrics['total_processed'] = len(result_df)
            self.metrics['successful_matches'] = len(
                result_df[result_df['match_level'] != 'no_match']
            )
            self.metrics['failed_matches'] = len(
                result_df[result_df['match_level'] == 'no_match']
            )
            self.metrics['processing_time'] = time.time() - start_time
            
            if self.metrics['processing_time'] > 0:
                self.metrics['matches_per_second'] = (
                    self.metrics['total_processed'] / self.metrics['processing_time']
                )
            
            self.logger.info(f"Batch processing completed: {len(result_df)} records processed")
            self.logger.info(
                f"Processing time: {self.metrics['processing_time']:.2f} seconds "
                f"({self.metrics['matches_per_second']:.2f} records/second)"
            )
            
            return output_file
        else:
            self.logger.error("No data processed")
            return None
    
    def _read_input_file(self, file_path, chunk_size):
        """
        Read input file in chunks with format auto-detection
        
        Args:
            file_path (str): Path to input file
            chunk_size (int): Chunk size for reading
            
        Returns:
            list: List of DataFrame chunks
        """
        file_ext = os.path.splitext(file_path)[1].lower()
        
        if file_ext == '.csv':
            # Read CSV in chunks
            chunks = []
            for chunk in pd.read_csv(file_path, chunksize=chunk_size):
                chunks.append(chunk)
            return chunks
            
        elif file_ext in ['.xlsx', '.xls']:
            # Read Excel file
            df = pd.read_excel(file_path)
            # Split into chunks
            return [df[i:i + chunk_size] for i in range(0, len(df), chunk_size)]
            
        elif file_ext == '.json':
            # Read JSON file
            df = pd.read_json(file_path)
            # Split into chunks
            return [df[i:i + chunk_size] for i in range(0, len(df), chunk_size)]
            
        elif file_ext == '.parquet':
            # Read Parquet file
            try:
                import pyarrow.parquet as pq
                df = pq.read_table(file_path).to_pandas()
                # Split into chunks
                return [df[i:i + chunk_size] for i in range(0, len(df), chunk_size)]
            except ImportError:
                # Fall back to pandas
                df = pd.read_parquet(file_path)
                # Split into chunks
                return [df[i:i + chunk_size] for i in range(0, len(df), chunk_size)]
        else:
            raise ValueError(f"Unsupported file format: {file_ext}")
    
    def _process_chunk(self, chunk, s1_col, s2_col, domain_col, id_col, return_detailed):
        """
        Process a chunk of data
        
        Args:
            chunk (DataFrame): DataFrame chunk to process
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            domain_col (str, optional): Column name for domain
            id_col (str, optional): Column name for ID
            return_detailed (bool): Whether to return detailed results
            
        Returns:
            DataFrame: Processed DataFrame with match results
        """
        result_df = chunk.copy()
        
        # Ensure matcher is initialized
        if self.matcher is None:
            self.logger.error("Matcher not initialized")
            return result_df
        
        # Add result columns if they don't exist
        if 'match_score' not in result_df.columns:
            result_df['match_score'] = 0.0
        if 'match_level' not in result_df.columns:
            result_df['match_level'] = 'no_match'
        if return_detailed and 'match_details' not in result_df.columns:
            result_df['match_details'] = None
        
        # Check if columns exist
        if s1_col not in result_df.columns or s2_col not in result_df.columns:
            self.logger.error(f"Required columns not found: {s1_col}, {s2_col}")
            return result_df
        
        # Process each row
        for idx, row in result_df.iterrows():
            try:
                s1 = row[s1_col]
                s2 = row[s2_col]
                
                # Skip if empty
                if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                    continue
                
                # Get domain if available
                domain = row[domain_col] if domain_col and domain_col in row else None
                
                # Match merchants
                if hasattr(self.matcher, 'match_merchants'):
                    # EnhancedMerchantMatcher
                    match_result = self.matcher.match_merchants(
                        s1, s2, domain, return_details=return_detailed
                    )
                else:
                    # GMARTMerchantMatchingPipeline
                    match_result = self.matcher.match(
                        s1, s2, domain, return_details=return_detailed
                    )
                
                # Update results
                if return_detailed:
                    result_df.at[idx, 'match_score'] = match_result['match_score']
                    result_df.at[idx, 'match_level'] = match_result['match_level']
                    result_df.at[idx, 'match_details'] = str(match_result)
                else:
                    score = match_result
                    level = 'high_match' if score >= 0.85 else (
                        'medium_match' if score >= 0.75 else (
                            'low_match' if score >= 0.6 else 'no_match'
                        )
                    )
                    result_df.at[idx, 'match_score'] = score
                    result_df.at[idx, 'match_level'] = level
            
            except Exception as e:
                self.logger.warning(f"Error processing row {idx}: {e}")
                continue
        
        return result_df
    
    def _process_chunk_file(self, chunk_path, s1_col, s2_col, domain_col, id_col, return_detailed):
        """
        Process a chunk file (for parallel processing)
        
        Args:
            chunk_path (str): Path to chunk file
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            domain_col (str, optional): Column name for domain
            id_col (str, optional): Column name for ID
            return_detailed (bool): Whether to return detailed results
            
        Returns:
            DataFrame: Processed DataFrame with match results
        """
        # Read chunk
        chunk = pd.read_csv(chunk_path)
        
        # Process chunk
        return self._process_chunk(chunk, s1_col, s2_col, domain_col, id_col, return_detailed)
    
    def _write_output_file(self, df, output_path):
        """
        Write output file in the specified format
        
        Args:
            df (DataFrame): DataFrame to write
            output_path (str): Path to output file
            
        Returns:
            bool: Success flag
        """
        try:
            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
            
            # Write based on format
            file_ext = os.path.splitext(output_path)[1].lower()
            
            if file_ext == '.csv':
                df.to_csv(output_path, index=False)
                
            elif file_ext in ['.xlsx', '.xls']:
                df.to_excel(output_path, index=False)
                
            elif file_ext == '.json':
                # Handle non-serializable objects
                if 'match_details' in df.columns:
                    df['match_details'] = df['match_details'].apply(
                        lambda x: str(x) if x is not None else None
                    )
                df.to_json(output_path, orient='records')
                
            elif file_ext == '.parquet':
                # Remove complex objects for parquet
                if 'match_details' in df.columns:
                    df = df.drop(columns=['match_details'])
                
                try:
                    import pyarrow as pa
                    import pyarrow.parquet as pq
                    
                    # Convert to pyarrow table and write
                    table = pa.Table.from_pandas(df)
                    pq.write_table(table, output_path)
                except ImportError:
                    # Fall back to pandas
                    df.to_parquet(output_path, index=False)
            else:
                # Default to CSV
                csv_path = os.path.splitext(output_path)[0] + '.csv'
                df.to_csv(csv_path, index=False)
                self.logger.warning(
                    f"Unsupported output format: {file_ext}. Saved as CSV: {csv_path}"
                )
                
            self.logger.info(f"Results saved to {output_path}")
            return True
            
        except Exception as e:
            self.logger.error(f"Error writing output file: {e}")
            # Try to save to a backup location
            try:
                backup_path = os.path.join(
                    os.getcwd(), 
                    f"merchant_matches_backup_{time.strftime('%Y%m%d_%H%M%S')}.csv"
                )
                df.to_csv(backup_path, index=False)
                self.logger.info(f"Backup saved to {backup_path}")
            except:
                pass
            return False
    
    def get_metrics(self):
        """Get processing metrics"""
        return self.metrics.copy()
    
    def merge_results(self, input_paths, output_path, how='inner', on=None):
        """
        Merge multiple result files
        
        Args:
            input_paths (list): List of input file paths
            output_path (str): Output file path
            how (str): Merge method ('inner', 'outer', 'left', 'right')
            on (str or list): Column(s) to join on
            
        Returns:
            str: Path to merged output file
        """
        try:
            # Read all input files
            dfs = []
            for path in input_paths:
                file_ext = os.path.splitext(path)[1].lower()
                
                if file_ext == '.csv':
                    df = pd.read_csv(path)
                elif file_ext in ['.xlsx', '.xls']:
                    df = pd.read_excel(path)
                elif file_ext == '.json':
                    df = pd.read_json(path)
                elif file_ext == '.parquet':
                    df = pd.read_parquet(path)
                else:
                    self.logger.warning(f"Unsupported file format: {file_ext}")
                    continue
                    
                dfs.append(df)
            
            if not dfs:
                self.logger.error("No valid input files")
                return None
                
            # Merge all dataframes
            result_df = dfs[0]
            for df in dfs[1:]:
                result_df = pd.merge(result_df, df, how=how, on=on)
            
            # Write merged result
            self._write_output_file(result_df, output_path)
            
            self.logger.info(
                f"Merged {len(dfs)} files with {len(result_df)} rows to {output_path}"
            )
            return output_path
            
        except Exception as e:
            self.logger.error(f"Error merging results: {e}")
            return None

In [44]:
# 10.2: PySpark Integration for Distributed Processing

class SparkMerchantMatcher:
    """
    PySpark wrapper for distributed merchant name matching at scale.
    
    This class enables distributed processing of merchant matching tasks
    using Apache Spark, allowing for processing of very large datasets
    across a cluster of machines.
    
    Key features:
    - Distributed matching using Spark executors
    - Flexible partitioning strategies
    - Built-in checkpoint and recovery
    - Support for various input/output formats
    - Performance optimization for large datasets
    """
    
    def __init__(self, spark_session=None, matcher_config=None, 
                 checkpoint_dir=None, partition_size=100000):
        """
        Initialize Spark merchant matcher with configuration
        
        Args:
            spark_session: Existing SparkSession or None to create new
            matcher_config (dict): Configuration for matcher components
            checkpoint_dir (str): Directory for checkpointing
            partition_size (int): Target size for partitions
        """
        self.matcher_config = matcher_config or {}
        self.partition_size = partition_size
        
        # Initialize Spark session if not provided
        if spark_session is None:
            try:
                from pyspark.sql import SparkSession
                
                # Create Spark session
                self.spark = SparkSession.builder \
                    .appName("MerchantMatcher") \
                    .config("spark.executor.memory", "4g") \
                    .config("spark.driver.memory", "4g") \
                    .config("spark.default.parallelism", "8") \
                    .config("spark.sql.shuffle.partitions", "8") \
                    .config("spark.executor.cores", "2") \
                    .getOrCreate()
                    
                self.logger = self.spark.sparkContext._jvm.org.apache.log4j.LogManager \
                    .getLogger("SparkMerchantMatcher")
                    
            except ImportError:
                raise ImportError("PySpark is required for SparkMerchantMatcher")
        else:
            self.spark = spark_session
            self.logger = None  # Use Python logging
            
        # Set checkpoint directory if provided
        if checkpoint_dir:
            self.spark.sparkContext.setCheckpointDir(checkpoint_dir)
            
        # Track initialization state
        self._matcher_broadcast = None
        self._thresholds_broadcast = None
        self._matcher_initialized = False
    
    def _initialize_matcher_broadcast(self):
        """Initialize and broadcast the matcher to all executors"""
        if self._matcher_initialized:
            return
            
        try:
            import pickle
            import base64
            
            # Create a serializable version of matcher components
            matcher_serialized = {
                "weights": self.matcher_config.get("weights", {
                    "string_similarity": 0.20,
                    "token_set_similarity": 0.15,
                    "semantic_similarity": 0.30,
                    "pattern_match": 0.25,
                    "contains_check": 0.05,
                    "acronym_check": 0.05
                }),
                "thresholds": self.matcher_config.get("thresholds", {
                    "high": 0.85,
                    "medium": 0.75,
                    "low": 0.60
                })
            }
            
            # Broadcast configuration to all executors
            self._thresholds_broadcast = self.spark.sparkContext.broadcast(
                matcher_serialized["thresholds"]
            )
            
            # Note: We don't broadcast the full matcher object because:
            # 1. It may contain non-serializable components (like BERT model)
            # 2. It's more efficient to initialize the matcher on each executor
            
            self._matcher_initialized = True
            
        except Exception as e:
            if self.logger:
                self.logger.error(f"Failed to initialize matcher broadcast: {e}")
            else:
                print(f"Failed to initialize matcher broadcast: {e}")
    
    def process_dataframe(self, df, s1_col='s1', s2_col='s2', domain_col=None,
                          id_col=None, output_path=None, output_format='csv'):
        """
        Process a Spark DataFrame with merchant name pairs
        
        Args:
            df: Spark DataFrame with merchant name pairs
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            domain_col (str, optional): Column name for domain
            id_col (str, optional): Column name for record ID
            output_path (str, optional): Path to save results
            output_format (str): Output format (csv, parquet, json)
            
        Returns:
            DataFrame: Spark DataFrame with match results
        """
        from pyspark.sql.functions import udf, col, lit
        from pyspark.sql.types import DoubleType, StringType, StructType, StructField
        
        # Initialize matcher configuration broadcast
        self._initialize_matcher_broadcast()
        
        # Define the schema for the result
        result_schema = StructType([
            StructField("match_score", DoubleType(), True),
            StructField("match_level", StringType(), True)
        ])
        
        # Define matching UDF for Spark
        def match_merchants_udf(s1, s2, domain=None):
            """UDF for merchant matching in Spark executors"""
            try:
                # Import here for executor scope
                import re
                import numpy as np
                from Levenshtein import distance as levenshtein_distance
                from Levenshtein import jaro_winkler, ratio as levenshtein_ratio
                
                # If both inputs are empty, return no match
                if not s1 or not s2 or not isinstance(s1, str) or not isinstance(s2, str):
                    return (0.0, "no_match")
                
                # Simplified preprocessing
                def preprocess(text):
                    if not isinstance(text, str):
                        return ""
                    # Lowercase and trim
                    text = text.lower().strip()
                    # Remove most punctuation and normalize spaces
                    text = re.sub(r'[^a-z0-9\s]', ' ', text)
                    # Normalize spaces
                    text = re.sub(r'\s+', ' ', text).strip()
                    return text
                
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                
                # If preprocessed strings are empty, return no match
                if not s1_clean or not s2_clean:
                    return (0.0, "no_match")
                
                # Fast exact match
                if s1_clean == s2_clean:
                    return (1.0, "high_match")
                
                # Get thresholds from broadcast
                thresholds = _thresholds_bc.value
                
                # Calculate string similarities
                jw_similarity = jaro_winkler(s1_clean, s2_clean)
                
                # Token set similarity
                def token_set_ratio(s1, s2):
                    s1_words = set(s1.split())
                    s2_words = set(s2.split())
                    
                    intersection = s1_words.intersection(s2_words)
                    union = s1_words.union(s2_words)
                    
                    if not union:
                        return 0.0
                    
                    return len(intersection) / len(union)
                
                token_sim = token_set_ratio(s1_clean, s2_clean)
                
                # Contains check
                contains_score = 0.0
                if s1_clean in s2_clean or s2_clean in s1_clean:
                    contains_score = 1.0
                
                # Combine scores with simple weights
                score = (jw_similarity * 0.5) + (token_sim * 0.3) + (contains_score * 0.2)
                
                # Determine match level
                if score >= thresholds["high"]:
                    level = "high_match"
                elif score >= thresholds["medium"]:
                    level = "medium_match"
                elif score >= thresholds["low"]:
                    level = "low_match"
                else:
                    level = "no_match"
                
                return (score, level)
                
            except Exception as e:
                # In case of any error, return no match
                return (0.0, "no_match")
        
        # Register broadcast variables in UDF
        _thresholds_bc = self._thresholds_broadcast
        
        # Create UDF
        matcher_udf = udf(match_merchants_udf, result_schema)
        
        # Apply matching to DataFrame
        if domain_col and domain_col in df.columns:
            result_df = df.withColumn(
                "match_result",
                matcher_udf(col(s1_col), col(s2_col), col(domain_col))
            )
        else:
            result_df = df.withColumn(
                "match_result",
                matcher_udf(col(s1_col), col(s2_col), lit(None))
            )
        
        # Extract fields from struct
        result_df = result_df \
            .withColumn("match_score", col("match_result.match_score")) \
            .withColumn("match_level", col("match_result.match_level")) \
            .drop("match_result")
        
        # Save results if output path provided
        if output_path:
            if output_format == 'csv':
                result_df.write.csv(output_path, header=True, mode="overwrite")
            elif output_format == 'parquet':
                result_df.write.parquet(output_path, mode="overwrite")
            elif output_format == 'json':
                result_df.write.json(output_path, mode="overwrite")
            else:
                # Default to parquet
                result_df.write.parquet(output_path, mode="overwrite")
        
        return result_df
    
    def process_file(self, input_path, output_path=None, s1_col='s1', s2_col='s2',
                    domain_col=None, output_format='csv', options=None):
        """
        Process a file containing merchant name pairs using Spark
        
        Args:
            input_path (str): Path to input file
            output_path (str, optional): Path to output directory
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            domain_col (str, optional): Column name for domain
            output_format (str): Output format (csv, parquet, json)
            options (dict): Additional options for reading input
            
        Returns:
            DataFrame: Spark DataFrame with match results
        """
        # Determine file format from extension
        file_format = input_path.split('.')[-1].lower()
        read_options = options or {}
        
        # Read input file
        if file_format == 'csv':
            df = self.spark.read.csv(input_path, header=True, inferSchema=True, **read_options)
        elif file_format in ['xls', 'xlsx']:
            # Excel requires additional libraries
            try:
                # Read using pandas and convert to Spark DataFrame
                import pandas as pd
                pandas_df = pd.read_excel(input_path, **read_options)
                df = self.spark.createDataFrame(pandas_df)
            except ImportError:
                raise ImportError("pandas is required for reading Excel files")
        elif file_format == 'json':
            df = self.spark.read.json(input_path, **read_options)
        elif file_format == 'parquet':
            df = self.spark.read.parquet(input_path, **read_options)
        else:
            raise ValueError(f"Unsupported file format: {file_format}")
        
        # Process the DataFrame
        return self.process_dataframe(
            df, s1_col, s2_col, domain_col, output_path=output_path, output_format=output_format
        )
    
    def find_matches(self, query_df, candidate_df, query_col='merchant_name', 
                     candidate_col='merchant_name', threshold=0.6, top_k=5,
                     domain_col=None, output_path=None, output_format='csv'):
        """
        Find best matches for each query merchant from candidates
        
        Args:
            query_df: Spark DataFrame with query merchants
            candidate_df: Spark DataFrame with candidate merchants
            query_col (str): Column name for query merchant names
            candidate_col (str): Column name for candidate merchant names
            threshold (float): Minimum score threshold
            top_k (int): Maximum number of matches to return per query
            domain_col (str, optional): Column name for domain
            output_path (str, optional): Path to save results
            output_format (str): Output format (csv, parquet, json)
            
        Returns:
            DataFrame: Spark DataFrame with match results
        """
        from pyspark.sql.functions import udf, col, lit, row_number, desc
        from pyspark.sql.types import DoubleType, StringType, StructType, StructField
        from pyspark.sql.window import Window
        
        # Initialize matcher configuration broadcast
        self._initialize_matcher_broadcast()
        
        # Define the schema for the result
        result_schema = StructType([
            StructField("match_score", DoubleType(), True),
            StructField("match_level", StringType(), True)
        ])
        
        # Define simplified matching UDF for Spark
        def match_merchants_udf(s1, s2, domain=None):
            """UDF for merchant matching in Spark executors"""
            try:
                # Import here for executor scope
                import re
                import numpy as np
                from Levenshtein import distance as levenshtein_distance
                from Levenshtein import jaro_winkler, ratio as levenshtein_ratio
                
                # If both inputs are empty, return no match
                if not s1 or not s2 or not isinstance(s1, str) or not isinstance(s2, str):
                    return (0.0, "no_match")
                
                # Simplified preprocessing
                def preprocess(text):
                    if not isinstance(text, str):
                        return ""
                    # Lowercase and trim
                    text = text.lower().strip()
                    # Remove most punctuation and normalize spaces
                    text = re.sub(r'[^a-z0-9\s]', ' ', text)
                    # Normalize spaces
                    text = re.sub(r'\s+', ' ', text).strip()
                    return text
                
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                
                # If preprocessed strings are empty, return no match
                if not s1_clean or not s2_clean:
                    return (0.0, "no_match")
                
                # Fast exact match
                if s1_clean == s2_clean:
                    return (1.0, "high_match")
                
                # Get thresholds from broadcast
                thresholds = _thresholds_bc.value
                
                # Calculate string similarities
                jw_similarity = jaro_winkler(s1_clean, s2_clean)
                
                # Token set similarity
                def token_set_ratio(s1, s2):
                    s1_words = set(s1.split())
                    s2_words = set(s2.split())
                    
                    intersection = s1_words.intersection(s2_words)
                    union = s1_words.union(s2_words)
                    
                    if not union:
                        return 0.0
                    
                    return len(intersection) / len(union)
                
                token_sim = token_set_ratio(s1_clean, s2_clean)
                
                # Contains check
                contains_score = 0.0
                if s1_clean in s2_clean or s2_clean in s1_clean:
                    contains_score = 1.0
                
                # Combine scores with simple weights
                score = (jw_similarity * 0.5) + (token_sim * 0.3) + (contains_score * 0.2)
                
                # Determine match level
                if score >= thresholds["high"]:
                    level = "high_match"
                elif score >= thresholds["medium"]:
                    level = "medium_match"
                elif score >= thresholds["low"]:
                    level = "low_match"
                else:
                    level = "no_match"
                
                return (score, level)
                
            except Exception as e:
                # In case of any error, return no match
                return (0.0, "no_match")
        
        # Register broadcast variables in UDF
        _thresholds_bc = self._thresholds_broadcast
        
        # Create UDF
        matcher_udf = udf(match_merchants_udf, result_schema)
        
        # Create cross join with query and candidate DataFrames
        cross_df = query_df.crossJoin(
            candidate_df.select(
                col(candidate_col).alias("candidate_name"), 
                *[col(c) for c in candidate_df.columns if c != candidate_col]
            )
        )
        
        # Apply matching to cross join
        if domain_col and domain_col in cross_df.columns:
            match_df = cross_df.withColumn(
                "match_result",
                matcher_udf(col(query_col), col("candidate_name"), col(domain_col))
            )
        else:
            match_df = cross_df.withColumn(
                "match_result",
                matcher_udf(col(query_col), col("candidate_name"), lit(None))
            )
        
        # Extract fields from struct
        match_df = match_df \
            .withColumn("match_score", col("match_result.match_score")) \
            .withColumn("match_level", col("match_result.match_level")) \
            .drop("match_result")
        
        # Filter by threshold
        match_df = match_df.filter(col("match_score") >= threshold)
        
        # Get top-k matches for each query
        window_spec = Window.partitionBy(query_col).orderBy(desc("match_score"))
        top_matches_df = match_df \
            .withColumn("rank", row_number().over(window_spec)) \
            .filter(col("rank") <= top_k) \
            .drop("rank")
        
        # Save results if output path provided
        if output_path:
            if output_format == 'csv':
                top_matches_df.write.csv(output_path, header=True, mode="overwrite")
            elif output_format == 'parquet':
                top_matches_df.write.parquet(output_path, mode="overwrite")
            elif output_format == 'json':
                top_matches_df.write.json(output_path, mode="overwrite")
            else:
                # Default to parquet
                top_matches_df.write.parquet(output_path, mode="overwrite")
        
        return top_matches_df
    
    def stop(self):
        """Stop the Spark session"""
        if hasattr(self, 'spark') and self.spark:
            self.spark.stop()

In [46]:
# 10.3: Performance Optimization and Distributed Analysis

class MerchantMatchingDistributor:
    """
    Advanced distributed workflow manager for merchant name matching
    with sophisticated workload partitioning and monitoring.
    
    This class provides high-level functions to efficiently distribute
    merchant matching tasks across computation resources, with automatic
    performance tuning and workflow management.
    
    Key features:
    - Adaptive partitioning based on data characteristics
    - Hybrid execution model (local + distributed)
    - Incremental processing with checkpoints
    - Smart resource allocation
    - Performance monitoring dashboard
    """
    
    def __init__(self, config=None, use_spark=False, use_dask=False,
                 min_partition_size=1000, max_workers=None, dashboard=False):
        """
        Initialize the distributor with configuration options
        
        Args:
            config (dict): Configuration dictionary
            use_spark (bool): Whether to use PySpark for distributed processing
            use_dask (bool): Whether to use Dask for distributed processing
            min_partition_size (int): Minimum partition size
            max_workers (int): Maximum number of workers
            dashboard (bool): Enable performance dashboard
        """
        self.config = config or {}
        self.use_spark = use_spark
        self.use_dask = use_dask
        self.min_partition_size = min_partition_size
        self.max_workers = max_workers
        self.dashboard = dashboard
        
        # Track performance metrics
        self.metrics = {
            'processing_time': 0,
            'records_processed': 0,
            'throughput': 0,
            'partitions': 0,
            'partition_distribution': {},
            'start_time': None,
            'end_time': None,
        }
        
        # Configure logging
        self.logger = logging.getLogger('MerchantDistributor')
        self.logger.setLevel(logging.INFO)
        
        # Initialize components based on configuration
        self._initialize_components()
    
    def _initialize_components(self):
        """Initialize distributed processing components"""
        self.matcher = None
        self.spark_matcher = None
        self.dask_client = None
        
        # Initialize matcher
        matcher_config = self.config.get('matcher', {})
        
        # Create EnhancedMerchantMatcher instance
        if not hasattr(self, 'matcher') or self.matcher is None:
            try:
                # Use reference to EnhancedMerchantMatcher from previously imported code
                # Assuming it has been properly defined in previous cells
                self.matcher = EnhancedMerchantMatcher(
                    weights=matcher_config.get('weights'),
                    thresholds=matcher_config.get('thresholds')
                )
                self.logger.info("Initialized EnhancedMerchantMatcher")
            except Exception as e:
                self.logger.warning(f"Failed to initialize EnhancedMerchantMatcher: {e}")
                self.matcher = None
        
        # Initialize BatchProcessor
        self.batch_processor = BatchProcessor(
            matcher=self.matcher,
            chunk_size=self.config.get('batch_size', 10000),
            n_jobs=self.max_workers or -1
        )
        
        # Initialize Spark if requested
        if self.use_spark:
            try:
                from pyspark.sql import SparkSession
                
                # Create or get Spark session
                spark = SparkSession.builder \
                    .appName("MerchantMatching") \
                    .config("spark.executor.memory", "4g") \
                    .config("spark.driver.memory", "4g") \
                    .getOrCreate()
                
                # Initialize Spark matcher
                self.spark_matcher = SparkMerchantMatcher(
                    spark_session=spark,
                    matcher_config=matcher_config,
                    checkpoint_dir=self.config.get('checkpoint_dir')
                )
                
                self.logger.info("Initialized Spark-based distributed processing")
                
            except ImportError:
                self.logger.warning("PySpark not available. Falling back to local processing.")
                self.use_spark = False
        
        # Initialize Dask if requested
        if self.use_dask:
            try:
                import dask
                import dask.dataframe as dd
                from dask.distributed import Client, LocalCluster
                
                # Create local cluster if not using existing
                if self.config.get('dask_scheduler_address'):
                    self.dask_client = Client(self.config['dask_scheduler_address'])
                else:
                    # Create local cluster with specified workers
                    cluster = LocalCluster(
                        n_workers=self.max_workers or 4,
                        threads_per_worker=2,
                        memory_limit='4GB'
                    )
                    self.dask_client = Client(cluster)
                
                if self.dashboard:
                    self.logger.info(f"Dask dashboard available at: {self.dask_client.dashboard_link}")
                
                self.logger.info("Initialized Dask-based distributed processing")
                
            except ImportError:
                self.logger.warning("Dask not available. Falling back to local processing.")
                self.use_dask = False
    
    def _analyze_dataset(self, data_source):
        """
        Analyze dataset to determine optimal partitioning strategy
        
        Args:
            data_source: DataFrame or file path
            
        Returns:
            dict: Dataset characteristics
        """
        # Initialize results
        analysis = {
            'record_count': 0,
            'estimated_size_mb': 0,
            'recommended_partitions': 0,
            'columns': [],
            'has_domain_column': False,
            'string_columns': [],
            'domain_distribution': {}
        }
        
        try:
            # Handle different input types
            if isinstance(data_source, str):
                # Analyze file size
                import os
                file_size = os.path.getsize(data_source) / (1024 * 1024)  # Size in MB
                analysis['estimated_size_mb'] = file_size
                
                # Sample file to analyze structure
                import pandas as pd
                
                # Determine file format
                file_ext = os.path.splitext(data_source)[1].lower()
                
                if file_ext == '.csv':
                    # Read sample of CSV
                    sample_df = pd.read_csv(data_source, nrows=10000)
                elif file_ext in ['.xlsx', '.xls']:
                    # Read sample of Excel
                    sample_df = pd.read_excel(data_source, nrows=10000)
                elif file_ext == '.json':
                    # Read sample of JSON
                    sample_df = pd.read_json(data_source, lines=True, nrows=10000)
                elif file_ext == '.parquet':
                    # Read sample of Parquet
                    sample_df = pd.read_parquet(data_source)
                    # Limit to 10000 rows
                    if len(sample_df) > 10000:
                        sample_df = sample_df.iloc[:10000]
                else:
                    self.logger.warning(f"Unsupported file format for analysis: {file_ext}")
                    return analysis
                
                # Use the sample for column analysis
                analysis['columns'] = list(sample_df.columns)
                analysis['string_columns'] = [
                    col for col in sample_df.columns 
                    if sample_df[col].dtype == 'object'
                ]
                
                # Check for domain column
                if 'domain' in sample_df.columns:
                    analysis['has_domain_column'] = True
                    # Get domain distribution
                    domain_counts = sample_df['domain'].value_counts()
                    analysis['domain_distribution'] = domain_counts.to_dict()
                
                # Estimate total records for file types
                if file_ext == '.csv':
                    # Estimate total lines
                    with open(data_source, 'r') as f:
                        for i, _ in enumerate(f):
                            if i >= 100000:  # Limit to prevent slow performance
                                break
                        total_lines = i + 1
                    
                    # Adjust for header
                    analysis['record_count'] = total_lines - 1
                else:
                    # For other formats, we might not know exact count
                    # Use the sample ratio to estimate
                    avg_row_size = file_size / len(sample_df) if len(sample_df) > 0 else 0
                    if avg_row_size > 0:
                        analysis['record_count'] = int(file_size / avg_row_size)
                    else:
                        analysis['record_count'] = 0
            
            elif hasattr(data_source, 'shape'):
                # Pandas DataFrame
                analysis['record_count'] = data_source.shape[0]
                analysis['columns'] = list(data_source.columns)
                analysis['string_columns'] = [
                    col for col in data_source.columns 
                    if data_source[col].dtype == 'object'
                ]
                
                # Check for domain column
                if 'domain' in data_source.columns:
                    analysis['has_domain_column'] = True
                    # Get domain distribution
                    domain_counts = data_source['domain'].value_counts()
                    analysis['domain_distribution'] = domain_counts.to_dict()
                
                # Estimate size
                analysis['estimated_size_mb'] = data_source.memory_usage(deep=True).sum() / (1024 * 1024)
            
            else:
                self.logger.warning(f"Unsupported data source type: {type(data_source)}")
                return analysis
            
            # Calculate recommended partitions
            if analysis['record_count'] > 0:
                # Base recommendation on records and estimated size
                size_based = max(1, int(analysis['estimated_size_mb'] / 100))  # ~100MB per partition
                count_based = max(1, int(analysis['record_count'] / self.min_partition_size))
                
                analysis['recommended_partitions'] = max(size_based, count_based)
                
                # Limit by available workers
                if self.max_workers:
                    analysis['recommended_partitions'] = min(
                        analysis['recommended_partitions'], 
                        self.max_workers * 2
                    )
            
            return analysis
            
        except Exception as e:
            self.logger.error(f"Error analyzing dataset: {e}")
            return analysis
    
    def process_dataset(self, data_source, output_path=None, s1_col='s1', s2_col='s2',
                        domain_col=None, id_col=None, return_detailed=False):
        """
        Process a dataset with optimized distributed execution
        
        Args:
            data_source: DataFrame or file path
            output_path (str, optional): Path to save results
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            domain_col (str, optional): Column name for domain
            id_col (str, optional): Column name for record ID
            return_detailed (bool): Whether to return detailed match info
            
        Returns:
            DataFrame or str: Results DataFrame or path to results file
        """
        # Analyze dataset for optimal partitioning
        self.metrics['start_time'] = time.time()
        
        analysis = self._analyze_dataset(data_source)
        self.logger.info(
            f"Dataset analysis: {analysis['record_count']} records, "
            f"{analysis['estimated_size_mb']:.2f} MB, "
            f"{analysis['recommended_partitions']} recommended partitions"
        )
        
        # Choose execution strategy based on dataset and available resources
        if self.use_spark and analysis['record_count'] > 100000:
            # Use Spark for very large datasets
            self.logger.info("Using Spark for distributed processing")
            return self._process_with_spark(
                data_source, output_path, s1_col, s2_col, domain_col, id_col
            )
        elif self.use_dask and analysis['record_count'] > 50000:
            # Use Dask for large datasets
            self.logger.info("Using Dask for distributed processing")
            return self._process_with_dask(
                data_source, output_path, s1_col, s2_col, domain_col, id_col, return_detailed
            )
        else:
            # Use batch processor for smaller datasets
            self.logger.info("Using batch processor for local processing")
            return self._process_with_batch(
                data_source, output_path, s1_col, s2_col, domain_col, id_col, return_detailed
            )
    
    def _process_with_batch(self, data_source, output_path, s1_col, s2_col, 
                           domain_col, id_col, return_detailed):
        """Process using BatchProcessor"""
        try:
            # Handle different input types
            if isinstance(data_source, str):
                # Process file
                result_path = self.batch_processor.process_file(
                    data_source, output_path, s1_col, s2_col, domain_col, id_col, return_detailed
                )
                
                # Update metrics
                self.metrics.update(self.batch_processor.get_metrics())
                
                # Return path to results
                return result_path
            else:
                # Process DataFrame
                result_df = self.batch_processor._process_chunk(
                    data_source, s1_col, s2_col, domain_col, id_col, return_detailed
                )
                
                # Save if output path provided
                if output_path:
                    self.batch_processor._write_output_file(result_df, output_path)
                
                # Update metrics
                self.metrics['records_processed'] = len(result_df)
                self.metrics['processing_time'] = time.time() - self.metrics['start_time']
                if self.metrics['processing_time'] > 0:
                    self.metrics['throughput'] = self.metrics['records_processed'] / self.metrics['processing_time']
                
                return result_df
                
        except Exception as e:
            self.logger.error(f"Error in batch processing: {e}")
            raise
        finally:
            self.metrics['end_time'] = time.time()
    
    def _process_with_spark(self, data_source, output_path, s1_col, s2_col, domain_col, id_col):
        """Process using Spark"""
        try:
            if self.spark_matcher is None:
                self.logger.error("Spark matcher not initialized")
                return None
            
            # Process with Spark
            start_time = time.time()
            
            if isinstance(data_source, str):
                # Process file
                result_df = self.spark_matcher.process_file(
                    data_source, output_path, s1_col, s2_col, domain_col
                )
            else:
                # Convert to Spark DataFrame if needed
                if hasattr(data_source, 'toPandas'):
                    # Already a Spark DataFrame
                    spark_df = data_source
                else:
                    # Convert Pandas DataFrame to Spark DataFrame
                    spark_df = self.spark_matcher.spark.createDataFrame(data_source)
                
                # Process DataFrame
                result_df = self.spark_matcher.process_dataframe(
                    spark_df, s1_col, s2_col, domain_col, id_col, output_path
                )
            
            # Update metrics
            self.metrics['records_processed'] = result_df.count()
            self.metrics['processing_time'] = time.time() - start_time
            if self.metrics['processing_time'] > 0:
                self.metrics['throughput'] = self.metrics['records_processed'] / self.metrics['processing_time']
            
            return result_df
            
        except Exception as e:
            self.logger.error(f"Error in Spark processing: {e}")
            raise
        finally:
            self.metrics['end_time'] = time.time()
    
    def _process_with_dask(self, data_source, output_path, s1_col, s2_col, 
                          domain_col, id_col, return_detailed):
        """Process using Dask"""
        try:
            if self.dask_client is None:
                self.logger.error("Dask client not initialized")
                return None
            
            import dask.dataframe as dd
            import pandas as pd
            
            # Convert input to Dask DataFrame
            if isinstance(data_source, str):
                # Read file with Dask
                file_ext = os.path.splitext(data_source)[1].lower()
                
                if file_ext == '.csv':
                    dask_df = dd.read_csv(data_source)
                elif file_ext in ['.xlsx', '.xls']:
                    # Dask doesn't support Excel natively, use pandas
                    pandas_df = pd.read_excel(data_source)
                    dask_df = dd.from_pandas(pandas_df, npartitions=self.max_workers or 4)
                elif file_ext == '.json':
                    dask_df = dd.read_json(data_source, lines=True)
                elif file_ext == '.parquet':
                    dask_df = dd.read_parquet(data_source)
                else:
                    self.logger.error(f"Unsupported file format for Dask: {file_ext}")
                    return None
            elif isinstance(data_source, pd.DataFrame):
                # Convert Pandas DataFrame to Dask DataFrame
                dask_df = dd.from_pandas(data_source, npartitions=self.max_workers or 4)
            elif isinstance(data_source, dd.DataFrame):
                # Already a Dask DataFrame
                dask_df = data_source
            else:
                self.logger.error(f"Unsupported data source type for Dask: {type(data_source)}")
                return None
            
            # Define matching function for Dask
            def match_merchants_func(df):
                """Function to apply matcher to a partition"""
                result_df = df.copy()
                
                # Add result columns if they don't exist
                if 'match_score' not in result_df.columns:
                    result_df['match_score'] = 0.0
                if 'match_level' not in result_df.columns:
                    result_df['match_level'] = 'no_match'
                if return_detailed and 'match_details' not in result_df.columns:
                    result_df['match_details'] = None
                
                # Process each row
                for idx, row in result_df.iterrows():
                    try:
                        s1 = row[s1_col]
                        s2 = row[s2_col]
                        
                        # Skip if empty
                        if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                            continue
                        
                        # Get domain if available
                        domain = row[domain_col] if domain_col and domain_col in row else None
                        
                        # Match with simplified algorithm
                        from Levenshtein import jaro_winkler
                        
                        # Simplified preprocessing
                        def preprocess(text):
                            if not isinstance(text, str):
                                return ""
                            # Lowercase and trim
                            text = text.lower().strip()
                            # Remove punctuation and normalize spaces
                            import re
                            text = re.sub(r'[^a-z0-9\s]', ' ', text)
                            text = re.sub(r'\s+', ' ', text).strip()
                            return text
                        
                        s1_clean = preprocess(s1)
                        s2_clean = preprocess(s2)
                        
                        # If preprocessed strings are empty, skip
                        if not s1_clean or not s2_clean:
                            continue
                        
                        # Fast exact match
                        if s1_clean == s2_clean:
                            result_df.at[idx, 'match_score'] = 1.0
                            result_df.at[idx, 'match_level'] = 'high_match'
                            continue
                        
                        # Calculate similarity
                        jw_sim = jaro_winkler(s1_clean, s2_clean)
                        
                        # Simple token similarity
                        def token_sim(s1, s2):
                            s1_words = set(s1.split())
                            s2_words = set(s2.split())
                            if not s1_words or not s2_words:
                                return 0.0
                            return len(s1_words.intersection(s2_words)) / len(s1_words.union(s2_words))
                        
                        token_similarity = token_sim(s1_clean, s2_clean)
                        
                        # Contains check
                        contains = 1.0 if s1_clean in s2_clean or s2_clean in s1_clean else 0.0
                        
                        # Combine with simple weights
                        score = 0.5 * jw_sim + 0.3 * token_similarity + 0.2 * contains
                        
                        # Determine level
                        if score >= 0.85:
                            level = 'high_match'
                        elif score >= 0.75:
                            level = 'medium_match'
                        elif score >= 0.6:
                            level = 'low_match'
                        else:
                            level = 'no_match'
                        
                        # Update results
                        result_df.at[idx, 'match_score'] = score
                        result_df.at[idx, 'match_level'] = level
                        
                    except Exception as e:
                        # Skip on error
                        continue
                
                return result_df
            
            # Apply function to each partition
            start_time = time.time()
            result_dask_df = dask_df.map_partitions(match_merchants_func)
            
            # Compute results
            result_df = result_dask_df.compute()
            
            # Save results if output path provided
            if output_path:
                # Determine output format from extension
                file_ext = os.path.splitext(output_path)[1].lower()
                
                if file_ext == '.csv':
                    result_df.to_csv(output_path, index=False)
                elif file_ext in ['.xlsx', '.xls']:
                    result_df.to_excel(output_path, index=False)
                elif file_ext == '.json':
                    result_df.to_json(output_path, orient='records', lines=True)
                elif file_ext == '.parquet':
                    result_df.to_parquet(output_path, index=False)
                else:
                    # Default to CSV
                    csv_path = os.path.splitext(output_path)[0] + '.csv'
                    result_df.to_csv(csv_path, index=False)
            
            # Update metrics
            self.metrics['records_processed'] = len(result_df)
            self.metrics['processing_time'] = time.time() - start_time
            if self.metrics['processing_time'] > 0:
                self.metrics['throughput'] = self.metrics['records_processed'] / self.metrics['processing_time']
            
            self.metrics['partitions'] = dask_df.npartitions
            
            return result_df
            
        except Exception as e:
            self.logger.error(f"Error in Dask processing: {e}")
            raise
        finally:
            self.metrics['end_time'] = time.time()
    
    def get_performance_report(self):
        """
        Get detailed performance report
        
        Returns:
            dict: Performance metrics and analysis
        """
        report = self.metrics.copy()
        
        # Add processing mode
        if hasattr(self, 'spark_matcher') and self.spark_matcher is not None:
            report['processing_mode'] = 'spark'
        elif hasattr(self, 'dask_client') and self.dask_client is not None:
            report['processing_mode'] = 'dask'
        else:
            report['processing_mode'] = 'batch'
        
        # Format times
        if report['start_time'] and report['end_time']:
            report['total_duration_seconds'] = report['end_time'] - report['start_time']
            report['formatted_duration'] = self._format_duration(report['total_duration_seconds'])
            
            # Add timestamp
            from datetime import datetime
            report['timestamp'] = datetime.fromtimestamp(report['end_time']).strftime(
                '%Y-%m-%d %H:%M:%S'
            )
        
        # Add formatted throughput
        if report.get('throughput', 0) > 0:
            report['formatted_throughput'] = f"{report['throughput']:.2f} records/second"
        
        return report
    
    def _format_duration(self, seconds):
        """Format duration in human-readable form"""
        if seconds < 60:
            return f"{seconds:.2f} seconds"
        elif seconds < 3600:
            minutes = seconds / 60
            return f"{minutes:.2f} minutes"
        else:
            hours = seconds / 3600
            return f"{hours:.2f} hours"
    
    def visualize_performance(self, output_path=None):
        """
        Create performance visualization
        
        Args:
            output_path (str, optional): Path to save visualization
            
        Returns:
            str or None: Path to visualization or None if not generated
        """
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Create figure with performance metrics
            fig, axs = plt.subplots(2, 1, figsize=(10, 10))
            
            # Throughput and record count
            axs[0].bar(['Records Processed'], [self.metrics['records_processed']], color='blue')
            axs[0].set_ylabel('Count')
            axs[0].set_title('Records Processed')
            
            # Add throughput as text
            if self.metrics.get('throughput', 0) > 0:
                axs[0].text(
                    0, self.metrics['records_processed'] * 0.5,
                    f"Throughput: {self.metrics['throughput']:.2f} records/second",
                    fontsize=12
                )
            
            # Processing time
            axs[1].bar(['Processing Time'], [self.metrics['processing_time']], color='green')
            axs[1].set_ylabel('Seconds')
            axs[1].set_title('Processing Time')
            
            plt.tight_layout()
            
            # Save if output path provided
            if output_path:
                plt.savefig(output_path)
                plt.close()
                return output_path
            else:
                return None
                
        except ImportError:
            self.logger.warning("Matplotlib not available for visualization")
            return None
    
    def shutdown(self):
        """Shutdown all distributed resources"""
        # Stop Spark session
        if hasattr(self, 'spark_matcher') and self.spark_matcher is not None:
            try:
                self.spark_matcher.stop()
                self.logger.info("Spark session stopped")
            except:
                pass
        
        # Close Dask client
        if hasattr(self, 'dask_client') and self.dask_client is not None:
            try:
                self.dask_client.close()
                self.logger.info("Dask client closed")
            except:
                pass

Cell 11: Evaluation and Testing
I'll now implement a comprehensive evaluation framework for the merchant matching system, focusing on rigorous comparison methods and performance analysis.


In [49]:
# Cell 11: Evaluation and Testing

# 11.1: Evaluation Framework and Metrics Calculation

class MerchantMatchingEvaluator:
    """
    Comprehensive evaluation framework for merchant name matching algorithms
    with rigorous statistical analysis and visualization capabilities.
    
    This class enables thorough assessment of matching algorithm performance,
    comparison between different approaches, and statistical validation of improvements.
    
    Key features:
    - Standard metrics calculation (precision, recall, F1, accuracy)
    - Advanced metrics (AUC-ROC, AUC-PR, MCC, confusion matrices)
    - Multi-algorithm comparison framework
    - Statistical significance testing
    - Cross-validation for robustness
    - Performance visualization across different dimensions
    - Error analysis and classification
    """
    
    def __init__(self, ground_truth_data=None, test_size=0.2, random_state=42, 
                matcher=None, baseline_matchers=None):
        """
        Initialize evaluator with ground truth data and matchers
        
        Args:
            ground_truth_data (DataFrame): DataFrame with labeled merchant pairs
            test_size (float): Proportion of data for testing (if splitting)
            random_state (int): Random seed for reproducibility
            matcher: Primary matcher to evaluate
            baseline_matchers (dict): Dictionary of {name: matcher} for comparison
        """
        self.ground_truth_data = ground_truth_data
        self.test_size = test_size
        self.random_state = random_state
        self.matcher = matcher
        self.baseline_matchers = baseline_matchers or {}
        
        # Set up logging
        self.logger = logging.getLogger('MerchantEvaluator')
        self.logger.setLevel(logging.INFO)
        
        # Tracking for results
        self.results = {}
        self.statistical_tests = {}
        self.error_analysis = {}
        self.cross_validation_results = {}
        
        # Default column names
        self.default_cols = {
            's1_col': 's1',
            's2_col': 's2',
            'label_col': 'is_match',
            'domain_col': 'domain'
        }
        
        # Initialize standard baseline matchers if none provided
        if not self.baseline_matchers:
            self._initialize_baseline_matchers()
    
    def _initialize_baseline_matchers(self):
        """Initialize standard baseline matchers for comparison"""
        try:
            # Create dictionary of baseline matchers
            # These are simple functions that return a similarity score
            
            # Import required libraries
            from Levenshtein import distance as levenshtein_distance
            from Levenshtein import jaro_winkler, ratio as levenshtein_ratio
            import textdistance
            from fuzzywuzzy import fuzz
            import jellyfish
            import re
            
            def preprocess(text):
                """Simple preprocessing for baseline matchers"""
                if not isinstance(text, str):
                    return ""
                # Lowercase and trim
                text = text.lower().strip()
                # Remove most punctuation and normalize spaces
                text = re.sub(r'[^a-z0-9\s]', ' ', text)
                # Normalize spaces
                text = re.sub(r'\s+', ' ', text).strip()
                return text
                
            # Define matcher functions
            def jaro_winkler_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return jaro_winkler(s1_clean, s2_clean)
            
            def levenshtein_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                max_len = max(len(s1_clean), len(s2_clean))
                if max_len == 0:
                    return 0.0
                distance = levenshtein_distance(s1_clean, s2_clean)
                return 1.0 - (distance / max_len)
            
            def token_sort_ratio_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return fuzz.token_sort_ratio(s1_clean, s2_clean) / 100.0
            
            def token_set_ratio_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return fuzz.token_set_ratio(s1_clean, s2_clean) / 100.0
            
            def jaccard_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return textdistance.jaccard.normalized_similarity(s1_clean, s2_clean)
            
            def cosine_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return textdistance.cosine.normalized_similarity(s1_clean, s2_clean)
            
            def sorensen_dice_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return textdistance.sorensen_dice.normalized_similarity(s1_clean, s2_clean)
            
            def overlap_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                return textdistance.overlap.normalized_similarity(s1_clean, s2_clean)
            
            def metaphone_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                # Calculate metaphone similarity at word level
                s1_words = s1_clean.split()
                s2_words = s2_clean.split()
                if not s1_words or not s2_words:
                    return 0.0
                    
                # Get metaphone codes for each word
                s1_codes = [jellyfish.metaphone(word) for word in s1_words]
                s2_codes = [jellyfish.metaphone(word) for word in s2_words]
                
                # Count matching codes
                matches = 0
                for code in s1_codes:
                    if code in s2_codes:
                        matches += 1
                        s2_codes.remove(code)
                
                total = max(len(s1_words), len(s2_words))
                return matches / total if total > 0 else 0.0
            
            def soundex_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                # Calculate soundex similarity at word level
                s1_words = s1_clean.split()
                s2_words = s2_clean.split()
                if not s1_words or not s2_words:
                    return 0.0
                    
                # Get soundex codes for each word
                s1_codes = [jellyfish.soundex(word) for word in s1_words]
                s2_codes = [jellyfish.soundex(word) for word in s2_words]
                
                # Count matching codes
                matches = 0
                for code in s1_codes:
                    if code in s2_codes:
                        matches += 1
                        s2_codes.remove(code)
                
                total = max(len(s1_words), len(s2_words))
                return matches / total if total > 0 else 0.0
            
            def contains_matcher(s1, s2, domain=None):
                s1_clean = preprocess(s1)
                s2_clean = preprocess(s2)
                if not s1_clean or not s2_clean:
                    return 0.0
                if s1_clean in s2_clean or s2_clean in s1_clean:
                    return 1.0
                else:
                    # Check word-level containment
                    s1_words = set(s1_clean.split())
                    s2_words = set(s2_clean.split())
                    if s1_words.issubset(s2_words) or s2_words.issubset(s1_words):
                        return 0.9
                    # Check overlap
                    intersection = s1_words.intersection(s2_words)
                    shorter_len = min(len(s1_words), len(s2_words))
                    if shorter_len == 0:
                        return 0.0
                    return len(intersection) / shorter_len
            
            # Add all matchers to baseline dictionary
            self.baseline_matchers = {
                'Jaro-Winkler': jaro_winkler_matcher,
                'Levenshtein': levenshtein_matcher,
                'Token Sort Ratio': token_sort_ratio_matcher,
                'Token Set Ratio': token_set_ratio_matcher,
                'Jaccard': jaccard_matcher,
                'Cosine': cosine_matcher,
                'Sorensen-Dice': sorensen_dice_matcher,
                'Overlap': overlap_matcher,
                'Metaphone': metaphone_matcher,
                'Soundex': soundex_matcher,
                'Contains': contains_matcher
            }
            
            self.logger.info(f"Initialized {len(self.baseline_matchers)} baseline matchers")
            
        except ImportError as e:
            self.logger.warning(f"Could not initialize all baseline matchers: {e}")
            self.baseline_matchers = {}
    
    def evaluate_matcher(self, matcher, test_data=None, s1_col=None, s2_col=None, 
                        label_col=None, domain_col=None, threshold=0.75, name="Primary"):
        """
        Evaluate a single matcher on test data
        
        Args:
            matcher: Matcher to evaluate (function or object with match_merchants method)
            test_data (DataFrame): Test data with merchant pairs and labels
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            label_col (str): Column name for match label (1=match, 0=no match)
            domain_col (str): Column name for domain information
            threshold (float): Score threshold for binary classification
            name (str): Name for the matcher in results
            
        Returns:
            dict: Evaluation metrics
        """
        # Use provided test data or split ground truth data
        if test_data is None:
            if self.ground_truth_data is None:
                self.logger.error("No test data or ground truth data provided")
                return None
            test_data = self._split_data(self.ground_truth_data)[1]
        
        # Use provided column names or defaults
        s1_col = s1_col or self.default_cols['s1_col']
        s2_col = s2_col or self.default_cols['s2_col']
        label_col = label_col or self.default_cols['label_col']
        domain_col = domain_col or self.default_cols['domain_col']
        
        # Check required columns
        if s1_col not in test_data.columns or s2_col not in test_data.columns:
            self.logger.error(f"Required columns missing: {s1_col}, {s2_col}")
            return None
        
        if label_col not in test_data.columns:
            self.logger.error(f"Label column missing: {label_col}")
            return None
        
        # Calculate match scores
        self.logger.info(f"Evaluating matcher: {name}")
        
        y_true = []
        y_scores = []
        domains = []
        errors = []
        
        start_time = time.time()
        
        for i, row in test_data.iterrows():
            try:
                s1 = row[s1_col]
                s2 = row[s2_col]
                
                # Skip if missing data
                if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                    continue
                
                # Get true label
                true_label = int(row[label_col])
                y_true.append(true_label)
                
                # Get domain if available
                domain = row[domain_col] if domain_col and domain_col in row else None
                if domain:
                    domains.append(domain)
                
                # Calculate similarity score based on matcher type
                if hasattr(matcher, 'match_merchants'):
                    # Matcher is an object with match_merchants method
                    score = matcher.match_merchants(s1, s2, domain)
                else:
                    # Matcher is a function
                    score = matcher(s1, s2, domain)
                
                y_scores.append(score)
                
            except Exception as e:
                errors.append((i, str(e)))
                # Skip this pair
                continue
        
        processing_time = time.time() - start_time
        
        # Check if we have enough data
        if len(y_true) < 10:
            self.logger.error(f"Insufficient data for evaluation: {len(y_true)} valid pairs")
            return None
        
        # Convert to numpy arrays
        y_true = np.array(y_true)
        y_scores = np.array(y_scores)
        
        # Calculate metrics
        metrics = self._calculate_metrics(y_true, y_scores, threshold)
        
        # Add additional information
        metrics['processing_time'] = processing_time
        metrics['avg_processing_time'] = processing_time / len(y_true) if len(y_true) > 0 else 0
        metrics['error_count'] = len(errors)
        metrics['error_rate'] = len(errors) / (len(y_true) + len(errors)) if (len(y_true) + len(errors)) > 0 else 0
        
        # Add domain-specific metrics if domains available
        if domains:
            metrics['domain_metrics'] = self._calculate_domain_metrics(
                y_true, y_scores, domains, threshold
            )
        
        # Store results
        self.results[name] = metrics
        
        # Log summary
        self.logger.info(
            f"Evaluation complete: {name}, Accuracy: {metrics['accuracy']:.4f}, "
            f"F1: {metrics['f1_score']:.4f}, AUC: {metrics['auc_roc']:.4f}"
        )
        
        return metrics
    
    def _calculate_metrics(self, y_true, y_scores, threshold):
        """
        Calculate comprehensive evaluation metrics
        
        Args:
            y_true (array): True binary labels
            y_scores (array): Predicted scores
            threshold (float): Score threshold for binary classification
            
        Returns:
            dict: Calculated metrics
        """
        from sklearn.metrics import (
            accuracy_score, precision_score, recall_score, f1_score,
            roc_auc_score, precision_recall_curve, auc,
            confusion_matrix, matthews_corrcoef, balanced_accuracy_score
        )
        
        # Convert scores to binary predictions using threshold
        y_pred = (y_scores >= threshold).astype(int)
        
        # Basic metrics
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, zero_division=0),
            'recall': recall_score(y_true, y_pred, zero_division=0),
            'f1_score': f1_score(y_true, y_pred, zero_division=0),
            'matthews_corrcoef': matthews_corrcoef(y_true, y_pred),
        }
        
        # Confusion matrix
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        metrics['true_positives'] = int(tp)
        metrics['false_positives'] = int(fp)
        metrics['true_negatives'] = int(tn)
        metrics['false_negatives'] = int(fn)
        
        # Calculate AUC metrics if possible
        try:
            metrics['auc_roc'] = roc_auc_score(y_true, y_scores)
            
            # Precision-recall curve and AUC
            precision, recall, _ = precision_recall_curve(y_true, y_scores)
            metrics['auc_pr'] = auc(recall, precision)
        except Exception as e:
            self.logger.warning(f"Could not calculate AUC metrics: {e}")
            metrics['auc_roc'] = 0.0
            metrics['auc_pr'] = 0.0
        
        # Calculate optimal threshold using F1 score
        try:
            thresholds = np.linspace(0, 1, 100)
            f1_scores = []
            
            for t in thresholds:
                y_pred_t = (y_scores >= t).astype(int)
                f1 = f1_score(y_true, y_pred_t, zero_division=0)
                f1_scores.append(f1)
            
            optimal_idx = np.argmax(f1_scores)
            metrics['optimal_threshold'] = thresholds[optimal_idx]
            metrics['optimal_f1_score'] = f1_scores[optimal_idx]
        except Exception as e:
            self.logger.warning(f"Could not calculate optimal threshold: {e}")
            metrics['optimal_threshold'] = threshold
            metrics['optimal_f1_score'] = metrics['f1_score']
        
        return metrics
    
    def _calculate_domain_metrics(self, y_true, y_scores, domains, threshold):
        """
        Calculate metrics for each domain separately
        
        Args:
            y_true (array): True binary labels
            y_scores (array): Predicted scores
            domains (list): Domain for each pair
            threshold (float): Score threshold for binary classification
            
        Returns:
            dict: Domain-specific metrics
        """
        # Get unique domains
        unique_domains = list(set(domains))
        
        # Calculate metrics for each domain
        domain_metrics = {}
        
        for domain in unique_domains:
            # Get indices for this domain
            indices = [i for i, d in enumerate(domains) if d == domain]
            
            # Skip if too few samples
            if len(indices) < 10:
                continue
                
            # Calculate metrics for this domain
            domain_y_true = np.array([y_true[i] for i in indices])
            domain_y_scores = np.array([y_scores[i] for i in indices])
            
            domain_metrics[domain] = self._calculate_metrics(
                domain_y_true, domain_y_scores, threshold
            )
        
        return domain_metrics
    
    def _split_data(self, data, stratify_col=None):
        """
        Split data into training and testing sets
        
        Args:
            data (DataFrame): Data to split
            stratify_col (str, optional): Column to use for stratified sampling
            
        Returns:
            tuple: (train_data, test_data)
        """
        from sklearn.model_selection import train_test_split
        
        if stratify_col and stratify_col in data.columns:
            stratify = data[stratify_col]
        else:
            stratify = None
        
        train_data, test_data = train_test_split(
            data,
            test_size=self.test_size,
            random_state=self.random_state,
            stratify=stratify
        )
        
        return train_data, test_data
    
    def compare_matchers(self, test_data=None, s1_col=None, s2_col=None, 
                         label_col=None, domain_col=None, threshold=0.75):
        """
        Compare multiple matchers on the same test data
        
        Args:
            test_data (DataFrame): Test data with merchant pairs and labels
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            label_col (str): Column name for match label (1=match, 0=no match)
            domain_col (str): Column name for domain information
            threshold (float): Score threshold for binary classification
            
        Returns:
            dict: Comparison results for all matchers
        """
        # Use provided test data or split ground truth data
        if test_data is None:
            if self.ground_truth_data is None:
                self.logger.error("No test data or ground truth data provided")
                return None
            test_data = self._split_data(self.ground_truth_data)[1]
        
        # Use provided column names or defaults
        s1_col = s1_col or self.default_cols['s1_col']
        s2_col = s2_col or self.default_cols['s2_col']
        label_col = label_col or self.default_cols['label_col']
        domain_col = domain_col or self.default_cols['domain_col']
        
        # Reset results
        self.results = {}
        
        # Evaluate primary matcher if available
        if self.matcher:
            self.evaluate_matcher(
                self.matcher,
                test_data,
                s1_col,
                s2_col,
                label_col,
                domain_col,
                threshold,
                "Primary"
            )
        
        # Evaluate all baseline matchers
        for name, matcher in self.baseline_matchers.items():
            self.evaluate_matcher(
                matcher,
                test_data,
                s1_col,
                s2_col,
                label_col,
                domain_col,
                threshold,
                name
            )
        
        # Perform statistical significance testing
        if len(self.results) > 1 and "Primary" in self.results:
            self._perform_statistical_tests(test_data, s1_col, s2_col, label_col, domain_col)
        
        # Return results
        return self.results
    
    def _perform_statistical_tests(self, test_data, s1_col, s2_col, label_col, domain_col):
        """
        Perform statistical significance tests on matcher results
        
        Args:
            test_data (DataFrame): Test data
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            label_col (str): Column name for true label
            domain_col (str): Column name for domain
        """
        try:
            from scipy import stats
            import numpy as np
            
            # Reset statistical tests
            self.statistical_tests = {}
            
            # We need to compare predictions of each matcher on exactly the same examples
            # First, collect predictions from all matchers
            matcher_predictions = {}
            
            # Primary matcher (target for comparison)
            if "Primary" not in self.results:
                return
                
            primary_matcher = self.matcher
            primary_predictions = []
            ground_truth = []
            
            # Other matchers
            baseline_matchers = {}
            baseline_predictions = {}
            
            # Initialize for baseline matchers
            for name in self.results:
                if name != "Primary":
                    baseline_matchers[name] = self.baseline_matchers.get(name)
                    baseline_predictions[name] = []
            
            # Collect predictions for each example
            for i, row in test_data.iterrows():
                try:
                    s1 = row[s1_col]
                    s2 = row[s2_col]
                    
                    # Skip if missing data
                    if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                        continue
                    
                    # Get true label
                    true_label = int(row[label_col])
                    
                    # Get domain if available
                    domain = row[domain_col] if domain_col and domain_col in row else None
                    
                    # Get primary matcher prediction
                    if hasattr(primary_matcher, 'match_merchants'):
                        # Matcher is an object with match_merchants method
                        primary_score = primary_matcher.match_merchants(s1, s2, domain)
                    else:
                        # Matcher is a function
                        primary_score = primary_matcher(s1, s2, domain)
                    
                    # Skip if primary matcher fails
                    if primary_score is None:
                        continue
                    
                    # Get baseline matcher predictions
                    baseline_scores = {}
                    all_valid = True
                    
                    for name, matcher in baseline_matchers.items():
                        try:
                            if matcher is None:
                                continue
                                
                            if hasattr(matcher, 'match_merchants'):
                                score = matcher.match_merchants(s1, s2, domain)
                            else:
                                score = matcher(s1, s2, domain)
                                
                            if score is None:
                                all_valid = False
                                break
                                
                            baseline_scores[name] = score
                        except Exception:
                            all_valid = False
                            break
                    
                    # Only include examples where all matchers produced valid scores
                    if all_valid and baseline_scores:
                        ground_truth.append(true_label)
                        primary_predictions.append(primary_score)
                        
                        for name, score in baseline_scores.items():
                            baseline_predictions[name].append(score)
                    
                except Exception:
                    # Skip examples with errors
                    continue
            
            # Convert to numpy arrays
            ground_truth = np.array(ground_truth)
            primary_predictions = np.array(primary_predictions)
            
            for name in baseline_predictions:
                baseline_predictions[name] = np.array(baseline_predictions[name])
            
            # Perform McNemar's test for binary classification
            # First convert scores to binary predictions using optimal thresholds
            primary_threshold = self.results["Primary"].get("optimal_threshold", 0.75)
            primary_binary = (primary_predictions >= primary_threshold).astype(int)
            
            # For each baseline matcher
            for name, predictions in baseline_predictions.items():
                if len(predictions) != len(primary_binary):
                    continue
                    
                baseline_threshold = self.results[name].get("optimal_threshold", 0.75)
                baseline_binary = (predictions >= baseline_threshold).astype(int)
                
                # Create contingency table for McNemar's test
                # [both wrong, baseline right & primary wrong,
                #  baseline wrong & primary right, both right]
                contingency_table = [
                    sum((baseline_binary == 0) & (primary_binary == 0) & (ground_truth == 1) | 
                        (baseline_binary == 1) & (primary_binary == 1) & (ground_truth == 0)),
                    sum((baseline_binary == 1) & (primary_binary == 0) & (ground_truth == 1) | 
                        (baseline_binary == 0) & (primary_binary == 1) & (ground_truth == 0)),
                    sum((baseline_binary == 0) & (primary_binary == 1) & (ground_truth == 1) | 
                        (baseline_binary == 1) & (primary_binary == 0) & (ground_truth == 0)),
                    sum((baseline_binary == 1) & (primary_binary == 1) & (ground_truth == 1) | 
                        (baseline_binary == 0) & (primary_binary == 0) & (ground_truth == 0))
                ]
                
                # Reshape for statsmodels
                table = np.array([[contingency_table[0], contingency_table[1]],
                                  [contingency_table[2], contingency_table[3]]])
                
                # Perform McNemar's test
                try:
                    mcnemar_result = stats.mcnemar(table, exact=True)
                    
                    # Store results
                    self.statistical_tests[name] = {
                        'test_name': "McNemar's test",
                        'statistic': float(mcnemar_result.statistic),
                        'p_value': float(mcnemar_result.pvalue),
                        'significant': mcnemar_result.pvalue < 0.05,
                        'contingency_table': contingency_table
                    }
                except Exception as e:
                    self.logger.warning(f"McNemar's test failed for {name}: {e}")
                    
                # Also perform signed-rank test on score differences
                try:
                    # Calculate absolute errors
                    primary_errors = np.abs(primary_predictions - ground_truth)
                    baseline_errors = np.abs(predictions - ground_truth)
                    
                    # Perform Wilcoxon signed-rank test
                    wilcoxon_result = stats.wilcoxon(baseline_errors, primary_errors)
                    
                    # Store results
                    self.statistical_tests[f"{name}_wilcoxon"] = {
                        'test_name': "Wilcoxon signed-rank test",
                        'statistic': float(wilcoxon_result.statistic),
                        'p_value': float(wilcoxon_result.pvalue),
                        'significant': wilcoxon_result.pvalue < 0.05
                    }
                except Exception as e:
                    self.logger.warning(f"Wilcoxon test failed for {name}: {e}")
            
            self.logger.info(f"Completed statistical significance testing against {len(baseline_predictions)} baselines")
            
        except ImportError:
            self.logger.warning("SciPy not available for statistical testing")
        
    def perform_cross_validation(self, data=None, matcher=None, n_splits=5,
                                s1_col=None, s2_col=None, label_col=None, domain_col=None,
                                threshold=0.75, name="Primary", stratify=True):
        """
        Perform cross-validation for robust performance estimation
        
        Args:
            data (DataFrame): Data with merchant pairs and labels
            matcher: Matcher to evaluate
            n_splits (int): Number of cross-validation splits
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            label_col (str): Column name for match label
            domain_col (str): Column name for domain
            threshold (float): Score threshold for binary classification
            name (str): Name for the matcher in results
            stratify (bool): Whether to use stratified sampling
            
        Returns:
            dict: Cross-validation results
        """
        try:
            from sklearn.model_selection import StratifiedKFold, KFold
            
            # Use provided data or ground truth data
            if data is None:
                if self.ground_truth_data is None:
                    self.logger.error("No data or ground truth data provided")
                    return None
                data = self.ground_truth_data
            
            # Use provided matcher or primary matcher
            if matcher is None:
                matcher = self.matcher
            
            if matcher is None:
                self.logger.error("No matcher provided")
                return None
            
            # Use provided column names or defaults
            s1_col = s1_col or self.default_cols['s1_col']
            s2_col = s2_col or self.default_cols['s2_col']
            label_col = label_col or self.default_cols['label_col']
            domain_col = domain_col or self.default_cols['domain_col']
            
            # Check required columns
            if s1_col not in data.columns or s2_col not in data.columns:
                self.logger.error(f"Required columns missing: {s1_col}, {s2_col}")
                return None
            
            if label_col not in data.columns:
                self.logger.error(f"Label column missing: {label_col}")
                return None
            
            # Create cross-validation splits
            if stratify and label_col in data.columns:
                # Use stratified sampling to maintain class balance
                cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
                splits = list(cv.split(data, data[label_col]))
            else:
                # Use regular k-fold cross-validation
                cv = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
                splits = list(cv.split(data))
            
            # Initialize results
            cv_results = {
                'fold_metrics': [],
                'aggregate_metrics': {},
                'roc_curves': [],
                'pr_curves': []
            }
            
            # For each fold
            for fold, (train_idx, test_idx) in enumerate(splits):
                # Split data
                train_data = data.iloc[train_idx]
                test_data = data.iloc[test_idx]
                
                # Evaluate on this fold
                fold_metrics = self.evaluate_matcher(
                    matcher, 
                    test_data, 
                    s1_col, 
                    s2_col, 
                    label_col, 
                    domain_col, 
                    threshold, 
                    f"{name}_fold{fold+1}"
                )
                
                # Add fold number
                fold_metrics['fold'] = fold + 1
                
                # Add to results
                cv_results['fold_metrics'].append(fold_metrics)
                
                # Store ROC and PR curves
                if test_data is not None:
                    # Calculate predictions
                    y_true = []
                    y_scores = []
                    
                    for i, row in test_data.iterrows():
                        try:
                            s1 = row[s1_col]
                            s2 = row[s2_col]
                            
                            # Skip if missing data
                            if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                                continue
                            
                            # Get true label
                            true_label = int(row[label_col])
                            y_true.append(true_label)
                            
                            # Get domain if available
                            domain = row[domain_col] if domain_col and domain_col in row else None
                            
                            # Calculate similarity score
                            if hasattr(matcher, 'match_merchants'):
                                score = matcher.match_merchants(s1, s2, domain)
                            else:
                                score = matcher(s1, s2, domain)
                            
                            y_scores.append(score)
                            
                        except Exception:
                            # Skip examples with errors
                            continue
                    
                    # Convert to numpy arrays
                    y_true = np.array(y_true)
                    y_scores = np.array(y_scores)
                    
                    # Calculate ROC curve
                    try:
                        from sklearn.metrics import roc_curve, precision_recall_curve
                        
                        fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
                        cv_results['roc_curves'].append({
                            'fold': fold + 1,
                            'fpr': fpr.tolist(),
                            'tpr': tpr.tolist(),
                            'thresholds': roc_thresholds.tolist()
                        })
                        
                        # Calculate PR curve
                        precision, recall, pr_thresholds = precision_recall_curve(y_true, y_scores)
                        cv_results['pr_curves'].append({
                            'fold': fold + 1,
                            'precision': precision.tolist(),
                            'recall': recall.tolist(),
                            'thresholds': pr_thresholds.tolist() if len(pr_thresholds) > 0 else []
                        })
                    except Exception as e:
                        self.logger.warning(f"Error calculating curves for fold {fold+1}: {e}")
            
            # Calculate aggregate metrics across folds
            metrics_keys = [
                'accuracy', 'balanced_accuracy', 'precision', 'recall', 'f1_score',
                'matthews_corrcoef', 'auc_roc', 'auc_pr'
            ]
            
            for key in metrics_keys:
                values = [m.get(key, 0) for m in cv_results['fold_metrics'] if key in m]
                if values:
                    cv_results['aggregate_metrics'][key] = {
                        'mean': np.mean(values),
                        'std': np.std(values),
                        'min': np.min(values),
                        'max': np.max(values),
                        'values': values
                    }
            
            # Store in instance
            self.cross_validation_results[name] = cv_results
            
            # Log summary
            self.logger.info(
                f"Cross-validation complete for {name}: "
                f"Mean F1: {cv_results['aggregate_metrics'].get('f1_score', {}).get('mean', 0):.4f} ± "
                f"{cv_results['aggregate_metrics'].get('f1_score', {}).get('std', 0):.4f}"
            )
            
            return cv_results
            
        except ImportError:
            self.logger.warning("scikit-learn not available for cross-validation")
            return None
    
    def analyze_errors(self, test_data=None, matcher=None, 
                     s1_col=None, s2_col=None, label_col=None, domain_col=None,
                     threshold=0.75, name="Primary"):
        """
        Perform detailed error analysis to understand matcher weaknesses
        
        Args:
            test_data (DataFrame): Test data
            matcher: Matcher to analyze
            s1_col (str): Column name for first merchant name
            s2_col (str): Column name for second merchant name
            label_col (str): Column name for label
            domain_col (str): Column name for domain
            threshold (float): Score threshold for binary classification
            name (str): Name for the matcher in results
            
        Returns:
            dict: Error analysis results
        """
        # Use provided test data or split ground truth data
        if test_data is None:
            if self.ground_truth_data is None:
                self.logger.error("No test data or ground truth data provided")
                return None
            test_data = self._split_data(self.ground_truth_data)[1]
        
        # Use provided matcher or primary matcher
        if matcher is None:
            matcher = self.matcher
        
        if matcher is None:
            self.logger.error("No matcher provided")
            return None
        
        # Use provided column names or defaults
        s1_col = s1_col or self.default_cols['s1_col']
        s2_col = s2_col or self.default_cols['s2_col']
        label_col = label_col or self.default_cols['label_col']
        domain_col = domain_col or self.default_cols['domain_col']
        
        # Initialize error analysis
        analysis = {
            'false_positives': [],
            'false_negatives': [],
            'error_patterns': {},
            'summary': {},
            'error_counts': {
                'false_positives': 0,
                'false_negatives': 0
            }
        }
        
        # Collect predictions and analyze errors
        for i, row in test_data.iterrows():
            try:
                s1 = row[s1_col]
                s2 = row[s2_col]
                
                # Skip if missing data
                if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                    continue
                
                # Get true label
                true_label = int(row[label_col])
                
                # Get domain if available
                domain = row[domain_col] if domain_col and domain_col in row else None
                
                # Calculate similarity score
                if hasattr(matcher, 'match_merchants'):
                    score = matcher.match_merchants(s1, s2, domain)
                else:
                    score = matcher(s1, s2, domain)
                
                # Convert to binary prediction
                pred_label = 1 if score >= threshold else 0
                
                # Check for error
                if pred_label != true_label:
                    if pred_label == 1 and true_label == 0:
                        # False positive
                        analysis['false_positives'].append({
                            'id': i,
                            's1': s1,
                            's2': s2,
                            'domain': domain,
                            'score': score,
                            'threshold': threshold
                        })
                        analysis['error_counts']['false_positives'] += 1
                    elif pred_label == 0 and true_label == 1:
                        # False negative
                        analysis['false_negatives'].append({
                            'id': i,
                            's1': s1,
                            's2': s2,
                            'domain': domain,
                            'score': score,
                            'threshold': threshold
                        })
                        analysis['error_counts']['false_negatives'] += 1
            
            except Exception:
                # Skip examples with errors
                continue
        
        # Analyze error patterns
        analysis['error_patterns'] = self._identify_error_patterns(
            analysis['false_positives'], 
            analysis['false_negatives']
        )
        
        # Create summary
        analysis['summary'] = {
            'total_errors': analysis['error_counts']['false_positives'] + analysis['error_counts']['false_negatives'],
            'false_positive_rate': analysis['error_counts']['false_positives'] / len(test_data) if len(test_data) > 0 else 0,
            'false_negative_rate': analysis['error_counts']['false_negatives'] / len(test_data) if len(test_data) > 0 else 0,
            'error_rate': (analysis['error_counts']['false_positives'] + analysis['error_counts']['false_negatives']) / len(test_data) if len(test_data) > 0 else 0,
            'top_error_patterns': sorted(
                analysis['error_patterns'].items(), 
                key=lambda x: x[1]['count'], 
                reverse=True
            )[:5] if analysis['error_patterns'] else []
        }
        
        # Store in instance
        self.error_analysis[name] = analysis
        
        # Log summary
        self.logger.info(
            f"Error analysis complete for {name}: "
            f"False positives: {analysis['error_counts']['false_positives']}, "
            f"False negatives: {analysis['error_counts']['false_negatives']}"
        )
        
        return analysis
    
    def _identify_error_patterns(self, false_positives, false_negatives):
        """
        Identify common patterns in errors
        
        Args:
            false_positives (list): List of false positive examples
            false_negatives (list): List of false negative examples
            
        Returns:
            dict: Error patterns with explanations
        """
        patterns = {}
        
        # Helper function to detect patterns
        def check_patterns(errors, error_type):
            for error in errors:
                s1 = error['s1'].lower()
                s2 = error['s2'].lower()
                
                # Check for length difference
                len_ratio = min(len(s1), len(s2)) / max(len(s1), len(s2)) if max(len(s1), len(s2)) > 0 else 0
                if len_ratio < 0.5:
                    pattern = "large_length_difference"
                    if pattern not in patterns:
                        patterns[pattern] = {
                            'name': "Large length difference",
                            'description': "One name is much shorter than the other",
                            'count': 0,
                            'examples': [],
                            'error_types': set()
                        }
                    patterns[pattern]['count'] += 1
                    patterns[pattern]['error_types'].add(error_type)
                    if len(patterns[pattern]['examples']) < 3:
                        patterns[pattern]['examples'].append((s1, s2))
                
                # Check for containment
                if s1 in s2 or s2 in s1:
                    pattern = "name_containment"
                    if pattern not in patterns:
                        patterns[pattern] = {
                            'name': "Name containment",
                            'description': "One name is contained within the other",
                            'count': 0,
                            'examples': [],
                            'error_types': set()
                        }
                    patterns[pattern]['count'] += 1
                    patterns[pattern]['error_types'].add(error_type)
                    if len(patterns[pattern]['examples']) < 3:
                        patterns[pattern]['examples'].append((s1, s2))
                
                # Check for acronyms
                if len(s1.split()) == 1 and len(s1) <= 5:
                    # s1 might be acronym
                    s2_words = s2.split()
                    if len(s2_words) >= 2:
                        s2_initials = ''.join([w[0] for w in s2_words if w])
                        if s1 in s2_initials or s2_initials in s1:
                            pattern = "acronym_confusion"
                            if pattern not in patterns:
                                patterns[pattern] = {
                                    'name': "Acronym confusion",
                                    'description': "One name is an acronym or abbreviation of the other",
                                    'count': 0,
                                    'examples': [],
                                    'error_types': set()
                                }
                            patterns[pattern]['count'] += 1
                            patterns[pattern]['error_types'].add(error_type)
                            if len(patterns[pattern]['examples']) < 3:
                                patterns[pattern]['examples'].append((s1, s2))
                
                # Check for word reordering
                s1_words = set(s1.split())
                s2_words = set(s2.split())
                if s1_words == s2_words and s1 != s2:
                    pattern = "word_reordering"
                    if pattern not in patterns:
                        patterns[pattern] = {
                            'name': "Word reordering",
                            'description': "Names have the same words but in different order",
                            'count': 0,
                            'examples': [],
                            'error_types': set()
                        }
                    patterns[pattern]['count'] += 1
                    patterns[pattern]['error_types'].add(error_type)
                    if len(patterns[pattern]['examples']) < 3:
                        patterns[pattern]['examples'].append((s1, s2))
                
                # Check for partial word matches
                common_words = s1_words.intersection(s2_words)
                if common_words and len(common_words) / max(len(s1_words), len(s2_words)) > 0.5:
                    pattern = "partial_word_match"
                    if pattern not in patterns:
                        patterns[pattern] = {
                            'name': "Partial word match",
                            'description': "Names share many common words",
                            'count': 0,
                            'examples': [],
                            'error_types': set()
                        }
                    patterns[pattern]['count'] += 1
                    patterns[pattern]['error_types'].add(error_type)
                    if len(patterns[pattern]['examples']) < 3:
                        patterns[pattern]['examples'].append((s1, s2))
        
        # Check both error types
        check_patterns(false_positives, "false_positive")
        check_patterns(false_negatives, "false_negative")
        
        return patterns

In [51]:
# 11.2: Visualization and Reporting

class MerchantMatchingVisualizer:
    """
    Comprehensive visualization toolkit for merchant matching evaluation results,
    providing intuitive graphical displays of performance metrics, comparisons,
    and error analysis.
    
    Key features:
    - Performance metric visualizations (ROC curves, PR curves, confusion matrices)
    - Multi-algorithm comparison charts
    - Cross-validation result visualization
    - Domain-specific performance analysis
    - Error pattern visualization
    - Interactive HTML report generation
    """
    
    def __init__(self, evaluator=None, output_dir=None, interactive=True):
        """
        Initialize visualizer with evaluator and output settings
        
        Args:
            evaluator (MerchantMatchingEvaluator): Evaluator with results
            output_dir (str): Directory for saving visualizations
            interactive (bool): Whether to create interactive visualizations
        """
        self.evaluator = evaluator
        self.output_dir = output_dir or os.path.join(os.getcwd(), 'merchant_matching_results')
        self.interactive = interactive
        
        # Create output directory if it doesn't exist
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        
        # Set up logging
        self.logger = logging.getLogger('MerchantVisualizer')
        self.logger.setLevel(logging.INFO)
        
        # Check if matplotlib is available
        try:
            import matplotlib.pyplot as plt
            self.matplotlib_available = True
        except ImportError:
            self.logger.warning("Matplotlib not available. Visualizations limited.")
            self.matplotlib_available = False
        
        # Check if plotly is available for interactive visualizations
        if self.interactive:
            try:
                import plotly.graph_objects as go
                self.plotly_available = True
            except ImportError:
                self.logger.warning("Plotly not available. Interactive visualizations disabled.")
                self.plotly_available = False
                self.interactive = False
        else:
            self.plotly_available = False
    
    def plot_roc_curves(self, save=True, filename='roc_curves.png', figsize=(10, 8)):
        """
        Plot ROC curves for all evaluated matchers
        
        Args:
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if not self.matplotlib_available or not self.evaluator or not self.evaluator.results:
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            from sklearn.metrics import roc_curve, auc
            
            # Create figure
            plt.figure(figsize=figsize)
            
            # Sort matchers by AUC (descending)
            sorted_matchers = sorted(
                self.evaluator.results.items(),
                key=lambda x: x[1].get('auc_roc', 0),
                reverse=True
            )
            
            # Plot ROC curve for each matcher
            for name, metrics in sorted_matchers:
                if 'auc_roc' not in metrics:
                    continue
                
                # Check if we have test data available from the evaluator
                test_data = getattr(self.evaluator, 'last_test_data', None)
                if test_data is not None:
                    # Extract column names
                    s1_col = self.evaluator.default_cols['s1_col']
                    s2_col = self.evaluator.default_cols['s2_col']
                    label_col = self.evaluator.default_cols['label_col']
                    domain_col = self.evaluator.default_cols['domain_col']
                    
                    # Get matcher
                    if name == "Primary":
                        matcher = self.evaluator.matcher
                    else:
                        matcher = self.evaluator.baseline_matchers.get(name)
                    
                    if matcher is None:
                        continue
                    
                    # Calculate predictions
                    y_true = []
                    y_scores = []
                    
                    for i, row in test_data.iterrows():
                        try:
                            s1 = row[s1_col]
                            s2 = row[s2_col]
                            
                            # Skip if missing data
                            if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                                continue
                            
                            # Get true label
                            true_label = int(row[label_col])
                            y_true.append(true_label)
                            
                            # Get domain if available
                            domain = row[domain_col] if domain_col and domain_col in row else None
                            
                            # Calculate similarity score
                            if hasattr(matcher, 'match_merchants'):
                                score = matcher.match_merchants(s1, s2, domain)
                            else:
                                score = matcher(s1, s2, domain)
                            
                            y_scores.append(score)
                            
                        except Exception:
                            # Skip examples with errors
                            continue
                    
                    # Convert to numpy arrays
                    y_true = np.array(y_true)
                    y_scores = np.array(y_scores)
                    
                    # Calculate ROC curve
                    fpr, tpr, _ = roc_curve(y_true, y_scores)
                    roc_auc = auc(fpr, tpr)
                    
                    # Plot ROC curve
                    plt.plot(
                        fpr, tpr, 
                        label=f"{name} (AUC = {roc_auc:.3f})",
                        linewidth=2
                    )
                else:
                    # Use AUC from metrics
                    plt.plot(
                        [0, 1], [0, 1], 
                        label=f"{name} (AUC = {metrics['auc_roc']:.3f})",
                        linewidth=2
                    )
            
            # Add diagonal line
            plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
            
            # Add labels and legend
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver Operating Characteristic (ROC) Curves')
            plt.legend(loc='lower right')
            plt.grid(alpha=0.3)
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"ROC curves saved to {save_path}")
            
            return plt.gcf()
            
        except Exception as e:
            self.logger.error(f"Error plotting ROC curves: {e}")
            return None
    
    def plot_precision_recall_curves(self, save=True, filename='pr_curves.png', figsize=(10, 8)):
        """
        Plot precision-recall curves for all evaluated matchers
        
        Args:
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if not self.matplotlib_available or not self.evaluator or not self.evaluator.results:
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            from sklearn.metrics import precision_recall_curve, auc
            
            # Create figure
            plt.figure(figsize=figsize)
            
            # Sort matchers by PR AUC (descending)
            sorted_matchers = sorted(
                self.evaluator.results.items(),
                key=lambda x: x[1].get('auc_pr', 0),
                reverse=True
            )
            
            # Plot PR curve for each matcher
            for name, metrics in sorted_matchers:
                if 'auc_pr' not in metrics:
                    continue
                
                # Check if we have test data available from the evaluator
                test_data = getattr(self.evaluator, 'last_test_data', None)
                if test_data is not None:
                    # Extract column names
                    s1_col = self.evaluator.default_cols['s1_col']
                    s2_col = self.evaluator.default_cols['s2_col']
                    label_col = self.evaluator.default_cols['label_col']
                    domain_col = self.evaluator.default_cols['domain_col']
                    
                    # Get matcher
                    if name == "Primary":
                        matcher = self.evaluator.matcher
                    else:
                        matcher = self.evaluator.baseline_matchers.get(name)
                    
                    if matcher is None:
                        continue
                    
                    # Calculate predictions
                    y_true = []
                    y_scores = []
                    
                    for i, row in test_data.iterrows():
                        try:
                            s1 = row[s1_col]
                            s2 = row[s2_col]
                            
                            # Skip if missing data
                            if pd.isna(s1) or pd.isna(s2) or not isinstance(s1, str) or not isinstance(s2, str):
                                continue
                            
                            # Get true label
                            true_label = int(row[label_col])
                            y_true.append(true_label)
                            
                            # Get domain if available
                            domain = row[domain_col] if domain_col and domain_col in row else None
                            
                            # Calculate similarity score
                            if hasattr(matcher, 'match_merchants'):
                                score = matcher.match_merchants(s1, s2, domain)
                            else:
                                score = matcher(s1, s2, domain)
                            
                            y_scores.append(score)
                            
                        except Exception:
                            # Skip examples with errors
                            continue
                    
                    # Convert to numpy arrays
                    y_true = np.array(y_true)
                    y_scores = np.array(y_scores)
                    
                    # Calculate PR curve
                    precision, recall, _ = precision_recall_curve(y_true, y_scores)
                    pr_auc = auc(recall, precision)
                    
                    # Plot PR curve
                    plt.plot(
                        recall, precision, 
                        label=f"{name} (AUC = {pr_auc:.3f})",
                        linewidth=2
                    )
                else:
                    # Use AUC from metrics
                    plt.plot(
                        [0, 1], [1, 0], 
                        label=f"{name} (AUC = {metrics['auc_pr']:.3f})",
                        linewidth=2
                    )
            
            # Add labels and legend
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title('Precision-Recall Curves')
            plt.legend(loc='lower left')
            plt.grid(alpha=0.3)
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Precision-recall curves saved to {save_path}")
            
            return plt.gcf()
            
        except Exception as e:
            self.logger.error(f"Error plotting precision-recall curves: {e}")
            return None
    
    def plot_metric_comparison(self, metrics=None, save=True, filename='metric_comparison.png', figsize=(12, 8)):
        """
        Plot comparison of multiple metrics across all matchers
        
        Args:
            metrics (list): List of metrics to compare
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if not self.matplotlib_available or not self.evaluator or not self.evaluator.results:
            return None
        
        # Default metrics to compare
        if metrics is None:
            metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc']
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Create figure
            plt.figure(figsize=figsize)
            
            # Get matcher names and sort by f1_score
            matcher_names = sorted(
                self.evaluator.results.keys(),
                key=lambda x: self.evaluator.results[x].get('f1_score', 0),
                reverse=True
            )
            
            # Set up bar positions
            x = np.arange(len(matcher_names))
            width = 0.15  # width of bars
            
            # Create a bar for each metric
            for i, metric in enumerate(metrics):
                values = [self.evaluator.results[name].get(metric, 0) for name in matcher_names]
                plt.bar(x + (i - len(metrics)/2 + 0.5) * width, values, width, label=metric.replace('_', ' ').title())
            
            # Add labels and legend
            plt.xlabel('Matcher')
            plt.ylabel('Score')
            plt.title('Performance Metric Comparison')
            plt.xticks(x, matcher_names, rotation=45, ha='right')
            plt.legend(loc='lower right')
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Metric comparison saved to {save_path}")
            
            return plt.gcf()
            
        except Exception as e:
            self.logger.error(f"Error plotting metric comparison: {e}")
            return None
    
    def plot_confusion_matrices(self, save=True, filename_prefix='confusion_matrix_', figsize=(8, 6)):
        """
        Plot confusion matrices for all evaluated matchers
        
        Args:
            save (bool): Whether to save the plots to files
            filename_prefix (str): Prefix for saved plot filenames
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            dict: Dictionary of matcher name to figure object
        """
        if not self.matplotlib_available or not self.evaluator or not self.evaluator.results:
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            figures = {}
            
            # For each matcher
            for name, metrics in self.evaluator.results.items():
                # Check if confusion matrix components are available
                required_keys = ['true_positives', 'false_positives', 'true_negatives', 'false_negatives']
                if not all(key in metrics for key in required_keys):
                    continue
                
                # Create confusion matrix
                cm = np.array([
                    [metrics['true_negatives'], metrics['false_positives']],
                    [metrics['false_negatives'], metrics['true_positives']]
                ])
                
                # Create figure
                plt.figure(figsize=figsize)
                
                # Plot confusion matrix
                plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
                plt.title(f'Confusion Matrix: {name}')
                plt.colorbar()
                
                # Add labels
                classes = ['Non-match', 'Match']
                tick_marks = np.arange(len(classes))
                plt.xticks(tick_marks, classes)
                plt.yticks(tick_marks, classes)
                
                # Add text annotations
                thresh = cm.max() / 2.0
                for i in range(cm.shape[0]):
                    for j in range(cm.shape[1]):
                        plt.text(j, i, format(cm[i, j], 'd'),
                                horizontalalignment="center",
                                color="white" if cm[i, j] > thresh else "black")
                
                plt.ylabel('True Label')
                plt.xlabel('Predicted Label')
                plt.tight_layout()
                
                # Save if requested
                if save:
                    filename = f"{filename_prefix}{name.lower().replace(' ', '_')}.png"
                    save_path = os.path.join(self.output_dir, filename)
                    plt.savefig(save_path, dpi=300, bbox_inches='tight')
                    self.logger.info(f"Confusion matrix for {name} saved to {save_path}")
                
                figures[name] = plt.gcf()
            
            return figures
            
        except Exception as e:
            self.logger.error(f"Error plotting confusion matrices: {e}")
            return None
    
    def plot_cross_validation_results(self, matcher_name="Primary", save=True, 
                                   filename='cross_validation_results.png', figsize=(10, 8)):
        """
        Plot cross-validation results for a matcher
        
        Args:
            matcher_name (str): Name of the matcher to visualize
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if (not self.matplotlib_available or not self.evaluator or 
            not hasattr(self.evaluator, 'cross_validation_results') or 
            matcher_name not in self.evaluator.cross_validation_results):
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Get cross-validation results
            cv_results = self.evaluator.cross_validation_results[matcher_name]
            
            # Create figure
            fig, axs = plt.subplots(1, 2, figsize=figsize)
            
            # Plot metric values across folds
            metrics_to_plot = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc']
            
            # Extract fold metrics
            fold_metrics = cv_results['fold_metrics']
            folds = range(1, len(fold_metrics) + 1)
            
            # Plot metrics
            for metric in metrics_to_plot:
                values = [fold.get(metric, 0) for fold in fold_metrics]
                axs[0].plot(folds, values, 'o-', label=metric.replace('_', ' ').title())
            
            axs[0].set_xlabel('Fold')
            axs[0].set_ylabel('Score')
            axs[0].set_title('Metrics Across Folds')
            axs[0].set_xticks(folds)
            axs[0].legend()
            axs[0].grid(alpha=0.3)
            
            # Plot aggregate metrics with error bars
            aggregate_metrics = cv_results['aggregate_metrics']
            metrics = list(aggregate_metrics.keys())
            values = [aggregate_metrics[metric]['mean'] for metric in metrics]
            errors = [aggregate_metrics[metric]['std'] for metric in metrics]
            
            # Format metric names for display
            formatted_metrics = [metric.replace('_', ' ').title() for metric in metrics]
            
            axs[1].barh(formatted_metrics, values, xerr=errors, alpha=0.7)
            
            for i, value in enumerate(values):
                axs[1].text(value, i, f" {value:.3f} ± {errors[i]:.3f}", va='center')
            
            axs[1].set_title('Aggregate Metrics (Mean ± Std)')
            axs[1].set_xlim(0, 1.0)
            axs[1].grid(axis='x', alpha=0.3)
            
            plt.tight_layout()
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Cross-validation results saved to {save_path}")
            
            return fig
            
        except Exception as e:
            self.logger.error(f"Error plotting cross-validation results: {e}")
            return None
    
    def plot_domain_performance(self, matcher_name="Primary", save=True,
                              filename='domain_performance.png', figsize=(12, 8)):
        """
        Plot performance across different domains for a matcher
        
        Args:
            matcher_name (str): Name of the matcher to visualize
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if not self.matplotlib_available or not self.evaluator or not self.evaluator.results:
            return None
        
        # Check if domain metrics are available
        if (matcher_name not in self.evaluator.results or 
            'domain_metrics' not in self.evaluator.results[matcher_name]):
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Get domain metrics
            domain_metrics = self.evaluator.results[matcher_name]['domain_metrics']
            
            # Skip if empty
            if not domain_metrics:
                return None
            
            # Create figure
            plt.figure(figsize=figsize)
            
            # Get domains and metrics
            domains = list(domain_metrics.keys())
            metrics = ['accuracy', 'precision', 'recall', 'f1_score']
            
            # Set up positions
            x = np.arange(len(domains))
            width = 0.2  # width of bars
            
            # Plot each metric
            for i, metric in enumerate(metrics):
                values = [domain_metrics[domain].get(metric, 0) for domain in domains]
                plt.bar(x + (i - len(metrics)/2 + 0.5) * width, values, width, label=metric.replace('_', ' ').title())
            
            # Add labels and legend
            plt.xlabel('Domain')
            plt.ylabel('Score')
            plt.title(f'Performance by Domain: {matcher_name}')
            plt.xticks(x, domains, rotation=45, ha='right')
            plt.legend(loc='lower right')
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Domain performance saved to {save_path}")
            
            return plt.gcf()
            
        except Exception as e:
            self.logger.error(f"Error plotting domain performance: {e}")
            return None
    
    def plot_error_analysis(self, matcher_name="Primary", save=True,
                          filename='error_analysis.png', figsize=(10, 8)):
        """
        Plot error analysis results for a matcher
        
        Args:
            matcher_name (str): Name of the matcher to visualize
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if (not self.matplotlib_available or not self.evaluator or 
            not hasattr(self.evaluator, 'error_analysis') or
            matcher_name not in self.evaluator.error_analysis):
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Get error analysis results
            error_analysis = self.evaluator.error_analysis[matcher_name]
            
            # Create figure
            fig, axs = plt.subplots(1, 2, figsize=figsize)
            
            # Plot error counts
            error_types = ['false_positives', 'false_negatives']
            error_counts = [error_analysis['error_counts'][et] for et in error_types]
            
            axs[0].bar(error_types, error_counts, color=['#ff9999', '#99ccff'])
            
            for i, count in enumerate(error_counts):
                axs[0].text(i, count, str(count), ha='center', va='bottom')
            
            axs[0].set_ylabel('Count')
            axs[0].set_title('Error Counts')
            axs[0].set_xticklabels([et.replace('_', ' ').title() for et in error_types])
            
            # Plot error patterns
            patterns = error_analysis['error_patterns']
            
            if patterns:
                # Get top patterns by count
                top_patterns = sorted(
                    patterns.items(),
                    key=lambda x: x[1]['count'],
                    reverse=True
                )[:5]
                
                pattern_names = [p[1]['name'] for p in top_patterns]
                pattern_counts = [p[1]['count'] for p in top_patterns]
                
                # Plot horizontal bar chart
                axs[1].barh(pattern_names, pattern_counts, color='#99cc99')
                
                for i, count in enumerate(pattern_counts):
                    axs[1].text(count, i, str(count), va='center')
                
                axs[1].set_xlabel('Count')
                axs[1].set_title('Top Error Patterns')
            else:
                axs[1].text(0.5, 0.5, "No error patterns identified", ha='center', va='center')
                axs[1].set_title('Error Patterns')
            
            plt.tight_layout()
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Error analysis saved to {save_path}")
            
            return fig
            
        except Exception as e:
            self.logger.error(f"Error plotting error analysis: {e}")
            return None
    
    def plot_significance_tests(self, save=True, filename='significance_tests.png', figsize=(10, 6)):
        """
        Plot statistical significance test results
        
        Args:
            save (bool): Whether to save the plot to file
            filename (str): Filename for saved plot
            figsize (tuple): Figure size (width, height) in inches
            
        Returns:
            matplotlib.figure.Figure or None: Figure object or None if plotting failed
        """
        if (not self.matplotlib_available or not self.evaluator or 
            not hasattr(self.evaluator, 'statistical_tests') or
            not self.evaluator.statistical_tests):
            return None
        
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Get significance test results
            tests = self.evaluator.statistical_tests
            
            # Filter to just the McNemar tests (not the Wilcoxon tests)
            mcnemar_tests = {k: v for k, v in tests.items() if 'wilcoxon' not in k}
            
            if not mcnemar_tests:
                return None
            
            # Create figure
            plt.figure(figsize=figsize)
            
            # Get baseline names and p-values
            baselines = list(mcnemar_tests.keys())
            p_values = [mcnemar_tests[b]['p_value'] for b in baselines]
            
            # Sort by p-value
            sorted_indices = np.argsort(p_values)
            baselines = [baselines[i] for i in sorted_indices]
            p_values = [p_values[i] for i in sorted_indices]
            
            # Plot bar chart of p-values
            bars = plt.bar(baselines, p_values, color=['green' if p < 0.05 else 'red' for p in p_values])
            
            # Add significance threshold line
            plt.axhline(y=0.05, color='black', linestyle='--', alpha=0.7, label='Significance threshold (p=0.05)')
            
            # Add labels
            for i, (bar, p) in enumerate(zip(bars, p_values)):
                plt.text(
                    bar.get_x() + bar.get_width()/2,
                    0.01,
                    f"p={p:.4f}",
                    ha='center',
                    rotation=90,
                    color='white' if p < 0.05 else 'black'
                )
                
                # Add significance indicator
                if p < 0.05:
                    plt.text(
                        bar.get_x() + bar.get_width()/2,
                        p + 0.02,
                        '*',
                        ha='center',
                        fontsize=16
                    )
            
            plt.xlabel('Baseline Matcher')
            plt.ylabel('p-value')
            plt.title("Statistical Significance Tests (McNemar's Test)")
            plt.xticks(rotation=45, ha='right')
            plt.legend()
            plt.tight_layout()
            
            # Save if requested
            if save:
                save_path = os.path.join(self.output_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                self.logger.info(f"Significance tests saved to {save_path}")
            
            return plt.gcf()
            
        except Exception as e:
            self.logger.error(f"Error plotting significance tests: {e}")
            return None
    
    def generate_html_report(self, title="Merchant Matching Evaluation Report", 
                            filename='merchant_matching_report.html'):
        """
        Generate a comprehensive HTML report with all visualizations and results
        
        Args:
            title (str): Report title
            filename (str): Filename for saved report
            
        Returns:
            str: Path to the generated HTML report
        """
        try:
            # Save all plots first
            self.plot_roc_curves()
            self.plot_precision_recall_curves()
            self.plot_metric_comparison()
            self.plot_confusion_matrices()
            self.plot_cross_validation_results()
            self.plot_domain_performance()
            self.plot_error_analysis()
            self.plot_significance_tests()
            
            # Generate HTML content
            html_content = f"""
            <!DOCTYPE html>
            <html>
            <head>
                <title>{title}</title>
                <style>
                    body {{
                        font-family: Arial, sans-serif;
                        margin: 20px;
                        line-height: 1.6;
                    }}
                    h1, h2, h3 {{
                        color: #2c3e50;
                    }}
                    .section {{
                        margin-bottom: 30px;
                        border-bottom: 1px solid #eee;
                        padding-bottom: 20px;
                    }}
                    table {{
                        border-collapse: collapse;
                        width: 100%;
                    }}
                    th, td {{
                        text-align: left;
                        padding: 8px;
                        border-bottom: 1px solid #ddd;
                    }}
                    th {{
                        background-color: #f2f2f2;
                    }}
                    tr:hover {{
                        background-color: #f5f5f5;
                    }}
                    .figure {{
                        margin: 20px 0;
                        text-align: center;
                    }}
                    .figure img {{
                        max-width: 100%;
                        box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                    }}
                    .caption {{
                        font-style: italic;
                        color: #666;
                        margin-top: 10px;
                    }}
                    .metric-good {{
                        color: green;
                        font-weight: bold;
                    }}
                    .metric-medium {{
                        color: orange;
                    }}
                    .metric-poor {{
                        color: red;
                    }}
                </style>
            </head>
            <body>
                <h1>{title}</h1>
                <p>Generated on {time.strftime("%Y-%m-%d %H:%M:%S")}</p>
                
                <div class="section">
                    <h2>Overview</h2>
                    <p>This report presents the evaluation results of various merchant name matching algorithms.</p>
            """
            
            # Add summary table
            if self.evaluator and self.evaluator.results:
                html_content += """
                    <h3>Performance Summary</h3>
                    <table>
                        <tr>
                            <th>Matcher</th>
                            <th>Accuracy</th>
                            <th>Precision</th>
                            <th>Recall</th>
                            <th>F1 Score</th>
                            <th>AUC-ROC</th>
                        </tr>
                """
                
                # Sort by F1 score
                sorted_matchers = sorted(
                    self.evaluator.results.items(),
                    key=lambda x: x[1].get('f1_score', 0),
                    reverse=True
                )
                
                for name, metrics in sorted_matchers:
                    html_content += f"""
                        <tr>
                            <td>{name}</td>
                            <td class="{'metric-good' if metrics.get('accuracy', 0) > 0.8 else 'metric-medium' if metrics.get('accuracy', 0) > 0.6 else 'metric-poor'}">{metrics.get('accuracy', 0):.4f}</td>
                            <td class="{'metric-good' if metrics.get('precision', 0) > 0.8 else 'metric-medium' if metrics.get('precision', 0) > 0.6 else 'metric-poor'}">{metrics.get('precision', 0):.4f}</td>
                            <td class="{'metric-good' if metrics.get('recall', 0) > 0.8 else 'metric-medium' if metrics.get('recall', 0) > 0.6 else 'metric-poor'}">{metrics.get('recall', 0):.4f}</td>
                            <td class="{'metric-good' if metrics.get('f1_score', 0) > 0.8 else 'metric-medium' if metrics.get('f1_score', 0) > 0.6 else 'metric-poor'}">{metrics.get('f1_score', 0):.4f}</td>
                            <td class="{'metric-good' if metrics.get('auc_roc', 0) > 0.8 else 'metric-medium' if metrics.get('auc_roc', 0) > 0.6 else 'metric-poor'}">{metrics.get('auc_roc', 0):.4f}</td>
                        </tr>
                    """
                
                html_content += """
                    </table>
                """
            
            # Add ROC curves
            html_content += """
                </div>
                
                <div class="section">
                    <h2>ROC Curves</h2>
                    <div class="figure">
                        <img src="roc_curves.png" alt="ROC Curves">
                        <p class="caption">Receiver Operating Characteristic (ROC) curves for all matchers.</p>
                    </div>
                </div>
                
                <div class="section">
                    <h2>Precision-Recall Curves</h2>
                    <div class="figure">
                        <img src="pr_curves.png" alt="Precision-Recall Curves">
                        <p class="caption">Precision-Recall curves for all matchers.</p>
                    </div>
                </div>
                
                <div class="section">
                    <h2>Metric Comparison</h2>
                    <div class="figure">
                        <img src="metric_comparison.png" alt="Metric Comparison">
                        <p class="caption">Comparison of performance metrics across all matchers.</p>
                    </div>
                </div>
            """
            
            # Add cross-validation results if available
            if hasattr(self.evaluator, 'cross_validation_results') and self.evaluator.cross_validation_results:
                html_content += """
                <div class="section">
                    <h2>Cross-Validation Results</h2>
                    <div class="figure">
                        <img src="cross_validation_results.png" alt="Cross-Validation Results">
                        <p class="caption">Cross-validation results showing performance stability across folds.</p>
                    </div>
                </div>
                """
            
            # Add domain performance if available
            if (self.evaluator and self.evaluator.results and 
                any('domain_metrics' in metrics for metrics in self.evaluator.results.values())):
                html_content += """
                <div class="section">
                    <h2>Domain-Specific Performance</h2>
                    <div class="figure">
                        <img src="domain_performance.png" alt="Domain Performance">
                        <p class="caption">Performance metrics across different domains.</p>
                    </div>
                </div>
                """
            
            # Add error analysis if available
            if hasattr(self.evaluator, 'error_analysis') and self.evaluator.error_analysis:
                html_content += """
                <div class="section">
                    <h2>Error Analysis</h2>
                    <div class="figure">
                        <img src="error_analysis.png" alt="Error Analysis">
                        <p class="caption">Analysis of error types and patterns.</p>
                    </div>
                """
                
                # Add example errors
                if "Primary" in self.evaluator.error_analysis:
                    error_analysis = self.evaluator.error_analysis["Primary"]
                    
                    # Add false positives
                    if error_analysis['false_positives']:
                        html_content += """
                        <h3>Example False Positives</h3>
                        <table>
                            <tr>
                                <th>First Name</th>
                                <th>Second Name</th>
                                <th>Score</th>
                                <th>Domain</th>
                            </tr>
                        """
                        
                        for fp in error_analysis['false_positives'][:5]:  # Show top 5
                            html_content += f"""
                            <tr>
                                <td>{fp['s1']}</td>
                                <td>{fp['s2']}</td>
                                <td>{fp['score']:.4f}</td>
                                <td>{fp.get('domain', 'N/A')}</td>
                            </tr>
                            """
                        
                        html_content += """
                        </table>
                        """
                    
                    # Add false negatives
                    if error_analysis['false_negatives']:
                        html_content += """
                        <h3>Example False Negatives</h3>
                        <table>
                            <tr>
                                <th>First Name</th>
                                <th>Second Name</th>
                                <th>Score</th>
                                <th>Domain</th>
                            </tr>
                        """
                        
                        for fn in error_analysis['false_negatives'][:5]:  # Show top 5
                            html_content += f"""
                            <tr>
                                <td>{fn['s1']}</td>
                                <td>{fn['s2']}</td>
                                <td>{fn['score']:.4f}</td>
                                <td>{fn.get('domain', 'N/A')}</td>
                            </tr>
                            """
                        
                        html_content += """
                        </table>
                        """
                
                html_content += """
                </div>
                """
            
            # Add significance tests if available
            if hasattr(self.evaluator, 'statistical_tests') and self.evaluator.statistical_tests:
                html_content += """
                <div class="section">
                    <h2>Statistical Significance</h2>
                    <div class="figure">
                        <img src="significance_tests.png" alt="Significance Tests">
                        <p class="caption">Statistical significance of performance improvements.</p>
                    </div>
                """
                
                # Add significance test results table
                html_content += """
                    <h3>Significance Test Results</h3>
                    <table>
                        <tr>
                            <th>Baseline</th>
                            <th>Test</th>
                            <th>p-value</th>
                            <th>Significant</th>
                        </tr>
                """
                
                for name, test in self.evaluator.statistical_tests.items():
                    if 'wilcoxon' in name:  # Skip Wilcoxon tests for simplicity
                        continue
                        
                    html_content += f"""
                    <tr>
                        <td>{name}</td>
                        <td>{test['test_name']}</td>
                        <td>{test['p_value']:.4f}</td>
                        <td class="{'metric-good' if test['significant'] else 'metric-poor'}">{test['significant']}</td>
                    </tr>
                    """
                
                html_content += """
                    </table>
                </div>
                """
            
            # Close HTML
            html_content += """
                <div class="section">
                    <h2>Conclusion</h2>
                    <p>This evaluation demonstrates the performance characteristics of various merchant name matching algorithms.</p>
            """
            
            # Add overall recommendation
            if self.evaluator and self.evaluator.results and "Primary" in self.evaluator.results:
                # Check if Primary is the best
                is_best = True
                primary_f1 = self.evaluator.results["Primary"].get('f1_score', 0)
                
                for name, metrics in self.evaluator.results.items():
                    if name != "Primary" and metrics.get('f1_score', 0) > primary_f1:
                        is_best = False
                        break
                
                if is_best:
                    html_content += """
                    <p class="metric-good">The primary hybrid algorithm outperforms all baseline algorithms, demonstrating superior merchant name matching capability.</p>
                    """
                else:
                    html_content += """
                    <p class="metric-medium">The primary algorithm shows competitive performance, but some baseline algorithms may perform better in specific scenarios.</p>
                    """
            
            html_content += """
                </div>
                
            </body>
            </html>
            """
            
            # Write to file
            report_path = os.path.join(self.output_dir, filename)
            with open(report_path, 'w') as f:
                f.write(html_content)
            
            self.logger.info(f"HTML report generated at {report_path}")
            
            return report_path
            
        except Exception as e:
            self.logger.error(f"Error generating HTML report: {e}")
            return None

In [53]:
# 11.3: Comprehensive Evaluation Pipeline
class MerchantMatchingEvaluationPipeline:
    """
    End-to-end pipeline for rigorous evaluation of merchant name matching algorithms
    with comprehensive analytics, comparison, and reporting capabilities.
    
    This pipeline integrates data loading, preprocessing, evaluation, validation,
    and reporting into a unified workflow, ensuring consistent and thorough
    assessment of matching performance.
    
    Key features:
    - Automated test suite generation from real-world data
    - Comprehensive performance evaluation across matchers
    - Cross-validation for robust assessment
    - Domain-specific analysis
    - Error pattern analysis and classification
    - Detailed reports with actionable insights
    - Performance visualization and comparison tools
    """
    
    def __init__(self, matchers=None, data_path=None, test_set=None, 
                 domains=None, metrics=None, k_folds=5, random_state=42):
        """
        Initialize the evaluation pipeline with matchers and configuration
        
        Args:
            matchers (dict): Dictionary mapping matcher names to matcher instances
            data_path (str, optional): Path to evaluation dataset
            test_set (DataFrame, optional): Predefined test set
            domains (list, optional): List of domains to evaluate
            metrics (list, optional): List of evaluation metrics to use
            k_folds (int): Number of folds for cross-validation
            random_state (int): Random seed for reproducibility
        """
        # Configure logging
        self.logger = logging.getLogger(__name__)
        
        # Initialize matcher registry
        self.matchers = matchers or {}
        
        # Set evaluation parameters
        self.data_path = data_path
        self.test_set = test_set
        self.domains = domains or ['general', 'banking', 'retail', 'restaurant', 'hotel']
        
        # Set default metrics if not provided
        self.metrics = metrics or [
            'precision', 'recall', 'f1_score', 'accuracy', 
            'roc_auc', 'confusion_matrix', 'precision_at_k'
        ]
        
        # Cross-validation parameters
        self.k_folds = k_folds
        self.random_state = random_state
        
        # Initialize results storage
        self.results = {}
        self.domain_results = {}
        self.error_analysis = {}
        self.cross_val_results = {}
        
        # Load test data if path provided
        if data_path and not test_set:
            self.test_set = self._load_test_data(data_path)
        
        # Register standard error types for classification
        self.error_types = {
            'acronym_errors': 'Failures in matching acronyms to full forms',
            'spelling_errors': 'Failures due to misspellings or typos',
            'structural_errors': 'Failures due to word order or structural differences',
            'semantic_errors': 'Failures in understanding semantic relationships',
            'domain_specific_errors': 'Failures due to domain-specific terminology',
            'false_positives': 'Incorrect matches between different merchants',
            'false_negatives': 'Missed matches between same merchants'
        }
        
        self.logger.info(f"Evaluation pipeline initialized with {len(self.matchers)} matchers")
    
    def register_matcher(self, name, matcher):
        """
        Register a matcher for evaluation
        
        Args:
            name (str): Name identifier for the matcher
            matcher: Matcher instance with match_merchants method
            
        Returns:
            bool: Success status
        """
        if name in self.matchers:
            self.logger.warning(f"Matcher '{name}' already registered. Overwriting.")
            
        # Validate matcher interface
        if not hasattr(matcher, 'match_merchants'):
            self.logger.error(f"Invalid matcher '{name}': missing match_merchants method")
            return False
            
        self.matchers[name] = matcher
        self.logger.info(f"Registered matcher '{name}'")
        return True
    
    def _load_test_data(self, data_path):
        """
        Load test data from file
        
        Args:
            data_path (str): Path to test data file
            
        Returns:
            DataFrame: Loaded test data
        """
        try:
            # Determine file type and load accordingly
            if data_path.endswith('.csv'):
                data = pd.read_csv(data_path)
            elif data_path.endswith('.xlsx'):
                data = pd.read_excel(data_path)
            elif data_path.endswith('.json'):
                data = pd.read_json(data_path)
            else:
                self.logger.error(f"Unsupported file format: {data_path}")
                return None
                
            # Validate required columns
            required_cols = ['s1', 's2', 'is_match']
            if not all(col in data.columns for col in required_cols):
                self.logger.error(f"Test data missing required columns: {required_cols}")
                return None
                
            self.logger.info(f"Loaded test data from {data_path} with {len(data)} examples")
            return data
            
        except Exception as e:
            self.logger.error(f"Error loading test data: {str(e)}")
            return None
    
    def generate_test_suite(self, source_data, sample_size=1000, match_ratio=0.5,
                            include_edge_cases=True, stratify_by_domain=True):
        """
        Generate a comprehensive test suite from source data
        
        Args:
            source_data (DataFrame): Source data with merchant names
            sample_size (int): Size of generated test suite
            match_ratio (float): Ratio of matching pairs to non-matching pairs
            include_edge_cases (bool): Whether to include challenging edge cases
            stratify_by_domain (bool): Whether to stratify sampling by domain
            
        Returns:
            DataFrame: Generated test suite
        """
        self.logger.info(f"Generating test suite with {sample_size} examples")
        
        try:
            # Check for required columns
            if 'merchant_name' not in source_data.columns:
                self.logger.error("Source data missing required 'merchant_name' column")
                return None
                
            # Create positive (matching) pairs
            pos_pairs = []
            
            # Group by domain if stratification requested
            if stratify_by_domain and 'domain' in source_data.columns:
                domains = source_data['domain'].unique()
                pos_per_domain = int((sample_size * match_ratio) / len(domains))
                
                for domain in domains:
                    domain_merchants = source_data[source_data['domain'] == domain]['merchant_name']
                    domain_pairs = self._generate_positive_pairs(domain_merchants, pos_per_domain)
                    
                    for pair in domain_pairs:
                        pos_pairs.append((*pair, 1, domain))
            else:
                # Generate without stratification
                merchants = source_data['merchant_name'].unique()
                pos_count = int(sample_size * match_ratio)
                all_pos_pairs = self._generate_positive_pairs(merchants, pos_count)
                
                for pair in all_pos_pairs:
                    domain = 'general'
                    if 'domain' in source_data.columns:
                        domains = source_data[source_data['merchant_name'].isin(pair)]['domain'].unique()
                        if len(domains) > 0:
                            domain = domains[0]
                    
                    pos_pairs.append((*pair, 1, domain))
            
            # Create negative (non-matching) pairs
            neg_pairs = []
            neg_count = sample_size - len(pos_pairs)
            
            if stratify_by_domain and 'domain' in source_data.columns:
                domains = source_data['domain'].unique()
                neg_per_domain = int(neg_count / len(domains))
                
                for domain in domains:
                    domain_merchants = source_data[source_data['domain'] == domain]['merchant_name']
                    domain_pairs = self._generate_negative_pairs(domain_merchants, neg_per_domain)
                    
                    for pair in domain_pairs:
                        neg_pairs.append((*pair, 0, domain))
            else:
                # Generate without stratification
                merchants = source_data['merchant_name'].unique()
                all_neg_pairs = self._generate_negative_pairs(merchants, neg_count)
                
                for pair in all_neg_pairs:
                    domain = 'general'
                    if 'domain' in source_data.columns:
                        domains = source_data[source_data['merchant_name'] == pair[0]]['domain'].unique()
                        if len(domains) > 0:
                            domain = domains[0]
                    
                    neg_pairs.append((*pair, 0, domain))
            
            # Combine positive and negative pairs
            all_pairs = pos_pairs + neg_pairs
            
            # Add edge cases if requested
            if include_edge_cases:
                edge_cases = self._generate_edge_cases(source_data)
                all_pairs.extend(edge_cases)
            
            # Convert to DataFrame
            test_suite = pd.DataFrame(all_pairs, columns=['s1', 's2', 'is_match', 'domain'])
            
            # Shuffle the data
            test_suite = test_suite.sample(frac=1, random_state=self.random_state).reset_index(drop=True)
            
            self.test_set = test_suite
            self.logger.info(f"Generated test suite with {len(test_suite)} examples")
            
            return test_suite
            
        except Exception as e:
            self.logger.error(f"Error generating test suite: {str(e)}")
            import traceback
            self.logger.debug(traceback.format_exc())
            return None
    
    def _generate_positive_pairs(self, merchants, count):
        """Generate matching pairs through systematic variation"""
        import random
        random.seed(self.random_state)
        
        pairs = []
        merchants = list(merchants)
        
        if len(merchants) < 10:
            self.logger.warning("Too few merchants to generate diverse positive pairs")
            return pairs
            
        # Generate pairs through various transformations
        while len(pairs) < count:
            if not merchants:
                break
                
            # Select a random merchant
            merchant = random.choice(merchants)
            
            # Skip very short names
            if len(merchant) < 3:
                continue
                
            # Apply transformations to create matching variants
            variant = None
            transform_type = random.randint(1, 6)
            
            if transform_type == 1:
                # Change business suffix
                suffixes = [' Inc', ' LLC', ' Corp', ' Co', ' Ltd']
                if any(merchant.endswith(suffix) for suffix in suffixes):
                    # Remove suffix
                    for suffix in suffixes:
                        if merchant.endswith(suffix):
                            variant = merchant[:-len(suffix)]
                            break
                else:
                    # Add suffix
                    variant = merchant + random.choice(suffixes)
                    
            elif transform_type == 2:
                # Add/remove spaces or hyphens
                if ' ' in merchant:
                    variant = merchant.replace(' ', '')
                elif '-' in merchant:
                    variant = merchant.replace('-', ' ')
                else:
                    # Insert spaces between words (heuristic)
                    words = []
                    current_word = ""
                    for c in merchant:
                        if c.isupper() and current_word and not current_word[-1].isupper():
                            words.append(current_word)
                            current_word = c
                        else:
                            current_word += c
                    if current_word:
                        words.append(current_word)
                    variant = ' '.join(words)
                    
            elif transform_type == 3:
                # Abbreviate or expand
                words = merchant.split()
                if len(words) >= 3:
                    # Create acronym
                    variant = ''.join(word[0].upper() for word in words)
                elif len(words) == 1 and len(merchant) <= 3:
                    # Expand common acronyms
                    expansions = {
                        'BOA': 'Bank of America',
                        'WF': 'Wells Fargo',
                        'JPM': 'JPMorgan Chase',
                        'GS': 'Goldman Sachs',
                        'MS': 'Morgan Stanley',
                        'WM': 'Walmart',
                        'TGT': 'Target',
                        'HD': 'Home Depot',
                        'MCD': 'McDonalds'
                    }
                    variant = expansions.get(merchant.upper())
                    
            elif transform_type == 4:
                # Change case
                if merchant.isupper():
                    variant = merchant.lower()
                elif merchant.islower():
                    variant = merchant.upper()
                else:
                    variant = merchant.upper()
                    
            elif transform_type == 5:
                # Add/remove "The" prefix
                if merchant.startswith('The '):
                    variant = merchant[4:]
                else:
                    variant = 'The ' + merchant
                    
            elif transform_type == 6:
                # Introduce minor misspellings
                variant = merchant
                if len(merchant) > 5:
                    pos = random.randint(1, len(merchant)-2)
                    chars = list(merchant)
                    # Either swap adjacent chars or replace with a similar char
                    if random.random() < 0.5:
                        chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
                    else:
                        similar_chars = {
                            'a': 'e', 'e': 'a', 'i': 'y', 'o': 'u', 'u': 'o',
                            's': 'z', 'c': 'k', 'm': 'n', 'b': 'p', 'g': 'j'
                        }
                        if chars[pos].lower() in similar_chars:
                            chars[pos] = similar_chars[chars[pos].lower()]
                    variant = ''.join(chars)
            
            # Add pair if valid variant was created
            if variant and variant != merchant and (merchant, variant) not in pairs and (variant, merchant) not in pairs:
                pairs.append((merchant, variant))
                
            # Break if we've tried too many merchants without success
            if len(pairs) < count and len(pairs) < len(merchants) // 2:
                merchants.remove(merchant)
        
        return pairs[:count]
    
    def _generate_negative_pairs(self, merchants, count):
        """Generate non-matching pairs with varying similarity levels"""
        import random
        random.seed(self.random_state)
        
        pairs = []
        merchants = list(merchants)
        
        if len(merchants) < 10:
            self.logger.warning("Too few merchants to generate diverse negative pairs")
            return pairs
            
        # Generate random pairs with different similarity levels
        while len(pairs) < count:
            if len(merchants) < 2:
                break
                
            # Select two different merchants
            m1 = random.choice(merchants)
            remaining = [m for m in merchants if m != m1]
            if not remaining:
                continue
                
            m2 = random.choice(remaining)
            
            # Skip very short names
            if len(m1) < 3 or len(m2) < 3:
                continue
                
            # Check if pair already exists
            if (m1, m2) in pairs or (m2, m1) in pairs:
                continue
                
            # Add the pair
            pairs.append((m1, m2))
        
        return pairs[:count]
    
    def _generate_edge_cases(self, source_data):
        """Generate challenging edge cases for testing"""
        edge_cases = []
        
        # Example edge cases (these would be more sophisticated in a real implementation)
        # 1. Acronyms with same letters but different order
        # 2. Names with same words but different order
        # 3. Names with high character overlap but different meanings
        # 4. Very similar names but semantically different merchants
        
        # For this implementation, we'll just add a few hand-crafted examples
        banking_cases = [
            # Positive cases (is_match=1)
            ('Bank of America', 'Bank of America, N.A.', 1, 'banking'),
            ('Bank of America', 'BoA', 1, 'banking'),
            ('JPMorgan Chase', 'J.P. Morgan', 1, 'banking'),
            # Negative cases (is_match=0)
            ('Bank of America', 'Bank of American Express', 0, 'banking'),
            ('First National Bank', 'First National Bancorp', 0, 'banking'),
        ]
        
        retail_cases = [
            # Positive cases
            ('Walmart', 'Wal-Mart Supercenter', 1, 'retail'),
            ('Target', 'Target Corp', 1, 'retail'),
            ('Best Buy', 'BestBuy', 1, 'retail'),
            # Negative cases
            ('Ross Dress for Less', 'Ross Stores', 0, 'retail'),
            ('TJ Maxx', 'T.J. Max', 0, 'retail'),
        ]
        
        restaurant_cases = [
            # Positive cases
            ('McDonald\'s', 'McDonalds', 1, 'restaurant'),
            ('Starbucks Coffee', 'SBUX', 1, 'restaurant'),
            ('Chipotle Mexican Grill', 'Chipotle', 1, 'restaurant'),
            # Negative cases
            ('Panda Express', 'Panda Inn', 0, 'restaurant'),
            ('Five Guys', 'Five Guys Burgers', 0, 'restaurant'),
        ]
        
        edge_cases.extend(banking_cases + retail_cases + restaurant_cases)
        return edge_cases
    
    def evaluate_all(self, use_cross_val=True):
        """
        Evaluate all registered matchers on the test set
        
        Args:
            use_cross_val (bool): Whether to use cross-validation
            
        Returns:
            dict: Evaluation results for all matchers
        """
        if not self.test_set is not None:
            self.logger.error("No test set available for evaluation")
            return None
            
        if not self.matchers:
            self.logger.error("No matchers registered for evaluation")
            return None
            
        self.logger.info(f"Evaluating {len(self.matchers)} matchers on {len(self.test_set)} examples")
        
        # Initialize results storage
        results = {}
        
        # Evaluate each matcher
        for name, matcher in self.matchers.items():
            self.logger.info(f"Evaluating matcher: {name}")
            
            # Standard evaluation
            matcher_results = self.evaluate_matcher(matcher, name)
            results[name] = matcher_results
            
            # Cross-validation if requested
            if use_cross_val:
                cv_results = self.cross_validate(matcher, name)
                results[name]['cross_validation'] = cv_results
        
        # Store results
        self.results = results
        
        # Generate comparative analysis
        comparison = self._compare_matchers(results)
        results['comparison'] = comparison
        
        return results
    
    def evaluate_matcher(self, matcher, matcher_name):
        """
        Evaluate a single matcher on the test set
        
        Args:
            matcher: Matcher instance to evaluate
            matcher_name (str): Name identifier for the matcher
            
        Returns:
            dict: Evaluation results
        """
        if self.test_set is None:
            self.logger.error("No test set available for evaluation")
            return None
            
        # Extract test pairs
        s1_values = self.test_set['s1'].values
        s2_values = self.test_set['s2'].values
        y_true = self.test_set['is_match'].values
        domains = self.test_set['domain'].values if 'domain' in self.test_set.columns else None
        
        # Get match scores
        self.logger.info(f"Processing {len(s1_values)} test pairs for {matcher_name}")
        
        y_scores = []
        y_pred = []
        
        # Process in batches for large test sets
        batch_size = 1000
        for i in range(0, len(s1_values), batch_size):
            batch_end = min(i + batch_size, len(s1_values))
            self.logger.info(f"Processing batch {i//batch_size + 1}/{(len(s1_values)-1)//batch_size + 1}")
            
            batch_scores = []
            for j in range(i, batch_end):
                s1 = s1_values[j]
                s2 = s2_values[j]
                domain = domains[j] if domains is not None else None
                
                try:
                    # Get match score
                    if hasattr(matcher, 'match_merchants'):
                        score = matcher.match_merchants(s1, s2, domain)
                    else:
                        # Fallback for simplistic matchers
                        score = matcher(s1, s2)
                        
                    batch_scores.append(float(score))
                except Exception as e:
                    self.logger.warning(f"Error matching '{s1}' and '{s2}': {e}")
                    batch_scores.append(0.0)
            
            y_scores.extend(batch_scores)
        
        # Convert scores to array
        y_scores = np.array(y_scores)
        
        # Calculate optimal threshold using Youden's J statistic
        from sklearn.metrics import roc_curve
        fpr, tpr, thresholds = roc_curve(y_true, y_scores)
        optimal_idx = np.argmax(tpr - fpr)
        optimal_threshold = thresholds[optimal_idx]
        
        # Generate predictions based on optimal threshold
        y_pred = (y_scores >= optimal_threshold).astype(int)
        
        # Calculate metrics
        metrics = self._calculate_metrics(y_true, y_pred, y_scores)
        metrics['optimal_threshold'] = optimal_threshold
        
        # Calculate domain-specific performance if domains available
        domain_performance = {}
        if domains is not None:
            unique_domains = np.unique(domains)
            
            for domain in unique_domains:
                domain_mask = domains == domain
                if sum(domain_mask) < 10:  # Skip domains with too few samples
                    continue
                    
                domain_y_true = y_true[domain_mask]
                domain_y_pred = y_pred[domain_mask]
                domain_y_scores = y_scores[domain_mask]
                
                domain_metrics = self._calculate_metrics(domain_y_true, domain_y_pred, domain_y_scores)
                domain_performance[domain] = domain_metrics
        
        # Identify error patterns
        error_analysis = self._analyze_errors(
            s1_values, s2_values, y_true, y_pred, y_scores, domains
        )
        
        # Compile results
        results = {
            'matcher_name': matcher_name,
            'overall_metrics': metrics,
            'domain_performance': domain_performance,
            'error_analysis': error_analysis,
            'test_size': len(y_true),
            'positive_examples': int(sum(y_true)),
            'negative_examples': int(len(y_true) - sum(y_true)),
            'timestamp': time.time()
        }
        
        # Store in instance results
        self.results[matcher_name] = results
        self.domain_results[matcher_name] = domain_performance
        self.error_analysis[matcher_name] = error_analysis
        
        return results
    
    def _calculate_metrics(self, y_true, y_pred, y_scores):
        """Calculate comprehensive evaluation metrics"""
        from sklearn.metrics import (
            precision_score, recall_score, f1_score, accuracy_score, 
            roc_auc_score, confusion_matrix, precision_recall_curve
        )
        
        metrics = {}
        
        # Basic classification metrics
        metrics['precision'] = float(precision_score(y_true, y_pred, zero_division=0))
        metrics['recall'] = float(recall_score(y_true, y_pred, zero_division=0))
        metrics['f1_score'] = float(f1_score(y_true, y_pred, zero_division=0))
        metrics['accuracy'] = float(accuracy_score(y_true, y_pred))
        
        # ROC and AUC
        try:
            metrics['roc_auc'] = float(roc_auc_score(y_true, y_scores))
        except:
            metrics['roc_auc'] = 0.5  # Default for failed calculation
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        metrics['confusion_matrix'] = cm.tolist()
        
        # True/False Positives/Negatives
        tn, fp, fn, tp = cm.ravel()
        metrics['true_positives'] = int(tp)
        metrics['false_positives'] = int(fp)
        metrics['true_negatives'] = int(tn)
        metrics['false_negatives'] = int(fn)
        
        # Precision-Recall curve
        precision, recall, pr_thresholds = precision_recall_curve(y_true, y_scores)
        # Store selected points from the curve (100 points max)
        step = max(1, len(precision) // 100)
        metrics['pr_curve'] = {
            'precision': precision[::step].tolist(),
            'recall': recall[::step].tolist()
        }
        
        # Additional metrics
        if tp + fp > 0:
            metrics['precision_at_k'] = float(tp / (tp + fp))
        else:
            metrics['precision_at_k'] = 0.0
            
        return metrics
    
    def _analyze_errors(self, s1_values, s2_values, y_true, y_pred, y_scores, domains=None):
        """
        Analyze error patterns in matching results
        
        Args:
            s1_values: Array of first merchant names
            s2_values: Array of second merchant names
            y_true: Array of true labels
            y_pred: Array of predicted labels
            y_scores: Array of match scores
            domains: Array of domains (optional)
            
        Returns:
            dict: Error analysis results
        """
        # Find incorrect predictions
        errors = (y_true != y_pred)
        error_indices = np.where(errors)[0]
        
        if len(error_indices) == 0:
            return {'total_errors': 0, 'error_examples': []}
            
        # Collect error examples
        error_examples = []
        for idx in error_indices[:min(100, len(error_indices))]:  # Limit to 100 examples
            s1 = s1_values[idx]
            s2 = s2_values[idx]
            true_label = int(y_true[idx])
            pred_label = int(y_pred[idx])
            score = float(y_scores[idx])
            domain = domains[idx] if domains is not None else 'unknown'
            
            error_type = 'false_positive' if pred_label == 1 and true_label == 0 else 'false_negative'
            
            error_examples.append({
                'index': int(idx),
                's1': s1,
                's2': s2,
                'true_label': true_label,
                'pred_label': pred_label,
                'score': score,
                'domain': domain,
                'error_type': error_type,
                'error_subtype': self._classify_error(s1, s2, true_label, pred_label)
            })
        
        # Classify errors by type
        error_counts = {
            'false_positives': int(sum((y_pred == 1) & (y_true == 0))),
            'false_negatives': int(sum((y_pred == 0) & (y_true == 1)))
        }
        
        # Classify errors by domain if available
        domain_errors = {}
        if domains is not None:
            unique_domains = np.unique(domains)
            for domain in unique_domains:
                domain_mask = domains == domain
                if sum(domain_mask) < 5:  # Skip domains with too few samples
                    continue
                    
                domain_errors[domain] = {
                    'total': int(sum(errors & domain_mask)),
                    'false_positives': int(sum((y_pred == 1) & (y_true == 0) & domain_mask)),
                    'false_negatives': int(sum((y_pred == 0) & (y_true == 1) & domain_mask))
                }
        
        # Classify error subtypes
        error_subtypes = {}
        for example in error_examples:
            subtype = example['error_subtype']
            if subtype not in error_subtypes:
                error_subtypes[subtype] = 0
            error_subtypes[subtype] += 1
        
        return {
            'total_errors': int(sum(errors)),
            'error_rate': float(sum(errors) / len(y_true)),
            'error_counts': error_counts,
            'domain_errors': domain_errors,
            'error_subtypes': error_subtypes,
            'error_examples': error_examples
        }
    
    def _classify_error(self, s1, s2, true_label, pred_label):
        """Classify the error type based on string analysis"""
        # Initialize potential error types
        error_types = []
        
        # Check for acronym pattern
        if (len(s1) <= 5 and s1.isupper() and len(s2) > 10) or \
           (len(s2) <= 5 and s2.isupper() and len(s1) > 10):
            error_types.append('acronym_error')
        
        # Check for spelling differences
        if len(s1) > 3 and len(s2) > 3:
            if self._levenshtein_distance(s1.lower(), s2.lower()) <= 2:
                error_types.append('spelling_error')
        
        # Check for word order differences
        s1_words = set(s1.lower().split())
        s2_words = set(s2.lower().split())
        if s1_words == s2_words and s1.lower() != s2.lower():
            error_types.append('structural_error')
        
        # Check for partial containment
        if s1.lower() in s2.lower() or s2.lower() in s1.lower():
            error_types.append('partial_match_error')
        
        # Default error types based on prediction type
        if true_label == 1 and pred_label == 0:
            if not error_types:
                error_types.append('semantic_error')  # Failed to recognize semantic relationship
        else:  # true_label == 0 and pred_label == 1
            if not error_types:
                error_types.append('false_similar_error')  # Incorrectly identified as similar
        
        # Return primary error type
        return error_types[0] if error_types else 'unknown_error'
    
    def _levenshtein_distance(self, s1, s2):
        """Calculate Levenshtein distance between two strings"""
        if len(s1) < len(s2):
            return self._levenshtein_distance(s2, s1)
            
        if len(s2) == 0:
            return len(s1)
            
        previous_row = range(len(s2) + 1)
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row
            
        return previous_row[-1]
    
    def cross_validate(self, matcher, matcher_name, n_folds=None):
        """
        Perform cross-validation on the matcher
        
        Args:
            matcher: Matcher instance to evaluate
            matcher_name (str): Name identifier for the matcher
            n_folds (int, optional): Number of folds (defaults to self.k_folds)
            
        Returns:
            dict: Cross-validation results
        """
        if self.test_set is None:
            self.logger.error("No test set available for cross-validation")
            return None
            
        # Use default folds if not specified
        n_folds = n_folds or self.k_folds
        
        self.logger.info(f"Performing {n_folds}-fold cross-validation for {matcher_name}")
        
        # Initialize result storage
        cv_results = {
            'folds': [],
            'avg_metrics': {},
            'std_metrics': {}
        }
        
        # Prepare data
        data = self.test_set.copy()
        
        # Create fold indices
        from sklearn.model_selection import StratifiedKFold
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=self.random_state)
        fold_indices = list(skf.split(data, data['is_match']))
        
        # Evaluate each fold
        all_metrics = []
        
        for fold_idx, (train_idx, test_idx) in enumerate(fold_indices):
            self.logger.info(f"Processing fold {fold_idx+1}/{n_folds}")
            
            # Split data
            train_data = data.iloc[train_idx]
            test_data = data.iloc[test_idx]
            
            # Train matcher if it supports training
            if hasattr(matcher, 'train') and callable(getattr(matcher, 'train')):
                try:
                    matcher.train(train_data)
                except Exception as e:
                    self.logger.warning(f"Error training matcher on fold {fold_idx+1}: {e}")
            
            # Evaluate on test split
            s1_values = test_data['s1'].values
            s2_values = test_data['s2'].values
            y_true = test_data['is_match'].values
            
            y_scores = []
            for s1, s2 in zip(s1_values, s2_values):
                try:
                    score = matcher.match_merchants(s1, s2)
                    y_scores.append(float(score))
                except Exception as e:
                    self.logger.warning(f"Error matching '{s1}' and '{s2}': {e}")
                    y_scores.append(0.0)
            
            y_scores = np.array(y_scores)
            
            # Find optimal threshold
            from sklearn.metrics import roc_curve
            fpr, tpr, thresholds = roc_curve(y_true, y_scores)
            optimal_idx = np.argmax(tpr - fpr)
            optimal_threshold = thresholds[optimal_idx]
            
            # Generate predictions
            y_pred = (y_scores >= optimal_threshold).astype(int)
            
            # Calculate metrics
            fold_metrics = self._calculate_metrics(y_true, y_pred, y_scores)
            fold_metrics['optimal_threshold'] = optimal_threshold
            
            # Store fold results
            cv_results['folds'].append({
                'fold_idx': fold_idx,
                'metrics': fold_metrics,
                'train_size': len(train_data),
                'test_size': len(test_data)
            })
            
            all_metrics.append(fold_metrics)
        
        # Calculate aggregate statistics
        metric_names = ['precision', 'recall', 'f1_score', 'accuracy', 'roc_auc']
        for metric in metric_names:
            values = [fold['metrics'][metric] for fold in cv_results['folds']]
            cv_results['avg_metrics'][metric] = float(np.mean(values))
            cv_results['std_metrics'][metric] = float(np.std(values))
        
        # Store cross-validation results
        self.cross_val_results[matcher_name] = cv_results
        
        return cv_results
    
    def _compare_matchers(self, results):
        """Generate comparative analysis of matcher performance"""
        if not results:
            return {}
            
        comparison = {
            'overall_ranking': {},
            'domain_ranking': {},
            'metric_comparison': {},
            'error_comparison': {},
            'statistical_significance': {}
        }
        
        # Extract matcher names
        matcher_names = list(results.keys())
        
        # Compare overall metrics
        metric_names = ['precision', 'recall', 'f1_score', 'accuracy', 'roc_auc']
        for metric in metric_names:
            comparison['metric_comparison'][metric] = {}
            
            # Get metric values for all matchers
            values = {}
            for name in matcher_names:
                if name in results and 'overall_metrics' in results[name]:
                    values[name] = results[name]['overall_metrics'].get(metric, 0)
            
            # Rank matchers by the metric
            ranked_matchers = sorted(values.items(), key=lambda x: x[1], reverse=True)
            comparison['metric_comparison'][metric]['values'] = values
            comparison['metric_comparison'][metric]['ranking'] = [name for name, _ in ranked_matchers]
        
        # Generate overall ranking based on F1 score
        f1_scores = {}
        for name in matcher_names:
            if name in results and 'overall_metrics' in results[name]:
                f1_scores[name] = results[name]['overall_metrics'].get('f1_score', 0)
                
        overall_ranking = sorted(f1_scores.items(), key=lambda x: x[1], reverse=True)
        comparison['overall_ranking'] = {
            'metric': 'f1_score',
            'ranking': [name for name, _ in overall_ranking],
            'scores': f1_scores
        }
        
        # Compare domain performance if available
        domain_data = {}
        for name in matcher_names:
            if name in results and 'domain_performance' in results[name]:
                for domain, metrics in results[name]['domain_performance'].items():
                    if domain not in domain_data:
                        domain_data[domain] = {}
                    domain_data[domain][name] = metrics.get('f1_score', 0)
        
        # Rank by domain
        for domain, scores in domain_data.items():
            ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
            comparison['domain_ranking'][domain] = {
                'ranking': [name for name, _ in ranked],
                'scores': scores
            }
        
        # Compare error patterns
        error_data = {}
        for name in matcher_names:
            if name in results and 'error_analysis' in results[name]:
                error_analysis = results[name]['error_analysis']
                error_data[name] = {
                    'total_errors': error_analysis.get('total_errors', 0),
                    'error_rate': error_analysis.get('error_rate', 0),
                    'false_positives': error_analysis.get('error_counts', {}).get('false_positives', 0),
                    'false_negatives': error_analysis.get('error_counts', {}).get('false_negatives', 0)
                }
        
        comparison['error_comparison'] = error_data
        
        # Perform statistical significance testing (McNemar's test)
        if len(matcher_names) > 1:
            try:
                from statsmodels.stats.contingency_tables import mcnemar
                
                # Get prediction arrays for each matcher
                predictions = {}
                for name in matcher_names:
                    if name in self.results:
                        # Reconstruct predictions from error analysis
                        y_true = self.test_set['is_match'].values
                        error_indices = [ex['index'] for ex in self.results[name]['error_analysis'].get('error_examples', [])]
                        y_pred = np.array(y_true)  # Start with correct predictions
                        y_pred[error_indices] = 1 - y_true[error_indices]  # Flip predictions at error indices
                        predictions[name] = y_pred
                
                # Perform pairwise McNemar tests
                significance_results = {}
                for i, name1 in enumerate(matcher_names):
                    for j, name2 in enumerate(matcher_names):
                        if i < j and name1 in predictions and name2 in predictions:
                            # Create contingency table
                            pred1 = predictions[name1]
                            pred2 = predictions[name2]
                            
                            # Count contingency table cells
                            both_correct = sum((pred1 == self.test_set['is_match']) & (pred2 == self.test_set['is_match']))
                            name1_correct = sum((pred1 == self.test_set['is_match']) & (pred2 != self.test_set['is_match']))
                            name2_correct = sum((pred1 != self.test_set['is_match']) & (pred2 == self.test_set['is_match']))
                            both_wrong = sum((pred1 != self.test_set['is_match']) & (pred2 != self.test_set['is_match']))
                            
                            table = np.array([[both_correct, name2_correct], 
                                             [name1_correct, both_wrong]])
                            
                            # Perform McNemar test
                            result = mcnemar(table, exact=True)
                            p_value = result.pvalue
                            
                            pair_key = f"{name1}_vs_{name2}"
                            significance_results[pair_key] = {
                                'p_value': float(p_value),
                                'significant': p_value < 0.05,
                                'better_matcher': name1 if name1_correct > name2_correct else name2
                            }
                
                comparison['statistical_significance'] = significance_results
            except Exception as e:
                self.logger.warning(f"Error performing statistical significance tests: {e}")
                comparison['statistical_significance'] = {'error': str(e)}
        
        return comparison
    
    def generate_report(self, output_path=None, format='html'):
        """
        Generate a comprehensive evaluation report
        
        Args:
            output_path (str, optional): Path to save the report
            format (str): Report format ('html', 'json', 'pdf')
            
        Returns:
            str: Path to the generated report or HTML string
        """
        if not self.results:
            self.logger.error("No evaluation results available for reporting")
            return None
            
        self.logger.info(f"Generating {format} evaluation report")
        
        # Generate report content based on format
        if format == 'html':
            report_content = self._generate_html_report()
        elif format == 'json':
            import json
            report_content = json.dumps(self.results, indent=2)
        else:
            self.logger.error(f"Unsupported report format: {format}")
            return None
        
        # Save to file if output path provided
        if output_path:
            try:
                with open(output_path, 'w') as f:
                    f.write(report_content)
                self.logger.info(f"Report saved to {output_path}")
                return output_path
            except Exception as e:
                self.logger.error(f"Error saving report: {e}")
                return report_content
        
        return report_content
    
    def _generate_html_report(self):
        """Generate an HTML evaluation report"""
        # Basic HTML template
        html_template = '''
        <!DOCTYPE html>
        <html>
        <head>
            <title>Merchant Matching Evaluation Report</title>
            <style>
                body { font-family: Arial, sans-serif; line-height: 1.6; padding: 20px; max-width: 1200px; margin: 0 auto; }
                h1, h2, h3 { color: #333; }
                table { border-collapse: collapse; width: 100%; margin-bottom: 20px; }
                th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
                th { background-color: #f2f2f2; }
                tr:nth-child(even) { background-color: #f9f9f9; }
                .chart-container { width: 600px; height: 400px; margin: 20px 0; }
                .matcher-section { border: 1px solid #ddd; padding: 15px; margin-bottom: 20px; }
                .metric-good { color: green; }
                .metric-medium { color: orange; }
                .metric-bad { color: red; }
                .summary-box { background-color: #f8f8f8; border-left: 4px solid #4CAF50; padding: 10px; margin-bottom: 20px; }
            </style>
            <!-- Add Chart.js for visualizations -->
            <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
        </head>
        <body>
            <h1>Merchant Matching Evaluation Report</h1>
            <div class="summary-box">
                <h2>Executive Summary</h2>
                {executive_summary}
            </div>
            
            <h2>Overall Comparison</h2>
            {overall_comparison}
            
            <h2>Detailed Matcher Results</h2>
            {matcher_results}
            
            <h2>Domain-Specific Performance</h2>
            {domain_performance}
            
            <h2>Error Analysis</h2>
            {error_analysis}
            
            <h2>Statistical Significance</h2>
            {statistical_significance}
            
            <footer>
                <p>Generated on {generation_date}</p>
            </footer>
            
            <!-- Charts initialization -->
            <script>
            {chart_scripts}
            </script>
        </body>
        </html>
        '''
        
        # Generate sections
        executive_summary = self._generate_executive_summary_html()
        overall_comparison = self._generate_overall_comparison_html()
        matcher_results = self._generate_matcher_results_html()
        domain_performance = self._generate_domain_performance_html()
        error_analysis = self._generate_error_analysis_html()
        statistical_significance = self._generate_statistical_significance_html()
        chart_scripts = self._generate_chart_scripts()
        
        # Format the template
        formatted_html = html_template.format(
            executive_summary=executive_summary,
            overall_comparison=overall_comparison,
            matcher_results=matcher_results,
            domain_performance=domain_performance,
            error_analysis=error_analysis,
            statistical_significance=statistical_significance,
            chart_scripts=chart_scripts,
            generation_date=time.strftime('%Y-%m-%d %H:%M:%S')
        )
        
        return formatted_html
    
    def _generate_executive_summary_html(self):
        """Generate HTML executive summary section"""
        # Find the best performing matcher
        best_matcher = None
        best_f1 = -1
        
        for name, results in self.results.items():
            if 'overall_metrics' in results:
                f1 = results['overall_metrics'].get('f1_score', 0)
                if f1 > best_f1:
                    best_f1 = f1
                    best_matcher = name
        
        html = f'''
        <p>This report evaluates {len(self.results)} merchant name matching algorithms on a test set of 
        {len(self.test_set) if self.test_set is not None else 0} examples.</p>
        
        <p>The best performing matcher is <strong>{best_matcher}</strong> with an F1 score of {best_f1:.3f}.</p>
        
        <p>Key findings:</p>
        <ul>
        '''
        
        # Add key findings
        if best_matcher and 'overall_metrics' in self.results[best_matcher]:
            metrics = self.results[best_matcher]['overall_metrics']
            precision = metrics.get('precision', 0)
            recall = metrics.get('recall', 0)
            
            if precision > 0.9:
                html += f'<li>{best_matcher} achieved high precision ({precision:.3f}), making it suitable for applications requiring accurate matches.</li>'
            if recall > 0.9:
                html += f'<li>{best_matcher} achieved high recall ({recall:.3f}), making it effective at finding all potential matches.</li>'
        
        # Add domain-specific insights
        if best_matcher and 'domain_performance' in self.results[best_matcher]:
            domain_metrics = self.results[best_matcher]['domain_performance']
            
            # Find best and worst performing domains
            if domain_metrics:
                domain_f1 = {domain: metrics.get('f1_score', 0) for domain, metrics in domain_metrics.items()}
                best_domain = max(domain_f1.items(), key=lambda x: x[1])
                worst_domain = min(domain_f1.items(), key=lambda x: x[1])
                
                html += f'<li>Best performance in the <strong>{best_domain[0]}</strong> domain (F1: {best_domain[1]:.3f}).</li>'
                html += f'<li>Potential improvement needed in the <strong>{worst_domain[0]}</strong> domain (F1: {worst_domain[1]:.3f}).</li>'
        
        html += '</ul>'
        return html
    
    def _generate_overall_comparison_html(self):
        """Generate HTML overall comparison section"""
        # Extract overall metrics for all matchers
        html = '''
        <div class="chart-container">
            <canvas id="overallChart"></canvas>
        </div>
        
        <table>
            <tr>
                <th>Matcher</th>
                <th>Precision</th>
                <th>Recall</th>
                <th>F1 Score</th>
                <th>Accuracy</th>
                <th>ROC AUC</th>
            </tr>
        '''
        
        # Add rows for each matcher
        for name, results in self.results.items():
            if 'overall_metrics' in results:
                metrics = results['overall_metrics']
                html += f'''
                <tr>
                    <td>{name}</td>
                    <td>{metrics.get('precision', 0):.3f}</td>
                    <td>{metrics.get('recall', 0):.3f}</td>
                    <td>{metrics.get('f1_score', 0):.3f}</td>
                    <td>{metrics.get('accuracy', 0):.3f}</td>
                    <td>{metrics.get('roc_auc', 0):.3f}</td>
                </tr>
                '''
        
        html += '</table>'
        return html
    
    def _generate_matcher_results_html(self):
        """Generate HTML sections for individual matcher results"""
        html = ''
        
        for name, results in self.results.items():
            # Skip the comparison results
            if name == 'comparison':
                continue
                
            if 'overall_metrics' not in results:
                continue
                
            metrics = results['overall_metrics']
            test_size = results.get('test_size', 0)
            html += f'''
            <div class="matcher-section">
                <h3>{name}</h3>
                <p>Evaluated on {test_size} examples.</p>
                
                <h4>Overall Performance</h4>
                <table>
                    <tr>
                        <th>Metric</th>
                        <th>Value</th>
                    </tr>
                    <tr>
                        <td>Precision</td>
                        <td>{metrics.get('precision', 0):.3f}</td>
                    </tr>
                    <tr>
                        <td>Recall</td>
                        <td>{metrics.get('recall', 0):.3f}</td>
                    </tr>
                    <tr>
                        <td>F1 Score</td>
                        <td>{metrics.get('f1_score', 0):.3f}</td>
                    </tr>
                    <tr>
                        <td>Accuracy</td>
                        <td>{metrics.get('accuracy', 0):.3f}</td>
                    </tr>
                    <tr>
                        <td>ROC AUC</td>
                        <td>{metrics.get('roc_auc', 0):.3f}</td>
                    </tr>
                    <tr>
                        <td>Optimal Threshold</td>
                        <td>{metrics.get('optimal_threshold', 0):.3f}</td>
                    </tr>
                </table>
                
                <h4>Confusion Matrix</h4>
                <div class="chart-container" style="width: 400px; height: 300px;">
                    <canvas id="confusionMatrix_{name.replace(' ', '_')}"></canvas>
                </div>
            '''
            
            # Add cross-validation results if available
            if 'cross_validation' in results:
                cv = results['cross_validation']
                html += f'''
                <h4>Cross-Validation Results ({len(cv.get('folds', []))} folds)</h4>
                <table>
                    <tr>
                        <th>Metric</th>
                        <th>Average</th>
                        <th>Std Dev</th>
                    </tr>
                '''
                
                for metric, value in cv.get('avg_metrics', {}).items():
                    std = cv.get('std_metrics', {}).get(metric, 0)
                    html += f'''
                    <tr>
                        <td>{metric}</td>
                        <td>{value:.3f}</td>
                        <td>±{std:.3f}</td>
                    </tr>
                    '''
                    
                html += '</table>'
            
            html += '</div>'
        
        return html
    
    def _generate_domain_performance_html(self):
        """Generate HTML for domain-specific performance"""
        html = '''
        <div class="chart-container">
            <canvas id="domainChart"></canvas>
        </div>
        
        <table>
            <tr>
                <th>Matcher</th>
                <th>Domain</th>
                <th>F1 Score</th>
                <th>Precision</th>
                <th>Recall</th>
            </tr>
        '''
        
        # Add rows for each matcher and domain
        for name, results in self.results.items():
            if name == 'comparison' or 'domain_performance' not in results:
                continue
                
            domain_results = results['domain_performance']
            for domain, metrics in domain_results.items():
                html += f'''
                <tr>
                    <td>{name}</td>
                    <td>{domain}</td>
                    <td>{metrics.get('f1_score', 0):.3f}</td>
                    <td>{metrics.get('precision', 0):.3f}</td>
                    <td>{metrics.get('recall', 0):.3f}</td>
                </tr>
                '''
        
        html += '</table>'
        return html
    
    def _generate_error_analysis_html(self):
        """Generate HTML for error analysis"""
        html = '''
        <div class="chart-container">
            <canvas id="errorChart"></canvas>
        </div>
        '''
        
        # Add error examples for each matcher
        for name, results in self.results.items():
            if name == 'comparison' or 'error_analysis' not in results:
                continue
                
            error_analysis = results['error_analysis']
            total_errors = error_analysis.get('total_errors', 0)
            error_rate = error_analysis.get('error_rate', 0)
            
            html += f'''
            <h3>Errors for {name}</h3>
            <p>Total Errors: {total_errors} (Error Rate: {error_rate:.2%})</p>
            '''
            
            # Error subtype breakdown
            if 'error_subtypes' in error_analysis and error_analysis['error_subtypes']:
                html += '''
                <h4>Error Type Distribution</h4>
                <table>
                    <tr>
                        <th>Error Type</th>
                        <th>Count</th>
                        <th>Percentage</th>
                    </tr>
                '''
                
                for subtype, count in error_analysis['error_subtypes'].items():
                    percentage = count / total_errors if total_errors > 0 else 0
                    html += f'''
                    <tr>
                        <td>{subtype}</td>
                        <td>{count}</td>
                        <td>{percentage:.2%}</td>
                    </tr>
                    '''
                    
                html += '</table>'
            
            # Show sample error examples
            if 'error_examples' in error_analysis and error_analysis['error_examples']:
                examples = error_analysis['error_examples'][:10]  # Limit to 10 examples
                
                html += '''
                <h4>Sample Error Examples</h4>
                <table>
                    <tr>
                        <th>Merchant 1</th>
                        <th>Merchant 2</th>
                        <th>True Label</th>
                        <th>Predicted</th>
                        <th>Score</th>
                        <th>Error Type</th>
                    </tr>
                '''
                
                for example in examples:
                    true_label = "Match" if example['true_label'] == 1 else "No Match"
                    pred_label = "Match" if example['pred_label'] == 1 else "No Match"
                    
                    html += f'''
                    <tr>
                        <td>{example['s1']}</td>
                        <td>{example['s2']}</td>
                        <td>{true_label}</td>
                        <td>{pred_label}</td>
                        <td>{example['score']:.3f}</td>
                        <td>{example.get('error_subtype', 'Unknown')}</td>
                    </tr>
                    '''
                    
                html += '</table>'
        
        return html
    
    def _generate_statistical_significance_html(self):
        """Generate HTML for statistical significance tests"""
        # Check if comparison results exist
        if 'comparison' not in self.results or 'statistical_significance' not in self.results['comparison']:
            return '<p>No statistical significance tests available.</p>'
            
        significance_tests = self.results['comparison']['statistical_significance']
        
        if not significance_tests or isinstance(significance_tests, dict) and 'error' in significance_tests:
            return '<p>No statistical significance tests available.</p>'
            
        html = '''
        <p>McNemar's test was used to determine if differences between matchers are statistically significant.</p>
        
        <table>
            <tr>
                <th>Comparison</th>
                <th>p-value</th>
                <th>Significant (p < 0.05)</th>
                <th>Better Matcher</th>
            </tr>
        '''
        
        for pair, results in significance_tests.items():
            significant = "Yes" if results['significant'] else "No"
            
            html += f'''
            <tr>
                <td>{pair}</td>
                <td>{results['p_value']:.4f}</td>
                <td>{significant}</td>
                <td>{results.get('better_matcher', 'N/A')}</td>
            </tr>
            '''
            
        html += '</table>'
        return html
    
    def _generate_chart_scripts(self):
        """Generate JavaScript for charts"""
        # Extract data for charts
        matcher_names = [name for name in self.results.keys() if name != 'comparison']
        
        # Overall metrics data
        precision_data = []
        recall_data = []
        f1_data = []
        
        for name in matcher_names:
            if 'overall_metrics' in self.results[name]:
                metrics = self.results[name]['overall_metrics']
                precision_data.append(metrics.get('precision', 0))
                recall_data.append(metrics.get('recall', 0))
                f1_data.append(metrics.get('f1_score', 0))
        
        # Domain performance data
        domain_data = {}
        for name in matcher_names:
            if 'domain_performance' in self.results[name]:
                domain_results = self.results[name]['domain_performance']
                for domain, metrics in domain_results.items():
                    if domain not in domain_data:
                        domain_data[domain] = {}
                    domain_data[domain][name] = metrics.get('f1_score', 0)
        
        # Build chart scripts
        scripts = []
        
        # Overall metrics chart
        scripts.append(f'''
        const overallCtx = document.getElementById('overallChart').getContext('2d');
        new Chart(overallCtx, {{
            type: 'bar',
            data: {{
                labels: {json.dumps(matcher_names)},
                datasets: [
                    {{
                        label: 'Precision',
                        data: {json.dumps(precision_data)},
                        backgroundColor: 'rgba(54, 162, 235, 0.5)',
                        borderColor: 'rgb(54, 162, 235)',
                        borderWidth: 1
                    }},
                    {{
                        label: 'Recall',
                        data: {json.dumps(recall_data)},
                        backgroundColor: 'rgba(255, 99, 132, 0.5)',
                        borderColor: 'rgb(255, 99, 132)',
                        borderWidth: 1
                    }},
                    {{
                        label: 'F1 Score',
                        data: {json.dumps(f1_data)},
                        backgroundColor: 'rgba(75, 192, 192, 0.5)',
                        borderColor: 'rgb(75, 192, 192)',
                        borderWidth: 1
                    }}
                ]
            }},
            options: {{
                responsive: true,
                title: {{
                    display: true,
                    text: 'Overall Matcher Performance'
                }},
                scales: {{
                    y: {{
                        beginAtZero: true,
                        max: 1
                    }}
                }}
            }}
        }});
        ''')
        
        # Domain performance chart
        domain_chart_data = []
        domains = list(domain_data.keys())
        
        for name in matcher_names:
            dataset = {
                'label': name,
                'data': [domain_data.get(domain, {}).get(name, 0) for domain in domains]
            }
            domain_chart_data.append(dataset)
            
        if domains and domain_chart_data:
            scripts.append(f'''
            const domainCtx = document.getElementById('domainChart').getContext('2d');
            new Chart(domainCtx, {{
                type: 'radar',
                data: {{
                    labels: {json.dumps(domains)},
                    datasets: {json.dumps(domain_chart_data)}
                }},
                options: {{
                    responsive: true,
                    title: {{
                        display: true,
                        text: 'Domain-Specific Performance (F1 Score)'
                    }},
                    scale: {{
                        ticks: {{
                            beginAtZero: true,
                            max: 1
                        }}
                    }}
                }}
            }});
            ''')
        
        # Confusion matrix charts for each matcher
        for name in matcher_names:
            if 'overall_metrics' in self.results[name] and 'confusion_matrix' in self.results[name]['overall_metrics']:
                cm = self.results[name]['overall_metrics']['confusion_matrix']
                
                if isinstance(cm, list) and len(cm) == 2 and len(cm[0]) == 2:
                    # Extract values from 2x2 confusion matrix
                    tn, fp = cm[0]
                    fn, tp = cm[1]
                    
                    scripts.append(f'''
                    const cmCtx_{name.replace(' ', '_')} = document.getElementById('confusionMatrix_{name.replace(' ', '_')}').getContext('2d');
                    new Chart(cmCtx_{name.replace(' ', '_')}, {{
                        type: 'pie',
                        data: {{
                            labels: ['True Positives', 'False Positives', 'True Negatives', 'False Negatives'],
                            datasets: [{{
                                data: [{tp}, {fp}, {tn}, {fn}],
                                backgroundColor: [
                                    'rgba(75, 192, 192, 0.5)',  // TP: Teal
                                    'rgba(255, 99, 132, 0.5)',  // FP: Red
                                    'rgba(54, 162, 235, 0.5)',  // TN: Blue
                                    'rgba(255, 206, 86, 0.5)'   // FN: Yellow
                                ],
                                borderColor: [
                                    'rgb(75, 192, 192)',
                                    'rgb(255, 99, 132)',
                                    'rgb(54, 162, 235)',
                                    'rgb(255, 206, 86)'
                                ],
                                borderWidth: 1
                            }}]
                        }},
                        options: {{
                            responsive: true,
                            title: {{
                                display: true,
                                text: 'Confusion Matrix'
                            }}
                        }}
                    }});
                    ''')
        
        # Error analysis chart
        error_labels = []
        error_data = []
        
        for name in matcher_names:
            if 'error_analysis' in self.results[name]:
                error_analysis = self.results[name]['error_analysis']
                error_labels.append(name)
                error_data.append(error_analysis.get('error_rate', 0))
                
        if error_labels and error_data:
            scripts.append(f'''
            const errorCtx = document.getElementById('errorChart').getContext('2d');
            new Chart(errorCtx, {{
                type: 'horizontalBar',
                data: {{
                    labels: {json.dumps(error_labels)},
                    datasets: [{{
                        label: 'Error Rate',
                        data: {json.dumps(error_data)},
                        backgroundColor: 'rgba(255, 99, 132, 0.5)',
                        borderColor: 'rgb(255, 99, 132)',
                        borderWidth: 1
                    }}]
                }},
                options: {{
                    responsive: true,
                    title: {{
                        display: true,
                        text: 'Error Rate by Matcher'
                    }},
                    scales: {{
                        x: {{
                            beginAtZero: true,
                            max: 1
                        }}
                    }}
                }}
            }});
            ''')
            
        return '\n'.join(scripts)
    
    def visualize_errors(self, matcher_name, n_examples=10):
        """
        Visualize representative error examples for a matcher
        
        Args:
            matcher_name (str): Name of the matcher to visualize
            n_examples (int): Number of examples to display
            
        Returns:
            str: HTML visualization
        """
        if matcher_name not in self.results or 'error_analysis' not in self.results[matcher_name]:
            return f"No error analysis available for matcher '{matcher_name}'"
            
        error_analysis = self.results[matcher_name]['error_analysis']
        if 'error_examples' not in error_analysis or not error_analysis['error_examples']:
            return f"No error examples available for matcher '{matcher_name}'"
            
        # Get representative examples (mixed error types)
        examples = error_analysis['error_examples'][:n_examples]
        
        # Generate HTML visualization
        html = f'''
        <h2>Error Analysis for {matcher_name}</h2>
        <p>Total Errors: {error_analysis.get('total_errors', 0)} (Error Rate: {error_analysis.get('error_rate', 0):.2%})</p>
        
        <h3>Representative Error Examples</h3>
        <table style="width:100%; border-collapse: collapse; margin-bottom: 20px;">
            <tr style="background-color: #f2f2f2;">
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Merchant 1</th>
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Merchant 2</th>
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">True</th>
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Predicted</th>
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Score</th>
                <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Error Type</th>
            </tr>
        '''
        
        for example in examples:
            true_label = "Match" if example['true_label'] == 1 else "No Match"
            pred_label = "Match" if example['pred_label'] == 1 else "No Match"
            
            # Color-code based on error type
            bg_color = "#ffebee" if example['error_type'] == 'false_positive' else "#e8f5e9"
            
            html += f'''
            <tr style="background-color: {bg_color};">
                <td style="border: 1px solid #ddd; padding: 8px;">{example['s1']}</td>
                <td style="border: 1px solid #ddd; padding: 8px;">{example['s2']}</td>
                <td style="border: 1px solid #ddd; padding: 8px;">{true_label}</td>
                <td style="border: 1px solid #ddd; padding: 8px;">{pred_label}</td>
                <td style="border: 1px solid #ddd; padding: 8px;">{example['score']:.3f}</td>
                <td style="border: 1px solid #ddd; padding: 8px;">{example.get('error_subtype', 'Unknown')}</td>
            </tr>
            '''
            
        html += '</table>'
        
        # Add error type distribution
        if 'error_subtypes' in error_analysis and error_analysis['error_subtypes']:
            html += '''
            <h3>Error Type Distribution</h3>
            <table style="width:60%; border-collapse: collapse; margin-bottom: 20px;">
                <tr style="background-color: #f2f2f2;">
                    <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Error Type</th>
                    <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Count</th>
                    <th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Percentage</th>
                </tr>
            '''
            
            total_errors = error_analysis.get('total_errors', 0)
            for subtype, count in error_analysis['error_subtypes'].items():
                percentage = count / total_errors if total_errors > 0 else 0
                html += f'''
                <tr>
                    <td style="border: 1px solid #ddd; padding: 8px;">{subtype}</td>
                    <td style="border: 1px solid #ddd; padding: 8px;">{count}</td>
                    <td style="border: 1px solid #ddd; padding: 8px;">{percentage:.2%}</td>
                </tr>
                '''
                
            html += '</table>'
        
        return html

In [57]:
# Cell 12: Comprehensive Testing and Analysis Process

class MerchantMatcherTester:
    """
    Comprehensive testing and analysis framework for the merchant name matching system.
    This class provides end-to-end functionality for validating the matching system,
    generating test cases, analyzing performance, optimizing parameters, and generating
    detailed reports with explanations and confidence metrics.
    
    Key features:
    - Input validation and preprocessing
    - Test suite generation with diverse edge cases
    - Automated threshold optimization
    - Performance analysis with detailed metrics
    - Output generation with explanations
    - Confidence scoring mechanism
    """
    
    def __init__(self, pipeline=None, config_path=None, domain=None, 
                 output_dir="./merchant_matcher_results"):
        """
        Initialize the tester with the matching pipeline and configuration
        
        Args:
            pipeline (GMARTMerchantMatchingPipeline): Existing matcher pipeline
            config_path (str): Path to configuration file
            domain (str): Default domain for testing
            output_dir (str): Directory for output files
        """
        # Use existing pipeline or create a new one
        if pipeline:
            self.pipeline = pipeline
        else:
            self.pipeline = GMARTMerchantMatchingPipeline(
                config_path=config_path,
                domain=domain,
                debug_mode=True,
                log_level='INFO'
            )
        
        # Set up output directory
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize performance metrics tracking
        self.performance_metrics = {}
        self.test_results = pd.DataFrame()
        self.optimal_thresholds = {}
        
        # Tracking for confidence score calibration
        self.score_distribution = {
            'true_positives': [],
            'false_positives': [],
            'true_negatives': [],
            'false_negatives': []
        }
        
        # Set up logging for test results
        self._setup_logging()
        
        # Load or initialize confusion matrix
        self.confusion_matrix = {
            'true_positives': 0,
            'false_positives': 0,
            'false_negatives': 0,
            'true_negatives': 0
        }
    
    def _setup_logging(self):
        """Set up logging for test results"""
        self.logger = logging.getLogger('merchant_matcher_tester')
        self.logger.setLevel(logging.INFO)
        
        # Create file handler
        log_file = os.path.join(self.output_dir, "test_results.log")
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        
        # Create formatter and add to handler
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        
        # Add handler to logger
        self.logger.addHandler(file_handler)
        
        self.logger.info("MerchantMatcherTester initialized")
    
    def _validate_test_data(self, test_data):
        """
        Validate test data and preprocess for testing
        
        Args:
            test_data (DataFrame): Test data with merchant pairs
            
        Returns:
            DataFrame: Validated and preprocessed test data
        """
        # Check if DataFrame
        if not isinstance(test_data, pd.DataFrame):
            raise ValueError("Test data must be a pandas DataFrame")
        
        # Check required columns
        required_cols = ['s1', 's2']
        if not all(col in test_data.columns for col in required_cols):
            # Try to adapt common column formats
            if 'Acronym' in test_data.columns and 'Full_Name' in test_data.columns:
                self.logger.info("Adapting Acronym/Full_Name format to s1/s2")
                test_data = test_data.rename(columns={'Acronym': 's1', 'Full_Name': 's2'})
            elif 'input_name' in test_data.columns and 'matched_name' in test_data.columns:
                self.logger.info("Adapting input_name/matched_name format to s1/s2")
                test_data = test_data.rename(columns={'input_name': 's1', 'matched_name': 's2'})
            else:
                raise ValueError(f"Test data must contain 's1' and 's2' columns. Found: {test_data.columns.tolist()}")
        
        # Check if ground truth exists
        if 'is_match' not in test_data.columns and 'expected_match' not in test_data.columns:
            self.logger.warning("No ground truth column found. Adding placeholder 'is_match' column.")
            test_data['is_match'] = None
        elif 'expected_match' in test_data.columns and 'is_match' not in test_data.columns:
            self.logger.info("Renaming 'expected_match' to 'is_match'")
            test_data = test_data.rename(columns={'expected_match': 'is_match'})
        
        # Handle missing values
        test_data = test_data.dropna(subset=['s1', 's2'])
        self.logger.info(f"Validated test data with {len(test_data)} valid pairs")
        
        return test_data
    
    def generate_test_suite(self, base_data=None, num_cases=100, include_edge_cases=True,
                            domain=None, output_file=None):
        """
        Generate a comprehensive test suite with diverse test cases
        
        Args:
            base_data (DataFrame): Base data to augment with test cases
            num_cases (int): Number of test cases to generate
            include_edge_cases (bool): Whether to include edge cases
            domain (str): Domain for test cases
            output_file (str): Path to save test suite
            
        Returns:
            DataFrame: Generated test suite
        """
        self.logger.info(f"Generating test suite with {num_cases} cases")
        
        # Initialize test suite
        if base_data is not None:
            test_suite = self._validate_test_data(base_data).copy()
        else:
            test_suite = pd.DataFrame(columns=['s1', 's2', 'is_match', 'case_type'])
        
        # Get merchant matcher for generation
        matcher = self.pipeline._get_merchant_matcher()
        preprocessor = self.pipeline._initialized_components.get('preprocessor')
        
        if not preprocessor:
            raise ValueError("Preprocessor not initialized in pipeline")
        
        # Generate positive pairs (matching merchants)
        positive_pairs = []
        
        # Common merchant names for test generation
        common_merchants = {
            'banking': [
                'Bank of America', 'Chase Bank', 'Wells Fargo', 'Citibank', 
                'JP Morgan', 'Capital One', 'HSBC Bank', 'TD Bank', 'PNC Bank',
                'US Bank', 'Barclays', 'Santander', 'Bank of the West', 'Regions Bank'
            ],
            'retail': [
                'Walmart', 'Target', 'Costco', 'Amazon', 'Best Buy', 'Home Depot', 
                'Lowe\'s', 'Macy\'s', 'Kohl\'s', 'Nordstrom', 'Staples', 'Office Depot',
                'CVS Pharmacy', 'Walgreens', 'Dollar General'
            ],
            'restaurant': [
                'McDonald\'s', 'Starbucks', 'Subway', 'Taco Bell', 'Burger King',
                'Wendy\'s', 'KFC', 'Pizza Hut', 'Domino\'s Pizza', 'Chipotle',
                'Panera Bread', 'Olive Garden', 'Chili\'s', 'Applebee\'s', 'Outback'
            ],
            'gas_station': [
                'Shell', 'Exxon', 'BP', 'Chevron', 'Texaco', 'Mobil', 'Sunoco',
                'Valero', 'Marathon', 'Phillips 66', 'Circle K', '7-Eleven',
                'QuikTrip', 'RaceTrac', 'Speedway'
            ],
            'hotel': [
                'Marriott', 'Hilton', 'Hyatt', 'Holiday Inn', 'Sheraton', 'Westin',
                'Hampton Inn', 'Comfort Inn', 'Best Western', 'Courtyard Marriott',
                'Embassy Suites', 'Doubletree', 'Four Seasons', 'Ritz-Carlton'
            ]
        }
        
        # Select domain merchants
        if domain and domain.lower() in common_merchants:
            merchants = common_merchants[domain.lower()]
        else:
            # Combine all merchants
            merchants = []
            for domain_merchants in common_merchants.values():
                merchants.extend(domain_merchants)
        
        # Generate positive pairs with variations
        num_positive = min(num_cases // 2, 50)
        for i in range(num_positive):
            if len(merchants) > 0:
                # Select random merchant
                merchant = random.choice(merchants)
                
                # Create variations
                variation_type = random.choice([
                    'suffix', 'prefix', 'abbreviation', 'typo', 'spacing',
                    'punctuation', 'capitalization', 'word_order'
                ])
                
                if variation_type == 'suffix':
                    # Add business suffix
                    suffixes = [' Inc', ' LLC', ' Corp', ' Company', ' Co', ' Ltd']
                    variation = merchant + random.choice(suffixes)
                elif variation_type == 'prefix':
                    # Add location prefix
                    prefixes = ['Downtown ', 'North ', 'South ', 'East ', 'West ', 'Central ']
                    variation = random.choice(prefixes) + merchant
                elif variation_type == 'abbreviation':
                    # Create abbreviation or acronym
                    if ' ' in merchant:
                        words = merchant.split()
                        variation = ''.join(word[0] for word in words)
                    else:
                        # Abbreviate by removing vowels
                        variation = ''.join(c for c in merchant if c.lower() not in 'aeiou')
                elif variation_type == 'typo':
                    # Introduce typo
                    chars = list(merchant)
                    if len(chars) > 3:
                        pos = random.randint(1, len(chars)-2)
                        chars[pos] = random.choice('abcdefghijklmnopqrstuvwxyz')
                        variation = ''.join(chars)
                    else:
                        variation = merchant
                elif variation_type == 'spacing':
                    # Modify spacing
                    if ' ' in merchant:
                        variation = merchant.replace(' ', '')
                    else:
                        # Insert random space
                        pos = random.randint(1, len(merchant)-1)
                        variation = merchant[:pos] + ' ' + merchant[pos:]
                elif variation_type == 'punctuation':
                    # Add/modify punctuation
                    if "'" in merchant:
                        variation = merchant.replace("'", "")
                    else:
                        # Add apostrophe or hyphen
                        if random.random() < 0.5:
                            # Add apostrophe
                            pos = random.randint(1, len(merchant)-1)
                            variation = merchant[:pos] + "'" + merchant[pos:]
                        else:
                            # Add hyphen if spaces
                            variation = merchant.replace(' ', '-')
                elif variation_type == 'capitalization':
                    # Change capitalization
                    options = [merchant.upper(), merchant.lower(), merchant.title()]
                    variation = random.choice(options)
                elif variation_type == 'word_order':
                    # Change word order if multiple words
                    if ' ' in merchant:
                        words = merchant.split()
                        if len(words) > 1:
                            random.shuffle(words)
                            variation = ' '.join(words)
                        else:
                            variation = merchant
                    else:
                        variation = merchant
                
                positive_pairs.append({
                    's1': merchant,
                    's2': variation,
                    'is_match': 1,
                    'case_type': f'positive_{variation_type}'
                })
        
        # Generate negative pairs (non-matching merchants)
        negative_pairs = []
        num_negative = min(num_cases - num_positive, 50)
        for i in range(num_negative):
            if len(merchants) > 1:
                # Select two different merchants
                merchant1 = random.choice(merchants)
                merchant2 = random.choice([m for m in merchants if m != merchant1])
                
                # Create different types of negative pairs
                case_type = random.choice([
                    'completely_different', 'partial_match', 'similar_industry', 
                    'similar_word', 'substring', 'acronym_confusion'
                ])
                
                if case_type == 'completely_different':
                    # Use merchants as is
                    pass
                elif case_type == 'partial_match':
                    # Create partial word match
                    if ' ' in merchant1 and ' ' in merchant2:
                        words1 = merchant1.split()
                        words2 = merchant2.split()
                        # Share one word
                        shared_word = random.choice(words1)
                        pos = random.randint(0, len(words2)-1)
                        words2[pos] = shared_word
                        merchant2 = ' '.join(words2)
                elif case_type == 'similar_industry':
                    # Ensure from same industry but different
                    for domain_name, domain_merchants in common_merchants.items():
                        if merchant1 in domain_merchants:
                            other_merchants = [m for m in domain_merchants 
                                             if m != merchant1]
                            if other_merchants:
                                merchant2 = random.choice(other_merchants)
                            break
                elif case_type == 'similar_word':
                    # Find similar sounding merchant
                    similar_merchants = sorted(merchants, 
                                             key=lambda m: jaro_winkler(merchant1, m))
                    if len(similar_merchants) > 2:
                        # Pick second most similar (most similar would be itself)
                        merchant2 = similar_merchants[-2]
                elif case_type == 'substring':
                    # One is substring of other
                    if len(merchant1) > 5:
                        merchant2 = merchant1[:len(merchant1)//2]
                    else:
                        # If too short, just use different merchant
                        merchant2 = random.choice([m for m in merchants if m != merchant1])
                elif case_type == 'acronym_confusion':
                    # Create confusing acronym
                    if ' ' in merchant1:
                        words = merchant1.split()
                        acronym = ''.join(word[0] for word in words)
                        # Find another merchant with same acronym
                        for m in merchants:
                            if m != merchant1 and ' ' in m:
                                m_words = m.split()
                                m_acronym = ''.join(word[0] for word in m_words)
                                if m_acronym == acronym:
                                    merchant2 = m
                                    break
                
                # Add negative pair
                negative_pairs.append({
                    's1': merchant1,
                    's2': merchant2,
                    'is_match': 0,
                    'case_type': f'negative_{case_type}'
                })
        
        # Add edge cases if requested
        edge_cases = []
        if include_edge_cases:
            # Empty strings
            edge_cases.append({
                's1': '', 's2': 'Valid Merchant', 'is_match': 0, 'case_type': 'edge_empty_s1'
            })
            edge_cases.append({
                's1': 'Valid Merchant', 's2': '', 'is_match': 0, 'case_type': 'edge_empty_s2'
            })
            
            # Very short names
            edge_cases.append({
                's1': 'IBM', 's2': 'IBM Corp', 'is_match': 1, 'case_type': 'edge_short_name'
            })
            
            # Special characters
            edge_cases.append({
                's1': 'AT&T', 's2': 'AT and T', 'is_match': 1, 'case_type': 'edge_special_chars'
            })
            
            # Numbers
            edge_cases.append({
                's1': '7-Eleven', 's2': 'Seven Eleven', 'is_match': 1, 'case_type': 'edge_numbers'
            })
            
            # Extreme length difference
            edge_cases.append({
                's1': 'IBM', 
                's2': 'International Business Machines Corporation Worldwide', 
                'is_match': 1, 
                'case_type': 'edge_length_diff'
            })
            
            # Repeated words
            edge_cases.append({
                's1': 'Buffalo Buffalo', 
                's2': 'Buffalo', 
                'is_match': 1, 
                'case_type': 'edge_repeated_words'
            })
            
            # Non-English characters
            edge_cases.append({
                's1': 'Café Coffee', 
                's2': 'Cafe Coffee', 
                'is_match': 1, 
                'case_type': 'edge_non_english'
            })
        
        # Combine all test cases
        all_cases = pd.DataFrame(positive_pairs + negative_pairs + edge_cases)
        
        # Append to existing test suite if provided
        if not test_suite.empty:
            test_suite = pd.concat([test_suite, all_cases], ignore_index=True)
        else:
            test_suite = all_cases
        
        # Save to file if requested
        if output_file:
            output_path = os.path.join(self.output_dir, output_file)
            test_suite.to_csv(output_path, index=False)
            self.logger.info(f"Test suite saved to {output_path}")
        
        return test_suite
    
    def run_tests(self, test_data, domain=None, threshold=None, save_results=True):
        """
        Run comprehensive tests on the matcher with the provided test data
        
        Args:
            test_data (DataFrame): Test data with merchant pairs and expected matches
            domain (str): Domain for testing
            threshold (float): Custom threshold for matching
            save_results (bool): Whether to save detailed results
            
        Returns:
            dict: Test results and performance metrics
        """
        # Validate and preprocess test data
        test_data = self._validate_test_data(test_data)
        
        # Check if ground truth exists
        has_ground_truth = 'is_match' in test_data.columns and not test_data['is_match'].isna().all()
        
        # Get effective domain
        effective_domain = domain or self.pipeline.domain
        
        # Run matcher on test data
        self.logger.info(f"Running tests on {len(test_data)} merchant pairs")
        
        # Get matcher
        matcher = self.pipeline._get_merchant_matcher()
        
        # Customize threshold if provided
        original_thresholds = None
        if threshold is not None:
            original_thresholds = matcher.thresholds.copy()
            matcher.thresholds = {
                'high': threshold + 0.1,
                'medium': threshold,
                'low': threshold - 0.1
            }
        
        # Process in batches
        batch_size = 100
        results = []
        
        for i in range(0, len(test_data), batch_size):
            batch = test_data.iloc[i:i+batch_size]
            self.logger.info(f"Processing batch {i//batch_size + 1}/{(len(test_data)-1)//batch_size + 1}")
            
            batch_results = []
            for _, row in batch.iterrows():
                s1 = row['s1']
                s2 = row['s2']
                
                # Skip invalid pairs
                if not isinstance(s1, str) or not isinstance(s2, str):
                    continue
                
                # Get detailed match result
                match_result = matcher.match_merchants(
                    s1, s2, effective_domain, return_details=True
                )
                
                # Create result entry
                result_entry = {
                    's1': s1,
                    's2': s2,
                    'score': match_result['match_score'],
                    'level': match_result['match_level'],
                    'confidence': self._calculate_confidence(match_result),
                    'expected_match': row.get('is_match') if has_ground_truth else None,
                    'explanation': match_result['explanation'],
                    'case_type': row.get('case_type', 'unknown')
                }
                
                # Calculate match correctness if ground truth exists
                if has_ground_truth and not pd.isna(row['is_match']):
                    expected = bool(row['is_match'])
                    actual = match_result['match_score'] >= (threshold or matcher.thresholds['medium'])
                    result_entry['correct'] = expected == actual
                    
                    # Update confusion matrix
                    if expected and actual:
                        self.confusion_matrix['true_positives'] += 1
                        self.score_distribution['true_positives'].append(match_result['match_score'])
                    elif expected and not actual:
                        self.confusion_matrix['false_negatives'] += 1
                        self.score_distribution['false_negatives'].append(match_result['match_score'])
                    elif not expected and actual:
                        self.confusion_matrix['false_positives'] += 1
                        self.score_distribution['false_positives'].append(match_result['match_score'])
                    else:
                        self.confusion_matrix['true_negatives'] += 1
                        self.score_distribution['true_negatives'].append(match_result['match_score'])
                
                batch_results.append(result_entry)
            
            results.extend(batch_results)
        
        # Restore original thresholds if modified
        if original_thresholds:
            matcher.thresholds = original_thresholds
        
        # Convert results to DataFrame
        results_df = pd.DataFrame(results)
        self.test_results = results_df
        
        # Calculate performance metrics if ground truth exists
        performance = {}
        if has_ground_truth:
            # Calculate accuracy, precision, recall, F1
            tp = self.confusion_matrix['true_positives']
            fp = self.confusion_matrix['false_positives']
            fn = self.confusion_matrix['false_negatives']
            tn = self.confusion_matrix['true_negatives']
            
            total = tp + fp + fn + tn
            
            if total > 0:
                accuracy = (tp + tn) / total
            else:
                accuracy = 0.0
                
            if (tp + fp) > 0:
                precision = tp / (tp + fp)
            else:
                precision = 0.0
                
            if (tp + fn) > 0:
                recall = tp / (tp + fn)
            else:
                recall = 0.0
                
            if (precision + recall) > 0:
                f1 = 2 * (precision * recall) / (precision + recall)
            else:
                f1 = 0.0
            
            performance = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'true_positives': tp,
                'false_positives': fp,
                'false_negatives': fn,
                'true_negatives': tn
            }
            
            # Calculate performance by case type
            case_types = results_df['case_type'].unique()
            case_performance = {}
            
            for case_type in case_types:
                case_results = results_df[results_df['case_type'] == case_type]
                if not case_results.empty and 'correct' in case_results.columns:
                    case_performance[case_type] = {
                        'count': len(case_results),
                        'accuracy': case_results['correct'].mean(),
                        'avg_score': case_results['score'].mean(),
                        'avg_confidence': case_results['confidence'].mean()
                    }
            
            performance['case_type_performance'] = case_performance
        
        # Save the performance metrics
        self.performance_metrics = performance
        
        # Save results if requested
        if save_results:
            # Save detailed results
            results_path = os.path.join(self.output_dir, "test_results.csv")
            results_df.to_csv(results_path, index=False)
            
            # Save performance metrics
            if performance:
                metrics_path = os.path.join(self.output_dir, "performance_metrics.json")
                with open(metrics_path, 'w') as f:
                    json.dump(performance, f, indent=2)
            
            self.logger.info(f"Test results saved to {self.output_dir}")
        
        # Log performance summary
        if performance:
            self.logger.info(f"Performance Summary:")
            self.logger.info(f"  Accuracy: {performance['accuracy']:.4f}")
            self.logger.info(f"  Precision: {performance['precision']:.4f}")
            self.logger.info(f"  Recall: {performance['recall']:.4f}")
            self.logger.info(f"  F1 Score: {performance['f1_score']:.4f}")
        
        return {
            'results': results_df,
            'performance': performance,
            'domain': effective_domain,
            'threshold': threshold
        }
    
    def _calculate_confidence(self, match_result):
        """
        Calculate confidence score for a match result
        
        Args:
            match_result (dict): Match result from matcher
            
        Returns:
            float: Confidence score between 0 and 1
        """
        # Extract key metrics
        score = match_result['match_score']
        features = match_result.get('features', {})
        patterns = match_result.get('patterns', {})
        
        # Base confidence on match score
        confidence = score
        
        # Adjust confidence based on feature consistency
        feature_values = [v for k, v in features.items() 
                       if k in ['string_similarity', 'token_set_similarity', 
                              'semantic_similarity', 'contains_score']]
        
        if feature_values:
            # Reduce confidence if features disagree significantly
            feature_std = np.std(feature_values) if len(feature_values) > 1 else 0
            confidence *= (1 - 0.5 * min(feature_std, 0.5))
        
        # Boost confidence if patterns detected
        if patterns:
            confidence = min(confidence + 0.1, 1.0)
        
        # Penalize confidence for extreme length differences
        s1_clean = match_result.get('processed_s1', '')
        s2_clean = match_result.get('processed_s2', '')
        if s1_clean and s2_clean:
            len_ratio = min(len(s1_clean), len(s2_clean)) / max(len(s1_clean), len(s2_clean)) if max(len(s1_clean), len(s2_clean)) > 0 else 1
            if len_ratio < 0.3:
                confidence *= max(0.5, len_ratio + 0.2)
        
        return confidence
    
    def optimize_thresholds(self, test_data=None, domain=None):
        """
        Optimize matching thresholds based on test results
        
        Args:
            test_data (DataFrame): Test data with merchant pairs
            domain (str): Domain for optimization
            
        Returns:
            dict: Optimized thresholds and performance
        """
        # Use existing test results if no data provided
        if test_data is None:
            if self.test_results.empty:
                raise ValueError("No test data provided and no existing test results")
            test_data = self.test_results
        else:
            test_data = self._validate_test_data(test_data)
        
        # Ensure ground truth exists
        if 'is_match' not in test_data.columns or test_data['is_match'].isna().all():
            raise ValueError("Test data must contain 'is_match' ground truth for threshold optimization")
        
        # Get effective domain
        effective_domain = domain or self.pipeline.domain
        
        self.logger.info(f"Optimizing thresholds for domain: {effective_domain or 'general'}")
        
        # Try multiple thresholds and measure performance
        thresholds_to_try = np.arange(0.5, 0.95, 0.05)
        threshold_results = []
        
        for threshold in thresholds_to_try:
            # Run test with this threshold
            result = self.run_tests(
                test_data, 
                domain=effective_domain,
                threshold=threshold,
                save_results=False
            )
            
            # Store performance
            performance = result['performance']
            threshold_results.append({
                'threshold': threshold,
                'accuracy': performance.get('accuracy', 0),
                'precision': performance.get('precision', 0),
                'recall': performance.get('recall', 0),
                'f1_score': performance.get('f1_score', 0)
            })
        
        # Create thresholds DataFrame
        thresholds_df = pd.DataFrame(threshold_results)
        
        # Find best thresholds for different metrics
        best_accuracy = thresholds_df.loc[thresholds_df['accuracy'].idxmax()]
        best_f1 = thresholds_df.loc[thresholds_df['f1_score'].idxmax()]
        best_precision = thresholds_df.loc[thresholds_df['precision'].idxmax()]
        best_recall = thresholds_df.loc[thresholds_df['recall'].idxmax()]
        
        # Choose balanced threshold (best F1)
        optimal_threshold = best_f1['threshold']
        
        # Set optimal thresholds
        optimal_thresholds = {
            'high': optimal_threshold + 0.1,
            'medium': optimal_threshold,
            'low': optimal_threshold - 0.1
        }
        
        # Update matcher thresholds
        matcher = self.pipeline._get_merchant_matcher()
        matcher.thresholds = optimal_thresholds
        
        # Save thresholds
        self.optimal_thresholds = {
            'threshold': optimal_threshold,
            'accuracy': best_accuracy['accuracy'],
            'f1': best_f1['f1_score'],
            'precision': best_precision['precision'],
            'recall': best_recall['recall'],
            'thresholds': optimal_thresholds,
            'all_results': thresholds_df.to_dict('records')
        }
        
        # Save results
        thresholds_path = os.path.join(self.output_dir, "optimal_thresholds.json")
        with open(thresholds_path, 'w') as f:
            json.dump(self.optimal_thresholds, f, indent=2)
        
        # Generate visualization if matplotlib available
        try:
            self._visualize_threshold_optimization(thresholds_df)
        except:
            self.logger.warning("Could not generate threshold visualization")
        
        # Log results
        self.logger.info(f"Optimal thresholds determined:")
        self.logger.info(f"  High: {optimal_thresholds['high']:.2f}")
        self.logger.info(f"  Medium: {optimal_thresholds['medium']:.2f}")
        self.logger.info(f"  Low: {optimal_thresholds['low']:.2f}")
        self.logger.info(f"  Best F1 Score: {best_f1['f1_score']:.4f}")
        
        return self.optimal_thresholds
    
    def _visualize_threshold_optimization(self, thresholds_df):
        """Generate visualization for threshold optimization"""
        try:
            import matplotlib.pyplot as plt
            
            # Create figure
            plt.figure(figsize=(10, 6))
            
            # Plot metrics vs thresholds
            plt.plot(thresholds_df['threshold'], thresholds_df['accuracy'], 'b-', label='Accuracy')
            plt.plot(thresholds_df['threshold'], thresholds_df['precision'], 'g-', label='Precision')
            plt.plot(thresholds_df['threshold'], thresholds_df['recall'], 'r-', label='Recall')
            plt.plot(thresholds_df['threshold'], thresholds_df['f1_score'], 'k-', label='F1 Score')
            
            # Mark optimal threshold
            optimal_threshold = self.optimal_thresholds['threshold']
            optimal_f1 = self.optimal_thresholds['f1']
            plt.axvline(x=optimal_threshold, color='m', linestyle='--', alpha=0.5)
            plt.scatter([optimal_threshold], [optimal_f1], color='m', s=100, 
                      label=f'Optimal Threshold: {optimal_threshold:.2f}')
            
            # Add labels and legend
            plt.xlabel('Threshold')
            plt.ylabel('Metric Value')
            plt.title('Performance Metrics vs. Threshold')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Save figure
            plt.savefig(os.path.join(self.output_dir, "threshold_optimization.png"))
            plt.close()
            
        except Exception as e:
            self.logger.warning(f"Could not generate threshold visualization: {e}")
    
    def analyze_performance(self, by_domain=False, by_case_type=True, by_confidence=True):
        """
        Analyze matcher performance in depth
        
        Args:
            by_domain (bool): Whether to analyze by domain
            by_case_type (bool): Whether to analyze by case type
            by_confidence (bool): Whether to analyze by confidence level
            
        Returns:
            dict: Detailed performance analysis
        """
        if self.test_results.empty or not self.performance_metrics:
            raise ValueError("No test results available for analysis")
        
        self.logger.info("Analyzing matcher performance")
        
        analysis = {
            'overall': self.performance_metrics,
            'by_domain': {},
            'by_case_type': {},
            'by_confidence': {},
            'error_analysis': {}
        }
        
        # Analyze by domain if requested
        if by_domain and 'domain' in self.test_results.columns:
            domains = self.test_results['domain'].unique()
            for domain in domains:
                if not pd.isna(domain):
                    domain_results = self.test_results[self.test_results['domain'] == domain]
                    if 'correct' in domain_results.columns:
                        analysis['by_domain'][domain] = {
                            'count': len(domain_results),
                            'accuracy': domain_results['correct'].mean(),
                            'avg_score': domain_results['score'].mean(),
                            'avg_confidence': domain_results['confidence'].mean()
                        }
                        
                        # Calculate precision, recall, F1 if expected matches are present
                        if 'expected_match' in domain_results.columns:
                            tp = sum((domain_results['expected_match'] == 1) & (domain_results['score'] >= 0.7))
                            fp = sum((domain_results['expected_match'] == 0) & (domain_results['score'] >= 0.7))
                            fn = sum((domain_results['expected_match'] == 1) & (domain_results['score'] < 0.7))
                            tn = sum((domain_results['expected_match'] == 0) & (domain_results['score'] < 0.7))
                            
                            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
                            
                            analysis['by_domain'][domain].update({
                                'precision': precision,
                                'recall': recall,
                                'f1_score': f1,
                                'confusion_matrix': {
                                    'true_positives': int(tp),
                                    'false_positives': int(fp),
                                    'false_negatives': int(fn),
                                    'true_negatives': int(tn)
                                }
                            })
        
        # Analyze by case type if requested
        if by_case_type and 'case_type' in self.test_results.columns:
            case_types = self.test_results['case_type'].unique()
            for case_type in case_types:
                if not pd.isna(case_type):
                    case_results = self.test_results[self.test_results['case_type'] == case_type]
                    if 'correct' in case_results.columns:
                        analysis['by_case_type'][case_type] = {
                            'count': len(case_results),
                            'accuracy': case_results['correct'].mean(),
                            'avg_score': case_results['score'].mean(),
                            'avg_confidence': case_results['confidence'].mean()
                        }
                        
                        # Identify common failure patterns
                        if len(case_results) > 0 and 'correct' in case_results.columns:
                            failures = case_results[~case_results['correct']]
                            if len(failures) > 0:
                                analysis['by_case_type'][case_type]['failure_analysis'] = {
                                    'failure_count': len(failures),
                                    'failure_rate': len(failures) / len(case_results),
                                    'avg_failure_score': failures['score'].mean(),
                                    'example_failures': failures.head(3)[['s1', 's2', 'score', 'expected_match']].to_dict('records')
                                }
        
        # Analyze by confidence level if requested
        if by_confidence and 'confidence' in self.test_results.columns:
            # Create confidence bins
            self.test_results['confidence_bin'] = pd.cut(
                self.test_results['confidence'], 
                bins=[0, 0.5, 0.7, 0.85, 0.95, 1.0], 
                labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']
            )
            
            confidence_bins = self.test_results['confidence_bin'].unique()
            for confidence_bin in confidence_bins:
                if not pd.isna(confidence_bin):
                    bin_results = self.test_results[self.test_results['confidence_bin'] == confidence_bin]
                    if 'correct' in bin_results.columns:
                        analysis['by_confidence'][str(confidence_bin)] = {
                            'count': len(bin_results),
                            'accuracy': bin_results['correct'].mean(),
                            'avg_score': bin_results['score'].mean()
                        }
                        
                        # Calculate reliability ratio (how well confidence predicts accuracy)
                        if len(bin_results) > 0 and 'correct' in bin_results.columns:
                            avg_confidence = bin_results['confidence'].mean()
                            accuracy = bin_results['correct'].mean()
                            reliability_ratio = min(avg_confidence, accuracy) / max(avg_confidence, accuracy) if max(avg_confidence, accuracy) > 0 else 0
                            
                            analysis['by_confidence'][str(confidence_bin)]['reliability_ratio'] = reliability_ratio
        
        # Error analysis - identify common patterns in false positives and negatives
        if 'expected_match' in self.test_results.columns and 'score' in self.test_results.columns:
            # Get false positives (matched but shouldn't have)
            false_positives = self.test_results[
                (self.test_results['expected_match'] == 0) & 
                (self.test_results['score'] >= 0.7)
            ]
            
            # Get false negatives (didn't match but should have)
            false_negatives = self.test_results[
                (self.test_results['expected_match'] == 1) & 
                (self.test_results['score'] < 0.7)
            ]
            
            # Analyze false positives
            if len(false_positives) > 0:
                analysis['error_analysis']['false_positives'] = {
                    'count': len(false_positives),
                    'avg_score': false_positives['score'].mean(),
                    'avg_confidence': false_positives['confidence'].mean(),
                    'examples': false_positives.head(5)[['s1', 's2', 'score', 'confidence']].to_dict('records')
                }
                
                # Analyze by case type if available
                if 'case_type' in false_positives.columns:
                    type_counts = false_positives['case_type'].value_counts()
                    analysis['error_analysis']['false_positives']['case_type_distribution'] = type_counts.to_dict()
            
            # Analyze false negatives
            if len(false_negatives) > 0:
                analysis['error_analysis']['false_negatives'] = {
                    'count': len(false_negatives),
                    'avg_score': false_negatives['score'].mean(),
                    'avg_confidence': false_negatives['confidence'].mean(),
                    'examples': false_negatives.head(5)[['s1', 's2', 'score', 'confidence']].to_dict('records')
                }
                
                # Analyze by case type if available
                if 'case_type' in false_negatives.columns:
                    type_counts = false_negatives['case_type'].value_counts()
                    analysis['error_analysis']['false_negatives']['case_type_distribution'] = type_counts.to_dict()
        
        # Save analysis results
        analysis_path = os.path.join(self.output_dir, "performance_analysis.json")
        with open(analysis_path, 'w') as f:
            json.dump(analysis, f, indent=2)
        
        self.logger.info(f"Performance analysis saved to {analysis_path}")
        return analysis
    
    def generate_report(self, output_format='html', include_visualizations=True):
        """
        Generate a comprehensive report of test results and analysis
        
        Args:
            output_format (str): Report format ('html', 'pdf', 'markdown')
            include_visualizations (bool): Include visualizations in report
            
        Returns:
            str: Path to generated report
        """
        if self.test_results.empty:
            raise ValueError("No test results available for report generation")
        
        self.logger.info(f"Generating {output_format} report")
        
        # Prepare report data
        report_data = {
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'test_summary': {
                'total_tests': len(self.test_results),
                'performance_metrics': self.performance_metrics,
                'optimal_thresholds': self.optimal_thresholds
            },
            'test_results': self.test_results.to_dict('records')[:100]  # Limit to 100 examples
        }
        
        # Generate visualizations if requested
        if include_visualizations:
            visualization_paths = self._generate_visualizations()
            report_data['visualizations'] = visualization_paths
        
        # Create report based on format
        if output_format == 'html':
            report_path = self._generate_html_report(report_data)
        elif output_format == 'markdown':
            report_path = self._generate_markdown_report(report_data)
        elif output_format == 'pdf':
            try:
                # Try to convert HTML to PDF if library available
                report_path = self._generate_html_report(report_data)
                pdf_path = report_path.replace('.html', '.pdf')
                
                try:
                    import weasyprint
                    html = weasyprint.HTML(filename=report_path)
                    html.write_pdf(pdf_path)
                    report_path = pdf_path
                except ImportError:
                    self.logger.warning("WeasyPrint not available. Generated HTML report instead.")
            except Exception as e:
                self.logger.error(f"Failed to generate PDF report: {e}")
                report_path = self._generate_html_report(report_data)
        else:
            # Default to HTML
            report_path = self._generate_html_report(report_data)
        
        self.logger.info(f"Report generated at {report_path}")
        return report_path
    
    def _generate_visualizations(self):
        """Generate visualizations for report"""
        visualization_paths = []
        
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            # Create output directory for visualizations
            viz_dir = os.path.join(self.output_dir, "visualizations")
            os.makedirs(viz_dir, exist_ok=True)
            
            # 1. Score distribution plot
            plt.figure(figsize=(10, 6))
            if 'score' in self.test_results.columns and 'expected_match' in self.test_results.columns:
                sns.histplot(
                    data=self.test_results, 
                    x='score', 
                    hue='expected_match',
                    bins=20,
                    kde=True
                )
                plt.title('Score Distribution by Expected Match')
                plt.xlabel('Match Score')
                plt.ylabel('Count')
                plt.legend(['Not a Match', 'Match'])
                
                score_dist_path = os.path.join(viz_dir, "score_distribution.png")
                plt.savefig(score_dist_path)
                plt.close()
                visualization_paths.append(score_dist_path)
            
            # 2. Confusion matrix heatmap
            if self.confusion_matrix:
                plt.figure(figsize=(8, 6))
                cm_data = np.array([
                    [self.confusion_matrix['true_negatives'], self.confusion_matrix['false_positives']],
                    [self.confusion_matrix['false_negatives'], self.confusion_matrix['true_positives']]
                ])
                
                sns.heatmap(
                    cm_data, 
                    annot=True, 
                    fmt='d', 
                    cmap='Blues',
                    xticklabels=['Predicted No Match', 'Predicted Match'],
                    yticklabels=['Actual No Match', 'Actual Match']
                )
                plt.title('Confusion Matrix')
                
                cm_path = os.path.join(viz_dir, "confusion_matrix.png")
                plt.savefig(cm_path)
                plt.close()
                visualization_paths.append(cm_path)
            
            # 3. Confidence vs. Accuracy plot
            if 'confidence' in self.test_results.columns and 'correct' in self.test_results.columns:
                plt.figure(figsize=(10, 6))
                
                # Group by confidence bins
                self.test_results['confidence_bin'] = pd.cut(
                    self.test_results['confidence'], 
                    bins=10
                )
                
                conf_acc = self.test_results.groupby('confidence_bin')['correct'].mean().reset_index()
                conf_acc['bin_center'] = conf_acc['confidence_bin'].apply(lambda x: x.mid)
                
                plt.plot(conf_acc['bin_center'], conf_acc['correct'], 'o-', linewidth=2)
                plt.plot([0, 1], [0, 1], 'k--')  # Ideal calibration line
                
                plt.title('Confidence Calibration Plot')
                plt.xlabel('Confidence')
                plt.ylabel('Accuracy')
                plt.grid(alpha=0.3)
                
                calib_path = os.path.join(viz_dir, "confidence_calibration.png")
                plt.savefig(calib_path)
                plt.close()
                visualization_paths.append(calib_path)
            
            # 4. Case type performance
            if 'case_type' in self.test_results.columns and 'correct' in self.test_results.columns:
                plt.figure(figsize=(12, 8))
                
                case_perf = self.test_results.groupby('case_type')['correct'].mean().sort_values()
                
                colors = ['g' if x > 0.8 else 'y' if x > 0.6 else 'r' for x in case_perf]
                
                case_perf.plot(kind='barh', color=colors)
                plt.title('Accuracy by Case Type')
                plt.xlabel('Accuracy')
                plt.ylabel('Case Type')
                plt.xlim(0, 1)
                plt.tight_layout()
                
                case_path = os.path.join(viz_dir, "case_type_performance.png")
                plt.savefig(case_path)
                plt.close()
                visualization_paths.append(case_path)
            
        except Exception as e:
            self.logger.warning(f"Failed to generate visualizations: {e}")
        
        return visualization_paths
    
    def _generate_html_report(self, report_data):
        """Generate HTML report from data"""
        report_path = os.path.join(self.output_dir, "merchant_matcher_report.html")
        
        # Basic HTML template
        html_template = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>Merchant Matcher Test Report</title>
            <style>
                body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; max-width: 1200px; margin: 0 auto; padding: 20px; }
                h1, h2, h3 { color: #2c3e50; }
                .header { background-color: #2c3e50; color: white; padding: 20px; margin-bottom: 20px; }
                .section { margin-bottom: 30px; background-color: #f9f9f9; padding: 20px; border-radius: 5px; }
                table { border-collapse: collapse; width: 100%; margin-bottom: 20px; }
                th, td { border: 1px solid #ddd; padding: 12px; }
                th { background-color: #f2f2f2; text-align: left; }
                tr:nth-child(even) { background-color: #f9f9f9; }
                .metrics { display: flex; flex-wrap: wrap; gap: 20px; margin-bottom: 20px; }
                .metric-card { background-color: white; border-radius: 5px; padding: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); flex: 1; min-width: 150px; }
                .metric-value { font-size: 24px; font-weight: bold; color: #2c3e50; }
                .visualizations { display: flex; flex-wrap: wrap; gap: 20px; }
                .visualization { max-width: 100%; height: auto; }
                .visualization img { max-width: 100%; height: auto; border: 1px solid #ddd; }
                pre { background-color: #f5f5f5; padding: 10px; border-radius: 5px; overflow-x: auto; }
                .good { color: green; }
                .medium { color: orange; }
                .poor { color: red; }
            </style>
        </head>
        <body>
            <div class="header">
                <h1>Merchant Matcher Test Report</h1>
                <p>Generated on: {timestamp}</p>
            </div>
            
            <div class="section">
                <h2>Summary</h2>
                <div class="metrics">
                    <div class="metric-card">
                        <div>Accuracy</div>
                        <div class="metric-value {accuracy_class}">{accuracy:.2%}</div>
                    </div>
                    <div class="metric-card">
                        <div>Precision</div>
                        <div class="metric-value {precision_class}">{precision:.2%}</div>
                    </div>
                    <div class="metric-card">
                        <div>Recall</div>
                        <div class="metric-value {recall_class}">{recall:.2%}</div>
                    </div>
                    <div class="metric-card">
                        <div>F1 Score</div>
                        <div class="metric-value {f1_class}">{f1:.2%}</div>
                    </div>
                    <div class="metric-card">
                        <div>Total Tests</div>
                        <div class="metric-value">{total_tests}</div>
                    </div>
                </div>
                
                <h3>Confusion Matrix</h3>
                <table>
                    <tr>
                        <th></th>
                        <th>Predicted No Match</th>
                        <th>Predicted Match</th>
                    </tr>
                    <tr>
                        <th>Actual No Match</th>
                        <td>{true_negatives}</td>
                        <td>{false_positives}</td>
                    </tr>
                    <tr>
                        <th>Actual Match</th>
                        <td>{false_negatives}</td>
                        <td>{true_positives}</td>
                    </tr>
                </table>
                
                <h3>Optimal Thresholds</h3>
                <table>
                    <tr>
                        <th>Level</th>
                        <th>Threshold</th>
                    </tr>
                    <tr>
                        <td>High</td>
                        <td>{high_threshold:.2f}</td>
                    </tr>
                    <tr>
                        <td>Medium</td>
                        <td>{medium_threshold:.2f}</td>
                    </tr>
                    <tr>
                        <td>Low</td>
                        <td>{low_threshold:.2f}</td>
                    </tr>
                </table>
            </div>
            
            <div class="section">
                <h2>Visualizations</h2>
                <div class="visualizations">
                    {visualizations_html}
                </div>
            </div>
            
            <div class="section">
                <h2>Performance by Case Type</h2>
                <table>
                    <tr>
                        <th>Case Type</th>
                        <th>Count</th>
                        <th>Accuracy</th>
                        <th>Avg Score</th>
                    </tr>
                    {case_type_rows}
                </table>
            </div>
            
            <div class="section">
                <h2>Sample Test Results</h2>
                <table>
                    <tr>
                        <th>Merchant 1</th>
                        <th>Merchant 2</th>
                        <th>Score</th>
                        <th>Confidence</th>
                        <th>Expected</th>
                        <th>Match Level</th>
                    </tr>
                    {test_result_rows}
                </table>
            </div>
            
            <div class="section">
                <h2>Error Analysis</h2>
                <h3>False Positives (Incorrectly Matched)</h3>
                <table>
                    <tr>
                        <th>Merchant 1</th>
                        <th>Merchant 2</th>
                        <th>Score</th>
                        <th>Confidence</th>
                    </tr>
                    {false_positive_rows}
                </table>
                
                <h3>False Negatives (Incorrectly Not Matched)</h3>
                <table>
                    <tr>
                        <th>Merchant 1</th>
                        <th>Merchant 2</th>
                        <th>Score</th>
                        <th>Confidence</th>
                    </tr>
                    {false_negative_rows}
                </table>
            </div>
        </body>
        </html>
        """
        
        # Extract metrics
        metrics = report_data['test_summary']['performance_metrics']
        
        accuracy = metrics.get('accuracy', 0)
        precision = metrics.get('precision', 0)
        recall = metrics.get('recall', 0)
        f1 = metrics.get('f1_score', 0)
        
        accuracy_class = 'good' if accuracy >= 0.9 else 'medium' if accuracy >= 0.7 else 'poor'
        precision_class = 'good' if precision >= 0.9 else 'medium' if precision >= 0.7 else 'poor'
        recall_class = 'good' if recall >= 0.9 else 'medium' if recall >= 0.7 else 'poor'
        f1_class = 'good' if f1 >= 0.9 else 'medium' if f1 >= 0.7 else 'poor'
        
        # Generate visualizations HTML
        visualizations_html = ""
        if 'visualizations' in report_data:
            for viz_path in report_data['visualizations']:
                # Get relative path
                rel_path = os.path.relpath(viz_path, self.output_dir)
                visualizations_html += f"""
                <div class="visualization">
                    <img src="{rel_path}" alt="Visualization">
                </div>
                """
        
        # Generate case type rows
        case_type_rows = ""
        if 'case_type_performance' in metrics:
            for case_type, case_data in metrics['case_type_performance'].items():
                acc_class = 'good' if case_data['accuracy'] >= 0.9 else 'medium' if case_data['accuracy'] >= 0.7 else 'poor'
                case_type_rows += f"""
                <tr>
                    <td>{case_type}</td>
                    <td>{case_data['count']}</td>
                    <td class="{acc_class}">{case_data['accuracy']:.2%}</td>
                    <td>{case_data['avg_score']:.2f}</td>
                </tr>
                """
        
        # Generate test result rows
        test_result_rows = ""
        for result in report_data['test_results'][:20]:  # Show first 20
            expected = result.get('expected_match')
            expected_str = 'Yes' if expected == 1 else 'No' if expected == 0 else 'Unknown'
            test_result_rows += f"""
            <tr>
                <td>{result['s1']}</td>
                <td>{result['s2']}</td>
                <td>{result['score']:.2f}</td>
                <td>{result.get('confidence', 0):.2f}</td>
                <td>{expected_str}</td>
                <td>{result['level']}</td>
            </tr>
            """
        
        # Generate false positive and negative rows
        false_positive_rows = ""
        false_negative_rows = ""
        
        if 'error_analysis' in metrics and 'false_positives' in metrics['error_analysis']:
            for fp in metrics['error_analysis']['false_positives']['examples'][:10]:
                false_positive_rows += f"""
                <tr>
                    <td>{fp['s1']}</td>
                    <td>{fp['s2']}</td>
                    <td>{fp['score']:.2f}</td>
                    <td>{fp.get('confidence', 0):.2f}</td>
                </tr>
                """
        
        if 'error_analysis' in metrics and 'false_negatives' in metrics['error_analysis']:
            for fn in metrics['error_analysis']['false_negatives']['examples'][:10]:
                false_negative_rows += f"""
                <tr>
                    <td>{fn['s1']}</td>
                    <td>{fn['s2']}</td>
                    <td>{fn['score']:.2f}</td>
                    <td>{fn.get('confidence', 0):.2f}</td>
                </tr>
                """
        
        # Extract confusion matrix data
        true_positives = metrics.get('true_positives', 0)
        false_positives = metrics.get('false_positives', 0)
        false_negatives = metrics.get('false_negatives', 0)
        true_negatives = metrics.get('true_negatives', 0)
        
        # Extract threshold data
        thresholds = report_data['test_summary'].get('optimal_thresholds', {}).get('thresholds', {})
        high_threshold = thresholds.get('high', 0.85)
        medium_threshold = thresholds.get('medium', 0.75)
        low_threshold = thresholds.get('low', 0.60)
        
        # Format the HTML
        formatted_html = html_template.format(
            timestamp=report_data['timestamp'],
            accuracy=accuracy,
            precision=precision,
            recall=recall,
            f1=f1,
            total_tests=report_data['test_summary']['total_tests'],
            true_positives=true_positives,
            false_positives=false_positives,
            false_negatives=false_negatives,
            true_negatives=true_negatives,
            high_threshold=high_threshold,
            medium_threshold=medium_threshold,
            low_threshold=low_threshold,
            visualizations_html=visualizations_html,
            case_type_rows=case_type_rows,
            test_result_rows=test_result_rows,
            false_positive_rows=false_positive_rows,
            false_negative_rows=false_negative_rows,
            accuracy_class=accuracy_class,
            precision_class=precision_class,
            recall_class=recall_class,
            f1_class=f1_class
        )
        
        # Write to file
        with open(report_path, 'w') as f:
            f.write(formatted_html)
        
        return report_path
    
    def _generate_markdown_report(self, report_data):
        """Generate Markdown report from data"""
        report_path = os.path.join(self.output_dir, "merchant_matcher_report.md")
        
        # Extract metrics
        metrics = report_data['test_summary']['performance_metrics']
        
        accuracy = metrics.get('accuracy', 0)
        precision = metrics.get('precision', 0)
        recall = metrics.get('recall', 0)
        f1 = metrics.get('f1_score', 0)
        
        # Extract confusion matrix data
        true_positives = metrics.get('true_positives', 0)
        false_positives = metrics.get('false_positives', 0)
        false_negatives = metrics.get('false_negatives', 0)
        true_negatives = metrics.get('true_negatives', 0)
        
        # Extract threshold data
        thresholds = report_data['test_summary'].get('optimal_thresholds', {}).get('thresholds', {})
        high_threshold = thresholds.get('high', 0.85)
        medium_threshold = thresholds.get('medium', 0.75)
        low_threshold = thresholds.get('low', 0.60)
        
        # Markdown content
        markdown_content = f"""
        # Merchant Matcher Test Report
    
        *Generated on: {report_data['timestamp']}*
    
        ## Summary
    
        - **Total Tests:** {report_data['test_summary']['total_tests']}
        - **Accuracy:** {accuracy:.2%}
        - **Precision:** {precision:.2%}
        - **Recall:** {recall:.2%}
        - **F1 Score:** {f1:.2%}
    
        ### Confusion Matrix
    
        |                  | Predicted No Match | Predicted Match |
        |------------------|------------------|----------------|
        | **Actual No Match** | {true_negatives} | {false_positives} |
        | **Actual Match**    | {false_negatives} | {true_positives} |
    
        ### Optimal Thresholds
    
        | Level  | Threshold |
        |--------|-----------|
        | High   | {high_threshold:.2f} |
        | Medium | {medium_threshold:.2f} |
        | Low    | {low_threshold:.2f} |
    
        ## Performance by Case Type
    
        """
        
        # Add case type performance
        if 'case_type_performance' in metrics:
            markdown_content += """
        | Case Type | Count | Accuracy | Avg Score |
        |-----------|-------|----------|-----------|
        """
            for case_type, case_data in metrics['case_type_performance'].items():
                markdown_content += f"| {case_type} | {case_data['count']} | {case_data['accuracy']:.2%} | {case_data['avg_score']:.2f} |\n"
        
        # Add sample test results
        markdown_content += """
        ## Sample Test Results
    
        | Merchant 1 | Merchant 2 | Score | Confidence | Expected | Match Level |
        |------------|------------|-------|------------|----------|-------------|
        """
        
        for result in report_data['test_results'][:10]:
            expected = result.get('expected_match')
            expected_str = 'Yes' if expected == 1 else 'No' if expected == 0 else 'Unknown'
            markdown_content += f"| {result['s1']} | {result['s2']} | {result['score']:.2f} | {result.get('confidence', 0):.2f} | {expected_str} | {result['level']} |\n"
        
        # Add error analysis
        markdown_content += """
        ## Error Analysis
    
        ### False Positives (Incorrectly Matched)
    
        | Merchant 1 | Merchant 2 | Score | Confidence |
        |------------|------------|-------|------------|
        """
        
        if 'error_analysis' in metrics and 'false_positives' in metrics['error_analysis']:
            for fp in metrics['error_analysis']['false_positives']['examples'][:5]:
                markdown_content += f"| {fp['s1']} | {fp['s2']} | {fp['score']:.2f} | {fp.get('confidence', 0):.2f} |\n"
        
        markdown_content += """
        ### False Negatives (Incorrectly Not Matched)
    
        | Merchant 1 | Merchant 2 | Score | Confidence |
        |------------|------------|-------|------------|
        """
        
        if 'error_analysis' in metrics and 'false_negatives' in metrics['error_analysis']:
            for fn in metrics['error_analysis']['false_negatives']['examples'][:5]:
                markdown_content += f"| {fn['s1']} | {fn['s2']} | {fn['score']:.2f} | {fn.get('confidence', 0):.2f} |\n"
        
        # Write to file
        with open(report_path, 'w') as f:
            f.write(markdown_content)
        
        return report_path
    
    def export_results(self, format='excel', include_explanations=True):
        """
        Export test results to various formats
        
        Args:
            format (str): Export format ('excel', 'csv', 'json')
            include_explanations (bool): Include detailed explanations
            
        Returns:
            str: Path to exported results
        """
        if self.test_results.empty:
            raise ValueError("No test results available for export")
        
        self.logger.info(f"Exporting results to {format}")
        
        # Prepare results for export
        export_df = self.test_results.copy()
        
        # Truncate explanations if they are too long
        if 'explanation' in export_df.columns and include_explanations:
            export_df['explanation'] = export_df['explanation'].apply(
                lambda x: x[:500] + '...' if isinstance(x, str) and len(x) > 500 else x
            )
        elif 'explanation' in export_df.columns and not include_explanations:
            export_df = export_df.drop(columns=['explanation'])
        
        # Export based on format
        if format == 'excel':
            export_path = os.path.join(self.output_dir, "merchant_matcher_results.xlsx")
            
            # Create Excel writer
            try:
                import openpyxl
                from openpyxl.styles import PatternFill, Font, Alignment
                from openpyxl.utils import get_column_letter
                
                # Create writer
                writer = pd.ExcelWriter(export_path, engine='openpyxl')
                
                # Write main results
                export_df.to_excel(writer, sheet_name='Test Results', index=False)
                
                # Add summary sheet
                summary_data = pd.DataFrame([{
                    'Metric': 'Accuracy',
                    'Value': self.performance_metrics.get('accuracy', 0)
                }, {
                    'Metric': 'Precision',
                    'Value': self.performance_metrics.get('precision', 0)
                }, {
                    'Metric': 'Recall',
                    'Value': self.performance_metrics.get('recall', 0)
                }, {
                    'Metric': 'F1 Score',
                    'Value': self.performance_metrics.get('f1_score', 0)
                }, {
                    'Metric': 'Total Tests',
                    'Value': len(export_df)
                }])
                
                summary_data.to_excel(writer, sheet_name='Summary', index=False)
                
                # Access workbook
                workbook = writer.book
                
                # Format Test Results sheet
                if 'Test Results' in workbook.sheetnames:
                    sheet = workbook['Test Results']
                    
                    # Format header
                    for col in range(1, sheet.max_column + 1):
                        cell = sheet.cell(row=1, column=col)
                        cell.font = Font(bold=True)
                        cell.fill = PatternFill(start_color="DDDDDD", end_color="DDDDDD", fill_type="solid")
                    
                    # Auto-size columns
                    for col in range(1, sheet.max_column + 1):
                        column_letter = get_column_letter(col)
                        sheet.column_dimensions[column_letter].width = 15
                    
                    # Add conditional formatting for score column
                    score_col = None
                    for col in range(1, sheet.max_column + 1):
                        if sheet.cell(row=1, column=col).value == 'score':
                            score_col = col
                            break
                    
                    if score_col:
                        for row in range(2, sheet.max_row + 1):
                            cell = sheet.cell(row=row, column=score_col)
                            score = cell.value if cell.value is not None else 0
                            
                            if score >= 0.85:
                                cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
                            elif score >= 0.7:
                                cell.fill = PatternFill(start_color="FFEB9C", end_color="FFEB9C", fill_type="solid")
                            else:
                                cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
                
                # Format Summary sheet
                if 'Summary' in workbook.sheetnames:
                    sheet = workbook['Summary']
                    
                    # Format header
                    for col in range(1, sheet.max_column + 1):
                        cell = sheet.cell(row=1, column=col)
                        cell.font = Font(bold=True)
                        cell.fill = PatternFill(start_color="DDDDDD", end_color="DDDDDD", fill_type="solid")
                    
                    # Format values
                    for row in range(2, sheet.max_row + 1):
                        metric_cell = sheet.cell(row=row, column=1)
                        value_cell = sheet.cell(row=row, column=2)
                        
                        metric_cell.font = Font(bold=True)
                        
                        # Format percentage values
                        if metric_cell.value in ['Accuracy', 'Precision', 'Recall', 'F1 Score']:
                            value = value_cell.value if value_cell.value is not None else 0
                            
                            if value >= 0.9:
                                value_cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
                            elif value >= 0.7:
                                value_cell.fill = PatternFill(start_color="FFEB9C", end_color="FFEB9C", fill_type="solid")
                            else:
                                value_cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
                
                # Save workbook
                writer.close()
                
            except ImportError:
                # Fallback to basic Excel export
                export_df.to_excel(export_path, index=False)
                
        elif format == 'csv':
            export_path = os.path.join(self.output_dir, "merchant_matcher_results.csv")
            export_df.to_csv(export_path, index=False)
            
        elif format == 'json':
            export_path = os.path.join(self.output_dir, "merchant_matcher_results.json")
            
            # Create JSON with results and metrics
            export_data = {
                'results': export_df.to_dict('records'),
                'performance_metrics': self.performance_metrics,
                'optimal_thresholds': self.optimal_thresholds
            }
            
            with open(export_path, 'w') as f:
                json.dump(export_data, f, indent=2)
        else:
            # Default to CSV
            export_path = os.path.join(self.output_dir, "merchant_matcher_results.csv")
            export_df.to_csv(export_path, index=False)
        
        self.logger.info(f"Results exported to {export_path}")
        return export_path
    
    def calibrate_confidence(self):
        """
        Calibrate confidence scores based on test results
        
        Returns:
            dict: Calibration parameters and metrics
        """
        if self.test_results.empty or 'confidence' not in self.test_results.columns:
            raise ValueError("Test results with confidence scores required for calibration")
        
        self.logger.info("Calibrating confidence scores")
        
        # Check if we have ground truth
        has_ground_truth = 'expected_match' in self.test_results.columns and not self.test_results['expected_match'].isna().all()
        
        if not has_ground_truth:
            self.logger.warning("Ground truth not available. Skipping calibration.")
            return None
        
        # Create confidence bins
        self.test_results['confidence_bin'] = pd.cut(
            self.test_results['confidence'], 
            bins=10
        )
        
        # Calculate actual accuracy in each bin
        confidence_calibration = self.test_results.groupby('confidence_bin').agg({
            'confidence': 'mean',
            'correct': 'mean',
            'expected_match': 'count'
        }).rename(columns={
            'confidence': 'avg_confidence',
            'correct': 'actual_accuracy',
            'expected_match': 'count'
        }).reset_index()
        
        # Calculate calibration metrics
        calibration_error = np.mean(np.abs(confidence_calibration['avg_confidence'] - confidence_calibration['actual_accuracy']))
        
        # Calculate scaling parameters for logistic calibration
        from scipy.optimize import curve_fit
        
        def logistic_function(x, a, b):
            return 1 / (1 + np.exp(-(a * x + b)))
        
        try:
            # Fit logistic function to confidence-accuracy relationship
            params, _ = curve_fit(
                logistic_function, 
                confidence_calibration['avg_confidence'], 
                confidence_calibration['actual_accuracy'],
                p0=[1, 0]
            )
            
            a, b = params
            
            # Calculate calibrated confidence
            def calibrate(confidence):
                return logistic_function(confidence, a, b)
            
            # Apply calibration to test results
            self.test_results['calibrated_confidence'] = self.test_results['confidence'].apply(calibrate)
            
            # Calculate new calibration error
            self.test_results['calibrated_confidence_bin'] = pd.cut(
                self.test_results['calibrated_confidence'], 
                bins=10
            )
            
            calibrated_metrics = self.test_results.groupby('calibrated_confidence_bin').agg({
                'calibrated_confidence': 'mean',
                'correct': 'mean',
                'expected_match': 'count'
            }).rename(columns={
                'calibrated_confidence': 'avg_calibrated_confidence',
                'correct': 'actual_accuracy',
                'expected_match': 'count'
            }).reset_index()
            
            new_calibration_error = np.mean(np.abs(calibrated_metrics['avg_calibrated_confidence'] - calibrated_metrics['actual_accuracy']))
            
            # Save calibration parameters
            calibration_params = {
                'method': 'logistic',
                'parameters': {
                    'a': float(a),
                    'b': float(b)
                },
                'original_calibration_error': float(calibration_error),
                'calibrated_error': float(new_calibration_error),
                'improvement': float(calibration_error - new_calibration_error)
            }
            
            # Save calibration results
            calibration_path = os.path.join(self.output_dir, "confidence_calibration.json")
            with open(calibration_path, 'w') as f:
                json.dump(calibration_params, f, indent=2)
            
            self.logger.info(f"Confidence calibration completed. Parameters saved to {calibration_path}")
            self.logger.info(f"Calibration improvement: {calibration_params['improvement']:.4f}")
            
            return calibration_params
            
        except Exception as e:
            self.logger.error(f"Failed to calibrate confidence: {e}")
            return None
    
    def compare_with_baseline(self, baseline_thresholds=None):
        """
        Compare matcher performance with baseline algorithms
        
        Args:
            baseline_thresholds (dict): Thresholds for baseline algorithms
            
        Returns:
            dict: Comparison results
        """
        if self.test_results.empty:
            raise ValueError("No test results available for comparison")
        
        # Check if we have ground truth
        has_ground_truth = 'expected_match' in self.test_results.columns and not self.test_results['expected_match'].isna().all()
        
        if not has_ground_truth:
            self.logger.warning("Ground truth not available. Skipping baseline comparison.")
            return None
        
        self.logger.info("Comparing with baseline algorithms")
        
        # Default baseline thresholds
        if baseline_thresholds is None:
            baseline_thresholds = {
                'jaro_winkler': 0.85,
                'levenshtein': 0.7,
                'token_set': 0.8
            }
        
        # Prepare comparison data
        comparison_results = {
            'enhanced_matcher': {
                'accuracy': self.performance_metrics.get('accuracy', 0),
                'precision': self.performance_metrics.get('precision', 0),
                'recall': self.performance_metrics.get('recall', 0),
                'f1_score': self.performance_metrics.get('f1_score', 0)
            },
            'baseline_algorithms': {}
        }
        
        # Initialize similarity algorithms for baseline
        similarity_algorithms = SimilarityAlgorithms(
            preprocessor=self.pipeline._initialized_components.get('preprocessor')
        )
        
        # Run baseline algorithm comparisons
        baselines = [
            ('jaro_winkler', similarity_algorithms.jaro_winkler_similarity),
            ('levenshtein', lambda s1, s2, domain=None: 1 - levenshtein_distance(s1, s2) / max(len(s1), len(s2)) if max(len(s1), len(s2)) > 0 else 0),
            ('token_set', similarity_algorithms.token_set_ratio)
        ]
        
        for baseline_name, similarity_func in baselines:
            # Get threshold for this baseline
            threshold = baseline_thresholds.get(baseline_name, 0.7)
            
            # Calculate baseline scores
            baseline_results = []
            
            for _, row in self.test_results.iterrows():
                s1 = row['s1']
                s2 = row['s2']
                expected = row['expected_match']
                
                # Skip invalid entries
                if not isinstance(s1, str) or not isinstance(s2, str) or pd.isna(expected):
                    continue
                
                # Calculate similarity
                similarity = similarity_func(s1, s2)
                
                # Determine match
                predicted = similarity >= threshold
                
                baseline_results.append({
                    'expected': bool(expected),
                    'predicted': predicted,
                    'score': similarity
                })
            
            # Calculate metrics
            if baseline_results:
                df = pd.DataFrame(baseline_results)
                
                tp = sum((df['expected'] == True) & (df['predicted'] == True))
                fp = sum((df['expected'] == False) & (df['predicted'] == True))
                fn = sum((df['expected'] == True) & (df['predicted'] == False))
                tn = sum((df['expected'] == False) & (df['predicted'] == False))
                
                accuracy = (tp + tn) / len(df) if len(df) > 0 else 0
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
                
                comparison_results['baseline_algorithms'][baseline_name] = {
                    'threshold': threshold,
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1_score': f1
                }
        
        # Calculate improvements over baseline
        if comparison_results['baseline_algorithms']:
            # Find best baseline
            best_baseline = max(
                comparison_results['baseline_algorithms'].items(),
                key=lambda x: x[1]['f1_score']
            )
            
            best_baseline_name, best_baseline_metrics = best_baseline
            
            # Calculate improvements
            accuracy_improvement = comparison_results['enhanced_matcher']['accuracy'] - best_baseline_metrics['accuracy']
            precision_improvement = comparison_results['enhanced_matcher']['precision'] - best_baseline_metrics['precision']
            recall_improvement = comparison_results['enhanced_matcher']['recall'] - best_baseline_metrics['recall']
            f1_improvement = comparison_results['enhanced_matcher']['f1_score'] - best_baseline_metrics['f1_score']
            
            comparison_results['improvements'] = {
                'best_baseline': best_baseline_name,
                'accuracy_improvement': accuracy_improvement,
                'precision_improvement': precision_improvement,
                'recall_improvement': recall_improvement,
                'f1_improvement': f1_improvement,
                'relative_f1_improvement': f1_improvement / best_baseline_metrics['f1_score'] if best_baseline_metrics['f1_score'] > 0 else 0
            }
        
        # Save comparison results
        comparison_path = os.path.join(self.output_dir, "baseline_comparison.json")
        with open(comparison_path, 'w') as f:
            json.dump(comparison_results, f, indent=2)
        
        self.logger.info(f"Baseline comparison saved to {comparison_path}")
        
        # Generate comparison visualization
        try:
            self._visualize_baseline_comparison(comparison_results)
        except:
            self.logger.warning("Could not generate baseline comparison visualization")
        
        return comparison_results
    
    def _visualize_baseline_comparison(self, comparison_results):
        """Generate visualization for baseline comparison"""
        try:
            import matplotlib.pyplot as plt
            import numpy as np
            
            # Extract metrics for plotting
            algorithms = ['Enhanced Matcher'] + list(comparison_results['baseline_algorithms'].keys())
            
            accuracy_values = [comparison_results['enhanced_matcher']['accuracy']]
            precision_values = [comparison_results['enhanced_matcher']['precision']]
            recall_values = [comparison_results['enhanced_matcher']['recall']]
            f1_values = [comparison_results['enhanced_matcher']['f1_score']]
            
            for baseline in comparison_results['baseline_algorithms'].values():
                accuracy_values.append(baseline['accuracy'])
                precision_values.append(baseline['precision'])
                recall_values.append(baseline['recall'])
                f1_values.append(baseline['f1_score'])
            
            # Create figure
            plt.figure(figsize=(12, 8))
            
            # Set width of bars
            barWidth = 0.2
            
            # Set positions of bars on X axis
            r1 = np.arange(len(algorithms))
            r2 = [x + barWidth for x in r1]
            r3 = [x + barWidth for x in r2]
            r4 = [x + barWidth for x in r3]
            
            # Create bars
            plt.bar(r1, accuracy_values, width=barWidth, label='Accuracy', color='#3498db')
            plt.bar(r2, precision_values, width=barWidth, label='Precision', color='#2ecc71')
            plt.bar(r3, recall_values, width=barWidth, label='Recall', color='#e74c3c')
            plt.bar(r4, f1_values, width=barWidth, label='F1 Score', color='#f39c12')
            
            # Add labels and legend
            plt.xlabel('Algorithm')
            plt.ylabel('Score')
            plt.title('Performance Comparison with Baseline Algorithms')
            plt.xticks([r + barWidth*1.5 for r in range(len(algorithms))], algorithms)
            plt.legend()
            
            # Add value labels
            for i, v in enumerate(accuracy_values):
                plt.text(r1[i], v + 0.02, f'{v:.2f}', ha='center')
            for i, v in enumerate(precision_values):
                plt.text(r2[i], v + 0.02, f'{v:.2f}', ha='center')
            for i, v in enumerate(recall_values):
                plt.text(r3[i], v + 0.02, f'{v:.2f}', ha='center')
            for i, v in enumerate(f1_values):
                plt.text(r4[i], v + 0.02, f'{v:.2f}', ha='center')
            
            # Set y-axis limit
            plt.ylim(0, 1.2)
            
            # Add grid
            plt.grid(axis='y', alpha=0.3)
            
            # Save figure
            comparison_viz_path = os.path.join(self.output_dir, "baseline_comparison.png")
            plt.savefig(comparison_viz_path)
            plt.close()
            
        except Exception as e:
            self.logger.warning(f"Could not generate baseline comparison visualization: {e}")
    
    def run_end_to_end_pipeline(self, test_data=None, num_test_cases=200, domain=None,
                              optimize_thresholds=True, calibrate_confidence=True,
                              compare_baseline=True, generate_report=True):
        """
        Run a complete end-to-end testing pipeline
        
        Args:
            test_data (DataFrame): Test data with merchant pairs
            num_test_cases (int): Number of test cases to generate if no test data
            domain (str): Domain for testing
            optimize_thresholds (bool): Whether to optimize thresholds
            calibrate_confidence (bool): Whether to calibrate confidence scores
            compare_baseline (bool): Whether to compare with baseline algorithms
            generate_report (bool): Whether to generate report
            
        Returns:
            dict: Pipeline results
        """
        self.logger.info("Starting end-to-end testing pipeline")
        
        # Step 1: Prepare test data
        if test_data is None:
            # Generate test suite
            self.logger.info(f"Generating test suite with {num_test_cases} cases")
            test_data = self.generate_test_suite(
                num_cases=num_test_cases,
                include_edge_cases=True,
                domain=domain,
                output_file="generated_test_suite.csv"
            )
        else:
            # Validate test data
            test_data = self._validate_test_data(test_data)
        
        # Step 2: Run tests
        self.logger.info("Running tests")
        test_results = self.run_tests(
            test_data,
            domain=domain,
            save_results=True
        )
        
        pipeline_results = {
            'test_results': test_results
        }
        
        # Step 3: Analyze performance
        self.logger.info("Analyzing performance")
        performance_analysis = self.analyze_performance(
            by_domain=(domain is not None),
            by_case_type=True,
            by_confidence=True
        )
        
        pipeline_results['performance_analysis'] = performance_analysis
        
        # Step 4: Optimize thresholds if requested
        if optimize_thresholds:
            self.logger.info("Optimizing thresholds")
            optimal_thresholds = self.optimize_thresholds(test_data, domain)
            pipeline_results['optimal_thresholds'] = optimal_thresholds
        
        # Step 5: Calibrate confidence if requested
        if calibrate_confidence:
            self.logger.info("Calibrating confidence scores")
            try:
                calibration_results = self.calibrate_confidence()
                pipeline_results['confidence_calibration'] = calibration_results
            except Exception as e:
                self.logger.warning(f"Confidence calibration failed: {e}")
        
        # Step 6: Compare with baseline if requested
        if compare_baseline:
            self.logger.info("Comparing with baseline algorithms")
            comparison_results = self.compare_with_baseline()
            pipeline_results['baseline_comparison'] = comparison_results
        
        # Step 7: Generate report if requested
        if generate_report:
            self.logger.info("Generating report")
            report_path = self.generate_report(
                output_format='html',
                include_visualizations=True
            )
            pipeline_results['report_path'] = report_path
        
        # Step 8: Export results
        self.logger.info("Exporting results")
        export_path = self.export_results(format='excel', include_explanations=True)
        pipeline_results['export_path'] = export_path
        
        self.logger.info("End-to-end testing pipeline completed successfully")
        
        return pipeline_results
    
    def show_performance_dashboard(self):
        """
        Display an interactive performance dashboard if running in a Jupyter notebook
        
        Returns:
            object: Interactive dashboard object
        """
        try:
            from IPython.display import display, HTML
            import matplotlib.pyplot as plt
            from ipywidgets import interact, widgets
            
            if self.test_results.empty:
                return HTML("<p>No test results available for dashboard</p>")
            
            # Create dashboard HTML
            dashboard_html = """
            <div style="padding: 20px; background-color: #f9f9f9; border-radius: 10px; margin-bottom: 20px;">
                <h2 style="color: #2c3e50;">Merchant Matcher Performance Dashboard</h2>
                <div style="display: flex; flex-wrap: wrap; gap: 15px; margin-top: 20px;">
                    <div style="flex: 1; min-width: 150px; background-color: white; padding: 15px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
                        <div style="font-size: 14px; color: #7f8c8d;">Accuracy</div>
                        <div style="font-size: 24px; font-weight: bold; color: #2c3e50;">{accuracy:.2%}</div>
                    </div>
                    <div style="flex: 1; min-width: 150px; background-color: white; padding: 15px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
                        <div style="font-size: 14px; color: #7f8c8d;">Precision</div>
                        <div style="font-size: 24px; font-weight: bold; color: #2c3e50;">{precision:.2%}</div>
                    </div>
                    <div style="flex: 1; min-width: 150px; background-color: white; padding: 15px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
                        <div style="font-size: 14px; color: #7f8c8d;">Recall</div>
                        <div style="font-size: 24px; font-weight: bold; color: #2c3e50;">{recall:.2%}</div>
                    </div>
                    <div style="flex: 1; min-width: 150px; background-color: white; padding: 15px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);">
                        <div style="font-size: 14px; color: #7f8c8d;">F1 Score</div>
                        <div style="font-size: 24px; font-weight: bold; color: #2c3e50;">{f1:.2%}</div>
                    </div>
                </div>
            </div>
            """
            
            # Fill in metrics
            metrics = self.performance_metrics
            accuracy = metrics.get('accuracy', 0)
            precision = metrics.get('precision', 0)
            recall = metrics.get('recall', 0)
            f1 = metrics.get('f1_score', 0)
            
            dashboard = HTML(dashboard_html.format(
                accuracy=accuracy,
                precision=precision,
                recall=recall,
                f1=f1
            ))
            
            # Display dashboard
            display(dashboard)
            
            # Create interactive functions for exploring results
            def explore_results(case_type='All', score_threshold=0.7, correct_only=False):
                """Display filtered test results"""
                filtered = self.test_results.copy()
                
                # Filter by case type
                if case_type != 'All' and 'case_type' in filtered.columns:
                    filtered = filtered[filtered['case_type'] == case_type]
                
                # Filter by score threshold
                filtered = filtered[filtered['score'] >= score_threshold]
                
                # Filter by correctness
                if correct_only and 'correct' in filtered.columns:
                    filtered = filtered[filtered['correct']]
                
                # Select columns to display
                display_cols = ['s1', 's2', 'score', 'confidence', 'level']
                if 'expected_match' in filtered.columns:
                    display_cols.append('expected_match')
                if 'correct' in filtered.columns:
                    display_cols.append('correct')
                
                # Show top 20 rows
                return filtered[display_cols].head(20)
            
            # Create dropdown for case types
            case_types = ['All']
            if 'case_type' in self.test_results.columns:
                case_types.extend(self.test_results['case_type'].unique())
            
            # Create interactive widgets
            interact(
                explore_results,
                case_type=widgets.Dropdown(options=case_types, description='Case Type:'),
                score_threshold=widgets.FloatSlider(min=0, max=1, step=0.05, value=0.7, description='Score ≥:'),
                correct_only=widgets.Checkbox(value=False, description='Correct Only')
            )
            
            # Create visualization function
            def plot_visualization(plot_type):
                """Display selected visualization"""
                plt.figure(figsize=(10, 6))
                
                if plot_type == 'Score Distribution':
                    if 'expected_match' in self.test_results.columns:
                        plt.hist(
                            [
                                self.test_results[self.test_results['expected_match'] == 1]['score'],
                                self.test_results[self.test_results['expected_match'] == 0]['score']
                            ],
                            bins=20,
                            label=['Match', 'No Match'],
                            alpha=0.7
                        )
                        plt.legend()
                    else:
                        plt.hist(self.test_results['score'], bins=20)
                    plt.title('Score Distribution')
                    plt.xlabel('Match Score')
                    plt.ylabel('Count')
                    
                elif plot_type == 'Confidence Calibration':
                    if 'confidence' in self.test_results.columns and 'correct' in self.test_results.columns:
                        # Group by confidence bins
                        self.test_results['confidence_bin'] = pd.cut(
                            self.test_results['confidence'], 
                            bins=10
                        )
                        
                        conf_acc = self.test_results.groupby('confidence_bin')['correct'].mean().reset_index()
                        conf_acc['bin_center'] = conf_acc['confidence_bin'].apply(lambda x: x.mid)
                        
                        plt.plot(conf_acc['bin_center'], conf_acc['correct'], 'o-', linewidth=2)
                        plt.plot([0, 1], [0, 1], 'k--')  # Ideal calibration line
                        
                        plt.title('Confidence Calibration Plot')
                        plt.xlabel('Confidence')
                        plt.ylabel('Accuracy')
                        plt.grid(alpha=0.3)
                    else:
                        plt.text(0.5, 0.5, 'Confidence data not available', ha='center', va='center')
                        
                elif plot_type == 'Case Type Performance':
                    if 'case_type' in self.test_results.columns and 'correct' in self.test_results.columns:
                        case_perf = self.test_results.groupby('case_type')['correct'].mean().sort_values()
                        
                        colors = ['g' if x > 0.8 else 'y' if x > 0.6 else 'r' for x in case_perf]
                        
                        case_perf.plot(kind='barh', color=colors)
                        plt.title('Accuracy by Case Type')
                        plt.xlabel('Accuracy')
                        plt.ylabel('Case Type')
                        plt.xlim(0, 1)
                    else:
                        plt.text(0.5, 0.5, 'Case type data not available', ha='center', va='center')
                        
                plt.tight_layout()
                plt.show()
            
            # Create visualization widget
            viz_options = [
                'Score Distribution',
                'Confidence Calibration',
                'Case Type Performance'
            ]
            
            interact(
                plot_visualization,
                plot_type=widgets.Dropdown(options=viz_options, description='Plot:')
            )
            
            return dashboard
            
        except ImportError:
            self.logger.warning("IPython dependencies not available. Cannot display dashboard.")
            return None

In [59]:
# 1. Initialize the tester
print("Initializing Merchant Matcher Tester...")
tester = MerchantMatcherTester(output_dir=output_dir)

# 2. Generate test suite with diverse cases
print("Generating comprehensive test suite...")
test_suite = tester.generate_test_suite(
    num_cases=300,
    include_edge_cases=True,
    output_file="merchant_test_suite.csv"
)

# 3. Run end-to-end pipeline
print("Running end-to-end evaluation pipeline...")
pipeline_results = tester.run_end_to_end_pipeline(
    test_data=test_suite,
    optimize_thresholds=True,
    calibrate_confidence=True,
    compare_baseline=True,
    generate_report=True
)

# 4. Print summary results
print("\n==== Merchant Matcher Evaluation Results ====")
if 'test_results' in pipeline_results and 'performance' in pipeline_results['test_results']:
    perf = pipeline_results['test_results']['performance']
    print(f"Accuracy:  {perf.get('accuracy', 0):.2%}")
    print(f"Precision: {perf.get('precision', 0):.2%}")
    print(f"Recall:    {perf.get('recall', 0):.2%}")
    print(f"F1 Score:  {perf.get('f1_score', 0):.2%}")

# 5. Print comparison with baseline if available
if 'baseline_comparison' in pipeline_results and 'improvements' in pipeline_results['baseline_comparison']:
    imp = pipeline_results['baseline_comparison']['improvements']
    best_baseline = imp.get('best_baseline', 'Unknown')
    f1_imp = imp.get('f1_improvement', 0)
    rel_imp = imp.get('relative_f1_improvement', 0)
    
    print(f"\nImprovement over best baseline ({best_baseline}):")
    print(f"F1 Score Improvement: {f1_imp:.2%} ({rel_imp:.2%} relative improvement)")

# 6. Print paths to output files
print("\nOutput Files:")
if 'report_path' in pipeline_results:
    print(f"- Detailed Report: {pipeline_results['report_path']}")
if 'export_path' in pipeline_results:
    print(f"- Results Export: {pipeline_results['export_path']}")

print("\nEvaluation complete! See above paths for detailed results.")
return tester, pipeline_results

Initializing Merchant Matcher Tester...


NameError: name 'output_dir' is not defined