In [None]:
import json
from pathlib import Path
from typing import Dict, List, Tuple
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.vector_stores import FaissVectorStore
from llama_index.embeddings import OpenAIEmbedding
from llama_index.core import Document
import faiss    
import numpy as np

class TableMatchingRAG:
    def __init__(self, json_path: str, embedding_cache_path: str = "embeddings_cache.npy"):
        self.embedding_cache_path = Path(embedding_cache_path)
        self.embed_model = OpenAIEmbedding()
        self.table_data = self._load_json(json_path)
        self.descriptions = list(self.table_data.values())
        self.table_names = list(self.table_data.keys())
        
        # Initialize FAISS index
        self.dimension = 1536  # OpenAI embedding dimension
        self.index = faiss.IndexFlatL2(self.dimension)
        
        # Load or create embeddings
        self.embeddings = self._load_or_create_embeddings()
        
        # Add embeddings to FAISS index
        self.index.add(self.embeddings)
    
    def _load_json(self, json_path: str) -> Dict[str, str]:
        """Load table descriptions from JSON file."""
        with open(json_path, 'r') as f:
            return json.load(f)
    
    def _load_or_create_embeddings(self) -> np.ndarray:
        """Load embeddings from cache or create new ones."""
        if self.embedding_cache_path.exists():
            return np.load(str(self.embedding_cache_path))
        
        # Create embeddings for all descriptions
        embeddings = []
        for description in self.descriptions:
            embedding = self.embed_model.get_text_embedding(description)
            embeddings.append(embedding)
        
        embeddings_array = np.array(embeddings, dtype=np.float32)
        np.save(str(self.embedding_cache_path), embeddings_array)
        return embeddings_array
    
    def find_matching_tables(self, query: str, top_k: int = 3) -> List[Tuple[str, str, float]]:
        """Find the most similar tables based on the query."""
        # Get query embedding
        query_embedding = self.embed_model.get_text_embedding(query)
        query_embedding_array = np.array([query_embedding], dtype=np.float32)
        
        # Perform similarity search
        distances, indices = self.index.search(query_embedding_array, top_k)
        
        # Format results
        results = []
        for idx, distance in zip(indices[0], distances[0]):
            table_name = self.table_names[idx]
            description = self.descriptions[idx]
            similarity_score = 1 / (1 + distance)  # Convert distance to similarity score
            results.append((table_name, description, similarity_score))
        
        return results
    
    def batch_queries(self, queries: List[str], top_k: int = 3) -> Dict[str, List[Tuple[str, str, float]]]:
        """Process multiple queries efficiently."""
        # Get embeddings for all queries at once
        query_embeddings = []
        for query in queries:
            embedding = self.embed_model.get_text_embedding(query)
            query_embeddings.append(embedding)
        
        query_embedding_array = np.array(query_embeddings, dtype=np.float32)
        
        # Perform batch similarity search
        distances, indices = self.index.search(query_embedding_array, top_k)
        
        # Format results
        results = {}
        for i, query in enumerate(queries):
            query_results = []
            for idx, distance in zip(indices[i], distances[i]):
                table_name = self.table_names[idx]
                description = self.descriptions[idx]
                similarity_score = 1 / (1 + distance)
                query_results.append((table_name, description, similarity_score))
            results[query] = query_results
        
        return results

# Example usage
if __name__ == "__main__":
    # Sample table descriptions
    sample_data = {
        "users": "Table containing user information including id, name, email, and registration date",
        "orders": "Table storing customer order details with order_id, user_id, products, and order_date",
        "products": "Product catalog with product_id, name, description, price, and category",
    }
    
    # Save sample data to JSON
    with open("table_descriptions.json", "w") as f:
        json.dump(sample_data, f)
    
    # Initialize RAG system
    rag = TableMatchingRAG("table_descriptions.json")
    
    # Single query example
    query = "Where can I find customer purchase information?"
    results = rag.find_matching_tables(query)
    print("\nSingle Query Results:")
    for table_name, description, score in results:
        print(f"Table: {table_name}")
        print(f"Description: {description}")
        print(f"Similarity Score: {score:.4f}\n")
    
    # Batch query example
    queries = [
        "Where is user data stored?",
        "I need to look up product prices",
    ]
    batch_results = rag.batch_queries(queries)
    print("\nBatch Query Results:")
    for query, results in batch_results.items():
        print(f"\nQuery: {query}")
        for table_name, description, score in results:
            print(f"Table: {table_name}")
            print(f"Description: {description}")
            print(f"Similarity Score: {score:.4f}")