In [None]:
#!/usr/bin/env python3
"""
Enhanced Bundle Suggester with BERT Embeddings and FAISS Search
Combines Qwen-generated suggestions with semantic similarity search using BERT and FAISS.
"""

import pandas as pd
import numpy as np
import re
import logging
import pickle
import os
import sys
from itertools import islice
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Check and install required packages
def install_if_missing(package_name, import_name=None):
    if import_name is None:
        import_name = package_name
    
    try:
        __import__(import_name)
    except ImportError:
        print(f"{package_name} not found. Installing...")
        os.system(f"pip install {package_name}")

# Install required packages
install_if_missing("sentence-transformers")
install_if_missing("faiss-cpu", "faiss")
install_if_missing("vllm")
install_if_missing("huggingface_hub")
install_if_missing("torch")
install_if_missing("transformers")

# Now import everything
import torch
import faiss
from sentence_transformers import SentenceTransformer
from huggingface_hub import snapshot_download
import vllm

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def filt(text: str, max_items: int = 10) -> str:
    """Extract and clean numbered list items from text."""
    items = re.findall(r'^\s*\d+\.\s*(.+)$', text, re.MULTILINE)
    items = items[:max_items]
    cleaned_items = [re.sub(r'\s*\(.*?\)', '', item).strip() for item in items]
    return ' | '.join(cleaned_items)

class BERTEmbedder:
    """Handles BERT multilingual embeddings for product titles."""
    
    def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
        """
        Initialize BERT embedder with multilingual model.
        
        Args:
            model_name: Sentence transformer model name
        """
        print(f"Loading BERT model: {model_name}")
        self.model = SentenceTransformer(model_name)
        self.embeddings = None
        self.item_titles = None
        print("BERT model loaded successfully")
    
    def encode_items(self, item_titles: List[str], batch_size: int = 32) -> np.ndarray:
        """
        Encode item titles to embeddings.
        
        Args:
            item_titles: List of product titles
            batch_size: Batch size for encoding
            
        Returns:
            Array of embeddings
        """
        print(f"Encoding {len(item_titles)} items with BERT...")
        self.item_titles = item_titles
        
        # Encode in batches to manage memory
        embeddings = self.model.encode(
            item_titles,
            batch_size=batch_size,
            show_progress_bar=True,
            convert_to_numpy=True
        )
        
        self.embeddings = embeddings
        print(f"Generated embeddings shape: {embeddings.shape}")
        return embeddings
    
    def save_embeddings(self, filepath: str):
        """Save embeddings and titles to file."""
        data = {
            'embeddings': self.embeddings,
            'item_titles': self.item_titles
        }
        with open(filepath, 'wb') as f:
            pickle.dump(data, f)
        print(f"Embeddings saved to {filepath}")
    
    def load_embeddings(self, filepath: str):
        """Load embeddings and titles from file."""
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        self.embeddings = data['embeddings']
        self.item_titles = data['item_titles']
        print(f"Embeddings loaded from {filepath}")

class FAISSSearcher:
    """FAISS-based similarity search for product embeddings."""
    
    def __init__(self, embeddings: np.ndarray, item_titles: List[str]):
        """
        Initialize FAISS index with embeddings.
        
        Args:
            embeddings: Numpy array of embeddings
            item_titles: List of corresponding item titles
        """
        self.embeddings = embeddings
        self.item_titles = item_titles
        self.dimension = embeddings.shape[1]
        
        # Create FAISS index
        print("Building FAISS index...")
        self.index = faiss.IndexFlatIP(self.dimension)  # Inner product for cosine similarity
        
        # Normalize embeddings for cosine similarity
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings.astype('float32'))
        
        print(f"FAISS index built with {self.index.ntotal} vectors")
    
    def search_similar_items(self, query_text: str, bert_model: SentenceTransformer, 
                           top_k: int = 10) -> List[Tuple[str, float]]:
        """
        Search for similar items given a query text.
        
        Args:
            query_text: Text to search for
            bert_model: BERT model for encoding query
            top_k: Number of top results to return
            
        Returns:
            List of (item_title, similarity_score) tuples
        """
        # Encode query
        query_embedding = bert_model.encode([query_text], convert_to_numpy=True)
        faiss.normalize_L2(query_embedding)
        
        # Search
        scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
        
        # Return results
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx < len(self.item_titles):  # Valid index
                results.append((self.item_titles[idx], float(score)))
        
        return results
    
    def find_complementary_items(self, complement_categories: str, bert_model: SentenceTransformer,
                               original_item: str = None, top_k: int = 5) -> List[Tuple[str, float]]:
        """
        Find items that match complementary categories.
        
        Args:
            complement_categories: Pipe-separated categories from Qwen
            bert_model: BERT model for encoding
            original_item: Original item to exclude from results
            top_k: Number of results per category
            
        Returns:
            List of (item_title, similarity_score) tuples
        """
        if not complement_categories or complement_categories == "":
            return []
        
        categories = [cat.strip() for cat in complement_categories.split('|')]
        all_results = []
        
        for category in categories:
            if category:
                results = self.search_similar_items(category, bert_model, top_k)
                # Filter out the original item if specified
                if original_item:
                    results = [(item, score) for item, score in results 
                             if item.lower() != original_item.lower()]
                all_results.extend(results)
        
        # Remove duplicates and sort by score
        seen = set()
        unique_results = []
        for item, score in all_results:
            if item not in seen:
                seen.add(item)
                unique_results.append((item, score))
        
        # Sort by similarity score (descending)
        unique_results.sort(key=lambda x: x[1], reverse=True)
        return unique_results[:top_k * 2]  # Return more results


    def format_prompt(self, description: str) -> str:
        return self.prompt_template.format(description=description)

    def suggest_for_batch(self, descriptions: List[str]) -> List[str]:
        prompts = [self.format_prompt(d) for d in descriptions]
        try:
            responses = self.llm.generate(prompts, self.sampling_params, use_tqdm=False)
            outputs = []
            for resp in responses:
                text = resp.outputs[0].text.strip()
                outputs.append(filt(text))
            return outputs
        except Exception as e:
            logging.error("Error in batch generation: %s", e)
            return ["- Error generating suggestions."] * len(descriptions)

def batch_iterator(iterable, size: int = 32):
    """Yield successive batches from iterable."""
    it = iter(iterable)
    while True:
        chunk = list(islice(it, size))
        if not chunk:
            break
        yield chunk

    df = pd.DataFrame(sample_data)
    df.to_csv('sample_inventory.csv', index=False)
    print("Created sample_inventory.csv for testing")
    return 'sample_inventory.csv'

def main():
    """Main execution function."""
    # Configuration
    INPUT_CSV = "/kaggle/input/datasettt/final_dataset.csv"
    OUTPUT_CSV = "final_dataset_with_suggestions.csv"
    EMBEDDINGS_FILE = "bert_embeddings.pkl"
    BATCH_SIZE = 32
    
    # Check if input file exists
    if not os.path.exists(INPUT_CSV):
        print(f"Input file {INPUT_CSV} not found.")
        choice = input("Create sample data for testing? (y/n): ").lower()
        if choice == 'y':
            INPUT_CSV = create_sample_data()
        else:
            print("Please provide a valid CSV file with 'Item title' column.")
            sys.exit(1)
    
    # Load data
    print("Loading data...")
    try:
        df = pd.read_csv(INPUT_CSV)
        if 'Item title' not in df.columns:
            print("Error: CSV must contain 'Item title' column")
            sys.exit(1)
    except Exception as e:
        print(f"Error loading CSV: {e}")
        sys.exit(1)
    
    items = df["Item title"].tolist()
    print(f"Loaded {len(items)} items from {INPUT_CSV}")
    
    # Step 1: Generate embeddings with BERT
    print("\n=== Step 1: Generating BERT Embeddings ===")
    bert_embedder = BERTEmbedder()
    
    if os.path.exists(EMBEDDINGS_FILE):
        print("Loading existing embeddings...")
        bert_embedder.load_embeddings(EMBEDDINGS_FILE)
    else:
        embeddings = bert_embedder.encode_items(items)
        bert_embedder.save_embeddings(EMBEDDINGS_FILE)
    
    # Step 2: Create FAISS index
    print("\n=== Step 2: Building FAISS Index ===")
    faiss_searcher = FAISSSearcher(bert_embedder.embeddings, bert_embedder.item_titles)
    
    # Step 3: Generate complementary categories (if not already present)
    if 'compl_items' not in df.columns:
        print("\n=== No compl items found ===")
    else:
        print("Using existing complementary categories...")
    
    # Step 4: Find similar items using FAISS
    print("\n=== Step 4: Finding Similar Items with FAISS ===")
    similar_items_list = []
    similarity_scores_list = []
    
    for idx, row in df.iterrows():
        original_item = row['Item title']
        complement_categories = row.get('compl_items', '')
        
        if pd.isna(complement_categories) or complement_categories == '':
            similar_items_list.append('')
            similarity_scores_list.append('')
            continue
        
        # Find complementary items
        similar_results = faiss_searcher.find_complementary_items(
            complement_categories, 
            bert_embedder.model,
            original_item=original_item,
            top_k=3
        )
        
        # Format results
        if similar_results:
            items_str = ' | '.join([item for item, _ in similar_results[:5]])
            scores_str = ' | '.join([f"{score:.3f}" for _, score in similar_results[:5]])
        else:
            items_str = ''
            scores_str = ''
        
        similar_items_list.append(items_str)
        similarity_scores_list.append(scores_str)
        
        if (idx + 1) % 100 == 0:
            print(f"Processed {idx + 1}/{len(df)} items")
    
    # Add results to dataframe
    df['similar_items'] = similar_items_list
    df['similarity_scores'] = similarity_scores_list
    
    # Step 5: Save results
    print("\n=== Step 5: Saving Results ===")
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"Results saved to {OUTPUT_CSV}")
    
    # Show sample results
    print("\n=== Sample Results ===")
    for i in range(min(3, len(df))):
        print(f"\nItem #{i+1}:")
        print(f"Original: {df.iloc[i]['Item title']}")
        print(f"Categories: {df.iloc[i]['compl_items']}")
        print(f"Similar Items: {df.iloc[i]['similar_items']}")
        print(f"Scores: {df.iloc[i]['similarity_scores']}")
        print("-" * 80)
    
    print(f"\nProcessing complete! Check {OUTPUT_CSV} for full results.")

if __name__ == "__main__":
    main()