In [4]:
"""
Company Name Resolution using ReAct Agent
Identifies and maps company names from user queries to database entries.
"""

import os
import re
import json
from typing import List, Dict, Optional, Tuple, Set
from collections import defaultdict
import unicodedata

import pandas as pd
from openai import OpenAI
from rapidfuzz import fuzz
from sentence_transformers import SentenceTransformer
import numpy as np
# /Users/samuel/Documents/JPMC_PythonProjects/AgenticAIWorkspace/ClientReferenceAgent/run_agent_v2.ipynb

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# import os
# from openai import OpenAI

# # Assumes OPENAI_API_KEY environment variable is set
# client = OpenAI()

# response = client.embeddings.create(
#     input="<company name>",
#     model="text-embedding-3-small" # or "text-embedding-3-large"
# )

# # # The embedding is a list of floating point numbers
# # embedding = response.data[0].embedding
# # print(f"Embedding length: {len(embedding)}")


In [19]:
# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Central configuration"""
    OPENAI_MODEL = "gpt-4o-mini"
    MAX_ITERATIONS = 10
    # EMBEDDINGS_AVAILABLE = True
    
    # Company suffixes to normalize
    COMPANY_SUFFIXES = [
        "Inc.", "Inc", "LLC", "L.L.C.", "LP", "L.P.", 
        "Ltd.", "Ltd", "Limited", "Corporation", "Corp.", "Corp",
        "Associates", "Advisors", "Management", "Partners",
        "Investments", "Capital", "Group", "Authority"
    ]
    
    # # Nickname mappings
    # NICKNAME_MAPPING = {
    #     "2sigma": "Two Sigma Investments, LP",
    #     "apple": "Apple Inc.", 
    #     "google": "Alphabet Inc.",
    #     "aapl": "Apple Inc.",
    #     "googl": "Alphabet Inc."
    # }

    NICKNAME_MAPPING = {
    # Entity blocks (one nickname -> multiple companies)
    "millennium": [
        "Millennium Partners", 
        "Millennium Management LLC"
    ],
    "bridge": ["Bridge Associates"],
    "bridgewater": ["Bridgewater Associates"],
    "bridger": ["Bridger Capital"],
    "2sigma": ["Two Sigma Investments, LP"],
    "googl": ["Alphabet Inc."],
    "bridgewater": ["Bridgewater Associates"],
    "adia": ["Abu Dhabi Investment Authority"]
}

In [20]:
# ============================================================================
# DATA SETUP
# ============================================================================

def initialize_company_database() -> pd.DataFrame:
    """Initialize the company database"""
    data = {
        'Company': [
            "Apple Inc.", 
            "Two Sigma Investments, LP", 
            "Millennium Partners", 
            "Millennium Management LLC", 
            "WorldQuant Millennium Advisors",
            "Bridge Associates", 
            "Bridgewater Associates",
            "Bridger Capital",
            "Curry's Retail Ltd.",
            "Alphabet Inc.",
            "Abu Dhabi Investment Authority"
        ]
    }
    return pd.DataFrame(data)


# def initialize_embeddings(companies: List[str]):
#     """Initialize embedding model for semantic similarity"""
#     try:
#         model = SentenceTransformer('all-MiniLM-L6-v2')
#         embeddings = model.encode(companies)
#         print("✓ Embedding model loaded")
#         return model, embeddings
#     except Exception as e:
#         print(f"⚠ Could not load embedding model: {e}")
#         return None, None
    

# def initialize_embeddings(companies: List[str], openai_client=None):
#     """Initialize embedding model for semantic similarity"""
#     try:
#         if openai_client:
#             # Use OpenAI embeddings
#             print("Generating OpenAI embeddings...")
#             response = openai_client.embeddings.create(
#                 input=companies,
#                 model="text-embedding-3-small"
#             )
#             embeddings = np.array([item.embedding for item in response.data])
#             print("✓ OpenAI embeddings loaded")
#             return None, embeddings  # No model needed, just embeddings
#         # else:
#         #     # Fallback to sentence-transformers (if you want to keep it)
#         #     model = SentenceTransformer('all-MiniLM-L6-v2')
#         #     embeddings = model.encode(companies)
#         #     print("✓ Embedding model loaded")
#         #     return model, embeddings
#     except Exception as e:
#         print(f"⚠ Could not load embedding model: {e}")
#         return None, None

In [21]:
# ============================================================================
# FAST FUZZY MATCHER
# ============================================================================

class FastFuzzyMatcher:
    """High-performance fuzzy string matcher with multiple matching strategies"""
    
    def __init__(self, company_data: List[Tuple[str, str]]):
        """
        Initialize matcher with company data.
        
        Args:
            company_data: List of tuples (company_id, company_name)
        """
        self.company_data = company_data
        self.normalized_names = []
        self.original_names = []
        self.company_ids = []
        self.word_to_companies = defaultdict(set)
        self.trigram_to_companies = defaultdict(set)
        
        self._build_indexes()
    
    def _normalize_string(self, s: str) -> str:
        """Normalize string for better matching"""
        s = s.lower()
        # Remove accents
        s = unicodedata.normalize('NFKD', s)
        s = ''.join(c for c in s if not unicodedata.combining(c))
        # Remove common suffixes and punctuation
        s = re.sub(r'\b(llc|ltd|corp|inc)\b', '', s)
        s = re.sub(r'[^\w\s]', '', s)
        s = re.sub(r'\s+', ' ', s).strip()
        return s
    
    def _get_trigrams(self, s: str) -> Set[str]:
        """Generate trigrams from string"""
        s = f"  {s}  "  # Add padding
        return {s[i:i+3] for i in range(len(s) - 2)}
    
    def _build_indexes(self):
        """Build word and trigram indexes for fast lookup"""
        for i, (company_id, company_name) in enumerate(self.company_data):
            normalized = self._normalize_string(company_name)
            self.normalized_names.append(normalized)
            self.original_names.append(company_name)
            self.company_ids.append(company_id)
            
            # Word index
            words = normalized.split()
            for word in words:
                if len(word) >= 2:
                    self.word_to_companies[word].add(i)
            
            # Trigram index
            trigrams = self._get_trigrams(normalized)
            for trigram in trigrams:
                self.trigram_to_companies[trigram].add(i)


    def _word_similarity_with_position(self, query_words: List[str], target_words: List[str]) -> float:
        """
        Calculate word-based similarity with position and prefix awareness.
        This ensures "Millennium" ranks "Millennium Partners" higher than 
        "WorldQuant Millennium Advisors"
        """
        if not query_words or not target_words:
            return 0.0
        
        query_set = set(query_words)
        target_set = set(target_words)
        
        # Base score: exact word matches
        exact_matches = len(query_set & target_set)
        
        # Position bonus: same word at same position
        position_matches = 0
        for i, q_word in enumerate(query_words):
            if i < len(target_words) and q_word == target_words[i]:
                position_matches += 1
        
        # Prefix match: first word of query matches first word of target
        prefix_match = 1.0 if (query_words and target_words and query_words[0] == target_words[0]) else 0.0
        
        
        # Partial matches (substring)
        partial_matches = 0
        for q_word in query_words:
            for t_word in target_words:
                if q_word != t_word and (q_word in t_word or t_word in q_word):
                    partial_matches += 0.5
                    break
        
        max_words = max(len(query_words), len(target_words))
        
        # Weighted scoring
        base_score = (exact_matches + partial_matches) / max_words
        position_score = position_matches / max_words
        
        # Weighted combination (weights sum to 1.0)
        # - 50% base matching - "Do the words appear at all?" This is most fundamental
        # - 30% prefix matching - Company names usually start with the key identifier ("Millennium Partners", "Goldman Sachs")
        # - 20% position matching - Word order matters, but less than presence and prefix
        final_score = (0.5 * base_score) + (0.3 * prefix_match) + (0.2 * position_score)
        return final_score
    
    
    # def _word_similarity(self, query_words: List[str], target_words: List[str]) -> float:
    #     """Calculate word-based similarity"""
    #     if not query_words or not target_words:
    #         return 0.0
        
    #     query_set = set(query_words)
    #     target_set = set(target_words)
        
    #     # Exact matches
    #     exact_matches = len(query_set & target_set)
        
    #     # Partial matches (substring)
    #     partial_matches = 0
    #     for q_word in query_words:
    #         for t_word in target_words:
    #             if q_word != t_word and (q_word in t_word or t_word in q_word):
    #                 partial_matches += 0.5
    #                 break
        
    #     max_words = max(len(query_words), len(target_words))
    #     return (exact_matches + partial_matches) / max_words
    
    def _jaccard_similarity(self, set1: Set[str], set2: Set[str]) -> float:
        """Calculate Jaccard similarity"""
        if not set1 and not set2:
            return 1.0
        if not set1 or not set2:
            return 0.0
        intersection = len(set1 & set2)
        union = len(set1 | set2)
        return intersection / union
    
    def _calculate_similarity(self, query: str, target_idx: int) -> float:
        """Calculate overall similarity score"""
        query_normalized = self._normalize_string(query)
        target_normalized = self.normalized_names[target_idx]
        
        # Exact match
        if query_normalized == target_normalized:
            return 1.0
        
        # Substring match
        if query_normalized in target_normalized or target_normalized in query_normalized:
            return 0.95
        
        query_words = query_normalized.split()
        target_words = target_normalized.split()
        
        # Word similarity (for exact word matches)
        word_sim = self._word_similarity_with_position(query_words, target_words)
        
        # Trigram similarity (for fuzzy/spelling variations)
        query_trigrams = self._get_trigrams(query_normalized)
        target_trigrams = self._get_trigrams(target_normalized)
        trigram_sim = self._jaccard_similarity(query_trigrams, target_trigrams)
        
        # Character-level edit distance ratio
        min_len = min(len(query_normalized), len(target_normalized))
        max_len = max(len(query_normalized), len(target_normalized))
        length_ratio = min_len / max_len if max_len > 0 else 0
        
        # Use rapidfuzz for additional measure
        from rapidfuzz import fuzz
        partial_ratio = fuzz.partial_ratio(query_normalized, target_normalized) / 100.0
        token_sort_ratio = fuzz.token_sort_ratio(query_normalized, target_normalized) / 100.0
        
        # Adaptive weighted combination
        if word_sim < 0.3:  # Poor word match - likely spelling variation
            # Rely heavily on character-level matching
            score = 0.1 * word_sim + 0.3 * trigram_sim + 0.2 * length_ratio + 0.2 * partial_ratio + 0.2 * token_sort_ratio
        else:  # Good word match
            score = 0.5 * word_sim + 0.2 * trigram_sim + 0.1 * length_ratio + 0.1 * partial_ratio + 0.1 * token_sort_ratio
        
        return score
    
    # def _calculate_similarity(self, query: str, target_idx: int) -> float:
    #     """Calculate overall similarity score"""
    #     query_normalized = self._normalize_string(query)
    #     target_normalized = self.normalized_names[target_idx]
        
    #     # Exact match
    #     if query_normalized == target_normalized:
    #         return 1.0
        
    #     # Substring match
    #     if query_normalized in target_normalized or target_normalized in query_normalized:
    #         return 0.95
        
    #     query_words = query_normalized.split()
    #     target_words = target_normalized.split()
        
    #     # Word similarity (weighted heavily for business names)
    #     word_sim = self._word_similarity(query_words, target_words)
        
    #     # Trigram similarity
    #     query_trigrams = self._get_trigrams(query_normalized)
    #     target_trigrams = self._get_trigrams(target_normalized)
    #     trigram_sim = self._jaccard_similarity(query_trigrams, target_trigrams)
        
    #     # Weighted combination
    #     return 0.7 * word_sim + 0.3 * trigram_sim
    
    def find_matches(self, query: str, top_n: int = 10, min_score: float = 0.25) -> List[Tuple[str, str, float]]:
        """
        Find top N matches for the query.
        
        Args:
            query: Search string
            top_n: Number of top matches to return
            min_score: Minimum similarity score threshold
        
        Returns:
            List of tuples (company_id, company_name, similarity_score)
        """
        if not query.strip():
            return []
        
        query_normalized = self._normalize_string(query)
        query_words = set(query_normalized.split())
        query_trigrams = self._get_trigrams(query_normalized)

        # print(f"  [MATCHER DEBUG] Query normalized: '{query_normalized}'")
        # print(f"  [MATCHER DEBUG] Query words: {query_words}")
        # print(f"  [MATCHER DEBUG] Query trigrams count: {len(query_trigrams)}")
        # print(f"  [MATCHER DEBUG] Sample trigrams: {list(query_trigrams)[:10]}")
        
        
        # Gather candidates
        candidate_indices = set()
        
        # Word-based candidates
        for word in query_words:
            if word in self.word_to_companies:
                candidate_indices.update(self.word_to_companies[word])

        # print(f"  [MATCHER DEBUG] Word-based candidates: {len(candidate_indices)}")
        
        # Trigram-based candidates
        trigram_matches = defaultdict(int)
        for trigram in query_trigrams:
            if trigram in self.trigram_to_companies:
                for idx in self.trigram_to_companies[trigram]:
                    trigram_matches[idx] += 1

        # print(f"  [MATCHER DEBUG] Trigram matches: {len(trigram_matches)} companies")
        if trigram_matches:
            top_3 = sorted(trigram_matches.items(), key=lambda x: x[1], reverse=True)[:3]
            # print(f"  [MATCHER DEBUG] Top 3 trigram matches: {[(self.original_names[idx], count) for idx, count in top_3]}")
        
        # Add candidates with sufficient trigram overlap
        min_trigram_matches = max(1, len(query_trigrams) // 4)
        # print(f"  [MATCHER DEBUG] Min trigram threshold: {min_trigram_matches}")
        for idx, count in trigram_matches.items():
            if count >= min_trigram_matches:
                candidate_indices.add(idx)

        # print(f"  [MATCHER DEBUG] After trigram filter: {len(candidate_indices)} candidates")

        # If we have very few candidates from indexing, expand search
        if len(candidate_indices) < 5:
            # Add top trigram matches even if below threshold
            top_trigram_candidates = sorted(trigram_matches.items(), 
                                           key=lambda x: x[1], 
                                           reverse=True)[:20]
            for idx, _ in top_trigram_candidates:
                candidate_indices.add(idx)
            # print(f"  [MATCHER DEBUG] Expanded to {len(candidate_indices)} candidates")
        
        
        # Fallback: if still no candidates, check all (for very short queries or edge cases)
        if not candidate_indices:
            candidate_indices = set(range(len(self.company_data)))
            # print(f"  [MATCHER DEBUG] Fallback: checking all {len(candidate_indices)} companies")
        
        
        # Calculate similarity scores
        scored_candidates = []
        for idx in candidate_indices:
            score = self._calculate_similarity(query, idx)
            if score >= min_score:
                scored_candidates.append((
                    self.company_ids[idx],
                    self.original_names[idx],
                    score
                ))

        
        # print(f"  [MATCHER DEBUG] Scored candidates above threshold: {len(scored_candidates)}")
        # if scored_candidates:
        #     print(f"  [MATCHER DEBUG] Top scored: {scored_candidates[0]}")
        
        
        # Sort and return top N
        scored_candidates.sort(key=lambda x: x[2], reverse=True)
        return scored_candidates[:top_n]


In [22]:
# ============================================================================
# TOOL IMPLEMENTATIONS
# ============================================================================

class CompanyResolutionTools:
    """Collection of tools for company name resolution"""
    
    def __init__(self, client: OpenAI, matcher: FastFuzzyMatcher):
        self.client = client
        self.matcher = matcher
    
    def extract_names(self, query: str) -> str:
        """Extract, correct, and expand company names from user query using LLM"""
        prompt = f"""Extract all potential company names from this user query: "{query}"

                For each company name found, provide:
                1. The original name as it appears in the query (preserve exact formatting)
                2. A spelling-corrected version (fix typos, remove extraneous punctuation)
                3. An expanded version if it's an abbreviation/acronym/commonly used shortform (otherwise same as corrected)

                Note: Don't worry about capitalization - focus on spelling and expansion.

                Return a JSON object with this structure:
                {{
                "companies": [
                    {{
                    "original": "the exact name from query",
                    "corrected": "spelling-corrected version",
                    "expanded": "full expanded name if abbreviation, otherwise same as corrected"
                    }}
                ]
                }}

                Examples:

                Query: "What is the balance of ADIA' in 2025?"
                Output: {{
                "companies": [
                    {{
                    "original": "ADIA'",
                    "corrected": "ADIA",
                    "expanded": "Abu Dhabi Investment Authority"
                    }}
                ]
                }}

                Query: "Compare ups and Fedex performance"
                Output: {{
                "companies": [
                    {{
                    "original": "ups",
                    "corrected": "ups",
                    "expanded": "United Parcel Service"
                    }},
                    {{
                    "original": "Fedex",
                    "corrected": "FedEx",
                    "expanded": "Federal Express"
                    }}
                ]
                }}

                Query: "Show data for Microsft and Gogle"
                Output: {{
                "companies": [
                    {{
                    "original": "Microsft",
                    "corrected": "Microsoft",
                    "expanded": "Microsoft Corporation"
                    }},
                    {{
                    "original": "Gogle",
                    "corrected": "Google",
                    "expanded": "Google LLC"
                    }}
                ]
                }}

                Query: "Show me revenue data"
                Output: {{
                "companies": []
                }}

                Return ONLY the JSON object, nothing else."""                
        try:
            response = self.client.chat.completions.create(
                model=Config.OPENAI_MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
                max_tokens=300
            )
            
            result = response.choices[0].message.content.strip()
            data = json.loads(result)
            
            # Build ordered lists for each company
            companies_with_variants = []
            for company in data.get("companies", []):
                variants = [
                    company["original"],
                    company["corrected"],
                    company["expanded"]
                ]
                # Remove duplicates while preserving order
                seen = set()
                ordered_variants = []
                for v in variants:
                    if v.lower() not in seen:
                        seen.add(v.lower())
                        ordered_variants.append(v)
                
                companies_with_variants.append({
                    "original": company["original"],
                    "variants": ordered_variants
                })
            
            return json.dumps({
                "companies": companies_with_variants,
                "message": f"Extracted {len(companies_with_variants)} company name(s) with variants"
            })
        except Exception as e:
            return json.dumps({"companies": [], "error": str(e)})
    
        
    def map_nickname(self, candidate_name: str) -> str:
        """
        Map nicknames to official company names.
        Now supports entity blocks (one nickname -> multiple companies).
        
        Args:
            candidate_name: The nickname to map (e.g., "millennium", "2sigma")
        
        Returns:
            JSON string with mapping result
        """
        query_lower = candidate_name.lower().strip()
        
        if query_lower in Config.NICKNAME_MAPPING:
            mapped = Config.NICKNAME_MAPPING[query_lower]
            
            # Handle both single strings and lists (entity blocks)
            if isinstance(mapped, list):
                return json.dumps({
                    "is_nickname": True,
                    "nickname": candidate_name,
                    "official_names": mapped,
                    "message": f"Mapped identifier '{candidate_name}' to '{mapped}'"
                })
        
        return json.dumps({
            "is_nickname": False,
            "message": f"'{candidate_name}' is not a recognized nickname"
        }, indent=2)
    
    def fuzzy_search(self, candidate_name: str, top_n: int = 10, min_score: float = 0.25) -> str:
        """Find similar company names using fuzzy matching"""
        try:
            # Debug: print what we're searching for
            print(f"  [DEBUG] Searching for: '{candidate_name}'")
            print(f"  [DEBUG] Min score threshold: {min_score}")
            
            matches = self.matcher.find_matches(candidate_name, top_n=top_n, min_score=min_score)
            print(f"  [DEBUG] Found {len(matches)} matches")
            if matches:
                print(f"  [DEBUG] Top match: {matches[0]}")
            if not matches:
                return json.dumps({
                    "status": "no_matches",
                    "matches": [],
                    "message": f"No matches found for '{candidate_name}'"
                })
            
            results = {
                "status": "success",
                "query": candidate_name,
                "total_matches": len(matches),
                "matches": [
                    {
                        "company_id": company_id,
                        "company_name": company_name,
                        "similarity_score": round(score, 3)
                    }
                    for company_id, company_name, score in matches
                ]
            }
            return json.dumps(results, indent=2)
        except Exception as e:
            return json.dumps({"status": "error", "message": str(e)})
    
        
    def disambiguate(self, query: str) -> str:
        """Disambiguate between multiple candidate companies"""
        try:
            # Strip quotes if present
            query = query.strip('"\'')
            
            # Handle escaped quotes - sometimes the input comes with escaped quotes
            # Try to parse as-is first, if that fails, try replacing escaped quotes
            try:
                data = json.loads(query)
            except json.JSONDecodeError:
                # Try unescaping if initial parse fails
                query_unescaped = query.replace('\\"', '"').replace("\\'", "'")
                data = json.loads(query_unescaped)
            
            candidates = data.get('candidates', [])
            original_query = data.get('query', '')
            
            if not candidates:
                return json.dumps({"selected": [], "message": "No candidates provided"})
            
            candidates_str = '\n'.join([f"- {c}" for c in candidates])
            
            prompt = f"""User query: "{original_query}"

                    Candidate companies from database:
                    {candidates_str}

                    Determine which company(ies) the user is most likely referring to.

                    **Key Rules:**

                    1. **Word Order Matters**: Match based on the BEGINNING words of company names
                    - "Millennium" matches "Millennium Partners" and "Millennium Management LLC" 
                    - "Millennium" does NOT strongly match "WorldQuant Millennium Advisors" (Millennium is not the first word)
                    - "WorldQuant" matches "WorldQuant Millennium Advisors"

                    2. **Specificity Determines Selection**:
                    - If user term is GENERIC/AMBIGUOUS (matches multiple companies as prefix): return ALL matching companies
                        Example: "Millennium" → ["Millennium Partners", "Millennium Management LLC"]
                    - If user term is SPECIFIC (uniquely identifies one company): return only that company
                        Example: "WorldQuant" → ["WorldQuant Millennium Advisors"]

                    3. **Matching Strategy**:
                    - Check if user's term appears as a STARTING word/phrase in the company name
                    - Partial matches in later words are lower priority
                    - When in doubt about user intent, include all plausible matches

                    **Examples:**

                    User query: "Show balance for Millennium"
                    Candidates: ["Millennium Partners", "Millennium Management LLC", "WorldQuant Millennium Advisors"]
                    Selected: ["Millennium Partners", "Millennium Management LLC"]
                    Reason: "Millennium" may refer to both "Millennium Partners", "Millennium Management LLC", 
                            but "WorldQuant Millennium Advisors" would have been referred to as "WorldQuant".
                            Note: The initial terms used to name a company are crucial for identifying a company. 

                    User query: "What about WorldQuant?"
                    Candidates: ["Millennium Partners", "Millennium Management LLC", "WorldQuant Millennium Advisors"]
                    Selected: ["WorldQuant Millennium Advisors"]
                    Reason: "WorldQuant" specifically identifies one company

                    User query: "Bridge Associates data"
                    Candidates: ["Bridge Associates", "Bridgewater Associates", "Bridger Capital"]
                    Selected: ["Bridge Associates"]
                    Reason: "Bridge Associates" is an exact match to a specific company

                    Return ONLY a JSON array of selected company names from the candidates list.
                    Example: ["Company Name 1", "Company Name 2"]

                    Do not include explanations, markdown, or any other text.
                    """
            try:              
                response = self.client.chat.completions.create(
                    model=Config.OPENAI_MODEL,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.0,
                    max_tokens=200
                )
            
                result = response.choices[0].message.content.strip()
                print(f"  [DEBUG] LLM returned: {repr(result)}")
                
                # FIXED: Clean up BEFORE parsing
                # Remove markdown code blocks
                if '```' in result:
                    # Extract content between ``` markers
                    parts = result.split('```')
                    for part in parts:
                        part = part.strip()
                        if part.startswith('json'):
                            part = part[4:].strip()
                        if part.startswith('[') and part.endswith(']'):
                            result = part
                            break
                
                # Find the JSON array
                start_idx = result.find('[')
                end_idx = result.rfind(']')
                
                if start_idx != -1 and end_idx != -1:
                    result = result[start_idx:end_idx+1]
                
                print(f"  [DEBUG] Cleaned result: {repr(result)}")

                # NOW parse the cleaned result
                try:
                    selected = json.loads(result)
                except json.JSONDecodeError as e:
                    print(f"  [DEBUG] LLM response parsing failed: {e}")
                    print(f"  [DEBUG] Falling back to all candidates")
                    # Fallback: return all candidates
                    selected = candidates
                
            except Exception as e:
                print(f"  [DEBUG] LLM call failed: {e}")
                # Fallback: return all candidates
                selected = candidates
            
            # Validate and return
            # Make sure selected items are actually in candidates
            if isinstance(selected, list):
                selected = [s for s in selected if s in candidates]
            else:
                selected = candidates
            
            if not selected:
                selected = candidates
            
            return json.dumps({
                "selected": selected,
                "from_candidates": candidates
            })
            
        except json.JSONDecodeError as e:
            return json.dumps({
                "error": f"Input JSON parsing error: {str(e)}",
                "received_input": query[:200]
            })
        except Exception as e:
            return json.dumps({
                "error": f"Unexpected error: {str(e)}",
                "type": type(e).__name__
            })

In [23]:
# ============================================================================
# TOOL WRAPPER CLASS
# ============================================================================

class Tool:
    """Wrapper for callable functions used by ReAct Agent"""
    
    def __init__(self, name: str, func: callable, description: str):
        self.name = name
        self.func = func
        self.description = description
    
    def run(self, **kwargs):
        try:
            if 'query' in kwargs:
                return self.func(kwargs['query'])
            elif 'candidate_name' in kwargs:
                return self.func(kwargs['candidate_name'])
            else:
                return self.func(list(kwargs.values())[0])
        except Exception as e:
            return json.dumps({"error": f"Tool {self.name} failed: {str(e)}"})

In [24]:
# # Revised ReActAgent with Query-to-Candidate Mapping and Score Ranking

# ## Updated Implementation

# This version tracks:
# 1. Which original company name from the query maps to which database companies
# 2. The matching scores for each candidate
# 3. Ranking by score for each query term

from dataclasses import dataclass
from typing import List, Dict, Optional

@dataclass
class CompanyMatch:
    """Represents a matched company with metadata"""
    query_term: str           # Original term from query (e.g., "Millenium")
    company_name: str         # Matched company from database
    match_method: str         # 'nickname', 'disambiguate', 'fuzzy', 'exact'
    score: float             # Confidence/similarity score
    metadata: Dict = None    # Additional info


class ReActAgent:
    """ReAct Agent for company name resolution with query-candidate mapping"""
    
    SYSTEM_PROMPT = """You are a company name resolution agent specializing in the financial domain. 
                Your task is to identify names of financial institutions 
                (e.g., asset management firms, hedge funds, investment companies, financial institutions) 
                in user queries and map them to their official names in our database.

                **Available Tools:**
                1. extract_names[query] - Extract company names with spelling corrections and expansions. Returns each company with an ordered list of variants: [original, corrected, expanded]
                2. map_nickname[name] - Check if a name is a nickname/shorthand (returns official name immediately if found)
                3. fuzzy_search[name] - Find similar company names using fuzzy matching (includes similarity scores)
                4. disambiguate[json_data] - Disambiguate between multiple candidates (required when fuzzy_search returns multiple matches)

                **CRITICAL: Action Format**
                All actions MUST use square brackets: Action: tool_name[input]
                Examples:
                - Action: extract_names[What is the balance of Apple?]
                - Action: map_nickname[2sigma]
                - Action: fuzzy_search[Millenia]
                - Action: disambiguate[{"candidates": ["Company A", "Company B"], "query": "term"}]


                **Required Workflow:**
                1. **Extract ALL company names** from the query using extract_names (may return multiple companies)
                2. **For EACH extracted company** (process each one separately):
                   a. **Check nicknames in order**: Try map_nickname on original → corrected → expanded
                   b. If ANY variant is a nickname → record the official name and move to next company
                   c. If NONE are nicknames → proceed to fuzzy_search for this company
                3. **For EACH company using fuzzy_search**:
                   a. Try fuzzy_search with original → corrected → expanded until you get matches
                   b. **If fuzzy_search returns 2+ matches** → MUST call disambiguate for THIS company
                   c. Record the disambiguated result and move to next company
                4. **After processing ALL companies** → Signal completion with "COMPANIES_RESOLVED"

                **IMPORTANT: When map_nickname returns multiple companies:**
                - This means the nickname represents an entity block (e.g., "Millennium" → multiple companies)
                - You do NOT need to pick one - the system will automatically include ALL of them
                - Simply note that the nickname was resolved and move to the next company

                **Critical Notes:**
                - Process EACH company independently - don't skip any
                - If query has 3 companies, you must resolve all 3 (nickname check OR fuzzy search for each)
                - Disambiguate is called PER COMPANY when that company has multiple fuzzy matches
                - Keep track of all companies found and ensure each gets resolved

                **Why check all nickname variants:**
                - Original might be a nickname (e.g., "2sigma")
                - Corrected might be a nickname (e.g., "aapl")
                - Expanded might be a nickname (e.g., full name that maps to official)
                - Checking variants ensures we don't miss direct mappings

                **Fuzzy Search Strategy (only if NO nickname match found for a company):**
                - Try variants in order: original → corrected → expanded
                - Stop when you get good matches for THIS company
                - If multiple matches returned, MUST call disambiguate for THIS company
                - Format for disambiguate: {"candidates": [matches for THIS company], "query": "this company's original search term"}

                **When to use disambiguate:**
                - ALWAYS use it when fuzzy_search returns 2 or more matches FOR A COMPANY
                - Called separately for each company that has multiple matches
                - Format: {"candidates": ["Company A", "Company B", ...], "query": "this specific company's search term"}
                - This determines if user meant all companies (ambiguous) or specific ones

                **Output Format:**
                Thought: [Your reasoning about what to do next]
                Action: tool_name[input_value]

                After receiving Observation:
                Thought: [Your reasoning based on the observation]
                Action: tool_name[input_value]

                OR when you have resolved all companies:
                Thought: I have resolved all companies from the query
                Answer: COMPANIES_RESOLVED

                **Important:**
                - ALWAYS use square brackets [] for actions, never parentheses or just curly braces
                - When done, simply return "Answer: COMPANIES_RESOLVED"
                - The system will automatically collect ALL companies from your tool calls WITH SCORES
                - Process EACH company in the query independently
                - Check ALL nickname variants (original, corrected, expanded) BEFORE fuzzy search
                - Always disambiguate when multiple fuzzy matches found PER COMPANY
                - If no match for a company, note it but continue"""
                        
    def __init__(self, client: OpenAI, tools: List[Tool]):
        self.client = client
        self.tools_map = {tool.name: tool for tool in tools}
        self.messages = []
        self.tool_call_regex = r"Action:\s*(\w+)\[(.+)\]"
        self.tool_history = []
        self.extracted_companies = []  # NEW: Store extracted company names
    
    def run(self, query: str, verbose: bool = True) -> str:
        """Run the ReAct loop to resolve company names"""
        self.messages = [
            {"role": "system", "content": self.SYSTEM_PROMPT},
            {"role": "user", "content": f"User Query: {query}"}
        ]
        self.tool_history = []
        self.extracted_companies = []
        
        for i in range(Config.MAX_ITERATIONS):
            if verbose:
                print(f"\n{'='*60}")
                print(f"ITERATION {i+1}")
                print('='*60)
            
            # Get LLM response
            response = self._get_completion()
            if verbose:
                print(f"\nAgent Response:\n{response}")
            
            # Check for final answer
            if "Answer:" in response:
                answer_text = response.split("Answer:", 1)[1].strip()
                
                # Construct programmatic answer with mappings
                if "COMPANIES_RESOLVED" in answer_text or True:  # Always use programmatic
                    result = self._build_structured_result()
                    
                    if verbose:
                        print(f"\n{'='*60}")
                        print(f"STRUCTURED RESULT:")
                        print('='*60)
                        for query_term, matches in result.items():
                            print(f"\n'{query_term}' →")
                            for match in matches:
                                print(f"  - {match.company_name} (score: {match.score:.3f}, method: {match.match_method})")
                        print(f"\n{'='*60}")
                        print(f"FINAL ANSWER: {self._format_final_answer(result)}")
                        print('='*60)
                    
                    return self._format_final_answer(result)
            
            # Check for action
            action_match = re.search(self.tool_call_regex, response, re.DOTALL)
            if action_match:
                tool_name, tool_input = action_match.groups()
                tool_input = tool_input.strip()
                
                if tool_name not in self.tools_map:
                    observation = json.dumps({"error": f"Unknown tool '{tool_name}'"})
                else:
                    if verbose:
                        print(f"\n→ Executing: {tool_name}[{tool_input}]")
                    observation = self._execute_tool(tool_name, tool_input)
                
                if verbose:
                    print(f"→ Observation: {observation}")
                
                # Record tool call in history
                self.tool_history.append({
                    "tool": tool_name,
                    "input": tool_input,
                    "observation": observation
                })
                
                # Track extracted companies
                if tool_name == "extract_names":
                    try:
                        data = json.loads(observation)
                        companies = data.get("companies", [])
                        for company_data in companies:
                            self.extracted_companies.append({
                                "original": company_data.get("original"),
                                "variants": company_data.get("variants", [])
                            })
                    except:
                        pass
                
                self.messages.append({"role": "assistant", "content": response})
                self.messages.append({"role": "user", "content": f"Observation: {observation}"})
            else:
                if verbose:
                    print("\n⚠ No valid Action found")
                self.messages.append({"role": "assistant", "content": response})
        
        # Even if max iterations reached, try to extract companies
        result = self._build_structured_result()
        return self._format_final_answer(result)
    
    def _build_structured_result(self) -> Dict[str, List[CompanyMatch]]:
        """
        Build structured mapping of query terms to matched companies with scores.
        
        Returns:
            Dict mapping original query terms to lists of CompanyMatch objects,
            sorted by score (descending)
        """
        # Map query term → list of CompanyMatch objects
        query_to_matches: Dict[str, List[CompanyMatch]] = {}
        
        # Track which tool calls belong to which extracted company
        current_query_term = None
        
        for tool_call in self.tool_history:
            tool_name = tool_call["tool"]
            tool_input = tool_call["input"]
            observation = tool_call["observation"]
            
            try:
                result_data = json.loads(observation)
                
                # Determine which query term this tool call is for
                # Match tool_input against extracted company variants
                for extracted in self.extracted_companies:
                    variants = extracted.get("variants", [])
                    if any(tool_input.lower().strip() == v.lower().strip() for v in variants):
                        current_query_term = extracted.get("original")
                        break
                
                if not current_query_term:
                    # Fallback: use the tool input itself as query term
                    current_query_term = tool_input
                
                # Initialize list for this query term if not exists
                if current_query_term not in query_to_matches:
                    query_to_matches[current_query_term] = []
                
                # Extract from map_nickname (entity blocks)
                if tool_name == "map_nickname" and result_data.get("is_nickname"):
                    official_names = result_data.get("official_names", [])
                    
                    # Handle both list and single string
                    if isinstance(official_names, list):
                        names_list = official_names
                    elif isinstance(official_names, str):
                        names_list = [official_names]
                    else:
                        names_list = []
                    
                    for name in names_list:
                        match = CompanyMatch(
                            query_term=current_query_term,
                            company_name=name,
                            match_method='nickname',
                            score=1.0,  # Perfect match for nicknames
                            metadata={'count': len(names_list)}
                        )
                        query_to_matches[current_query_term].append(match)
                
                # Extract from fuzzy_search with scores
                elif tool_name == "fuzzy_search":
                    matches = result_data.get("matches", [])
                    for match_data in matches:
                        company_name = match_data.get("company_name")
                        similarity_score = match_data.get("similarity_score", 0.0)
                        
                        if company_name:
                            match = CompanyMatch(
                                query_term=current_query_term,
                                company_name=company_name,
                                match_method='fuzzy',
                                score=similarity_score,
                                metadata={'fuzzy_rank': len(query_to_matches[current_query_term])}
                            )
                            query_to_matches[current_query_term].append(match)
                
                # Extract from disambiguate (overrides fuzzy_search)
                elif tool_name == "disambiguate":
                    selected = result_data.get("selected", [])
                    
                    if isinstance(selected, list):
                        # Remove previous fuzzy matches for this query term
                        # (disambiguate refines the fuzzy results)
                        query_to_matches[current_query_term] = []
                        
                        for name in selected:
                            match = CompanyMatch(
                                query_term=current_query_term,
                                company_name=name,
                                match_method='disambiguate',
                                score=0.9,  # High confidence for disambiguated results
                                metadata={'from_candidates': result_data.get('from_candidates', [])}
                            )
                            query_to_matches[current_query_term].append(match)
            
            except (json.JSONDecodeError, AttributeError, KeyError) as e:
                # Skip malformed observations
                continue
        
        # Sort matches by score (descending) for each query term
        for query_term in query_to_matches:
            query_to_matches[query_term].sort(key=lambda x: x.score, reverse=True)
        
        return query_to_matches
    
    def _format_final_answer(self, structured_result: Dict[str, List[CompanyMatch]]) -> str:
        """Format the structured result as a simple comma-separated list"""
        all_companies = []
        seen = set()
        
        # Collect all companies (maintaining score order per query term)
        for query_term, matches in structured_result.items():
            for match in matches:
                if match.company_name not in seen:
                    all_companies.append(match.company_name)
                    seen.add(match.company_name)
        
        return ", ".join(all_companies) if all_companies else "No companies found"
    
    def get_structured_result(self) -> Dict[str, List[CompanyMatch]]:
        """
        Get the structured result with full mapping information.
        Call this after run() to get detailed results.
        """
        return self._build_structured_result()
    
    def _get_completion(self) -> str:
        """Get completion from OpenAI API"""
        completion = self.client.chat.completions.create(
            model=Config.OPENAI_MODEL,
            messages=self.messages,
            temperature=0.0
        )
        return completion.choices[0].message.content
    
    def _execute_tool(self, tool_name: str, tool_input: str) -> str:
        """Execute a tool with the given input"""
        tool = self.tools_map[tool_name]
        tool_input = tool_input.strip('"\'')

        if tool_name in ["extract_names", "disambiguate"]:
            return tool.run(query=tool_input)
        else:
            return tool.run(candidate_name=tool_input)


# ## Usage Example

# ```python
# # Initialize agent
# agent = initialize_agent()

# # Run query
# result = agent.run("What is the client balance of Millenium and Bridge?")

# # Get simple answer
# print(f"Answer: {result}")
# # → "Millennium Partners, Millennium Management LLC, Bridge Associates"

# # Get detailed structured result
# structured = agent.get_structured_result()

# # Access detailed information
# for query_term, matches in structured.items():
#     print(f"\nQuery term: '{query_term}'")
#     for match in matches:
#         print(f"  → {match.company_name}")
#         print(f"     Score: {match.score:.3f}")
#         print(f"     Method: {match.match_method}")
# # ```

# ## Example Output

# ```
# Query term: 'Millenium'
#   → Millennium Partners
#      Score: 1.000
#      Method: nickname
#   → Millennium Management LLC
#      Score: 1.000
#      Method: nickname

# Query term: 'Bridge'
#   → Bridge Associates
#      Score: 1.000
#      Method: nickname

# FINAL ANSWER: Millennium Partners, Millennium Management LLC, Bridge Associates
# ```

# ## Key Features

# 1. **Query-to-Candidate Mapping**: Each original term is mapped to its matched companies
# 2. **Score Tracking**: Maintains similarity/confidence scores from fuzzy_search and assigns perfect scores (1.0) to nickname matches
# 3. **Method Tracking**: Records how each match was found (nickname, fuzzy, disambiguate)
# 4. **Score Ranking**: Matches are sorted by score for each query term
# 5. **Metadata**: Additional context like entity block size, fuzzy rank, etc.

# This gives you full visibility into how each company was matched!

In [25]:
# # REACT AGENT


# class ReActAgent:
#     """ReAct Agent for company name resolution with programmatic answer construction"""
    
#     SYSTEM_PROMPT = """You are a company name resolution agent specializing in the financial domain. 
#                 Your task is to identify names of financial institutions 
#                 (e.g., asset management firms, hedge funds, investment companies, financial institutions) 
#                 in user queries and map them to their official names in our database.
                
#                 **Available Tools:**
#                 1. extract_names[query] - Extract company names with spelling corrections and expansions. Returns each company with an ordered list of variants: [original, corrected, expanded]
#                 2. map_nickname[name] - For each company name in the ordered list of variants, including the original, check if a company name is a nickname/shorthand (returns official name immediately if found)
#                 3. fuzzy_search[name] - If there are no matches, for each company name in the ordered list of variants, find similar company names using fuzzy matching
#                 4. disambiguate[json_data] - Disambiguate between multiple candidates (required when fuzzy_search returns multiple matches)

#                 **CRITICAL: Action Format**
#                 All actions MUST use square brackets: Action: tool_name[input]
#                 Examples:
#                 - Action: extract_names[What is the balance of Apple?]
#                 - Action: map_nickname[2sigma]
#                 - Action: fuzzy_search[Millenia]
#                 - Action: disambiguate[{"candidates": ["Company A", "Company B"], "query": "term"}]


#                 **Required Workflow:**
#                 1. **Extract ALL company names** from the query using extract_names (may return multiple companies)
#                 2. **For EACH extracted company** (process each one separately):
#                    a. **Check nicknames in order**: Try map_nickname on original → corrected → expanded
#                    b. If ANY variant is a nickname → record the official name and move to next company
#                    c. If NONE are nicknames → proceed to fuzzy_search for this company
#                 3. **For EACH company using fuzzy_search**:
#                    a. Try fuzzy_search with original → corrected → expanded until you get matches
#                    b. **If fuzzy_search returns 2+ matches** → MUST call disambiguate for THIS company
#                    c. Record the disambiguated result and move to next company
#                 4. **After processing ALL companies** → Signal completion with "COMPANIES_RESOLVED"

#                 **IMPORTANT: When map_nickname returns multiple companies:**
#                 - This means the nickname represents an entity block (e.g., "Millennium" → multiple companies)
#                 - You do NOT need to pick one - the system will automatically include ALL of them
#                 - Simply note that the nickname was resolved and move to the next company

#                 **Critical Notes:**
#                 - Process EACH company independently - don't skip any
#                 - If query has 3 companies, you must resolve all 3 (nickname check OR fuzzy search for each)
#                 - Disambiguate is called PER COMPANY when that company has multiple fuzzy matches
#                 - Keep track of all companies found and ensure each gets resolved

#                 **Why check all nickname variants:**
#                 - Original might be a nickname (e.g., "2sigma")
#                 - Corrected might be a nickname (e.g., "aapl")
#                 - Expanded might be a nickname (e.g., full name that maps to official)
#                 - Checking variants ensures we don't miss direct mappings

#                 **Fuzzy Search Strategy (only if NO nickname match found for a company):**
#                 - Try variants in order: original → corrected → expanded
#                 - Stop when you get good matches for THIS company
#                 - If multiple matches returned, MUST call disambiguate for THIS company
#                 - Format for disambiguate: {"candidates": [matches for THIS company], "query": "this company's original search term"}

#                 **When to use disambiguate:**
#                 - ALWAYS use it when fuzzy_search returns 2 or more matches FOR A COMPANY
#                 - Called separately for each company that has multiple matches
#                 - Format: {"candidates": ["Company A", "Company B", ...], "query": "this specific company's search term"}
#                 - This determines if user meant all companies (ambiguous) or specific ones

#                 **Output Format:**
#                 Thought: [Your reasoning about what to do next]
#                 Action: tool_name[input_value]

#                 After receiving Observation:
#                 Thought: [Your reasoning based on the observation]
#                 Action: tool_name[input_value]

#                 OR when you have resolved all companies:
#                 Thought: I have resolved all companies from the query
#                 Answer: COMPANIES_RESOLVED

#                 **Important:**
#                 - ALWAYS use square brackets [] for actions, never parentheses or just curly braces
#                 - When done, simply return "Answer: COMPANIES_RESOLVED"
#                 - The system will automatically collect ALL companies from your tool calls
#                 - Process EACH company in the query independently
#                 - Check ALL nickname variants (original, corrected, expanded) BEFORE fuzzy search
#                 - Always disambiguate when multiple fuzzy matches found PER COMPANY
#                 - If no match for a company, note it but continue"""
                        
#     def __init__(self, client: OpenAI, tools: List[Tool]):
#         self.client = client
#         self.tools_map = {tool.name: tool for tool in tools}
#         self.messages = []
#         self.tool_call_regex = r"Action:\s*(\w+)\[(.+)\]"
#         self.tool_history = []  # NEW: Track all tool calls
    
#     def run(self, query: str, verbose: bool = True) -> str:
#         """Run the ReAct loop to resolve company names"""
#         self.messages = [
#             {"role": "system", "content": self.SYSTEM_PROMPT},
#             {"role": "user", "content": f"User Query: {query}"}
#         ]
#         self.tool_history = []  # Reset history for each run
        
#         for i in range(Config.MAX_ITERATIONS):
#             if verbose:
#                 print(f"\n{'='*60}")
#                 print(f"ITERATION {i+1}")
#                 print('='*60)
            
#             # Get LLM response
#             response = self._get_completion()
#             if verbose:
#                 print(f"\nAgent Response:\n{response}")
            
#             # Check for final answer
#             if "Answer:" in response:
#                 answer_text = response.split("Answer:", 1)[1].strip()
                
#                 # NEW: If answer is the signal, construct programmatic answer
#                 if "COMPANIES_RESOLVED" in answer_text:
#                     final_companies = self._extract_all_companies_from_history()
#                     answer = ", ".join(final_companies) if final_companies else "No companies found"
                    
#                     if verbose:
#                         print(f"\n{'='*60}")
#                         print(f"FINAL ANSWER (from tool history): {answer}")
#                         print('='*60)
#                     return answer
#                 else:
#                     # Fallback: use LLM's answer but post-process it
#                     final_companies = self._extract_all_companies_from_history()
#                     if final_companies:
#                         answer = ", ".join(final_companies)
#                     else:
#                         answer = answer_text
                    
#                     if verbose:
#                         print(f"\n{'='*60}")
#                         print(f"FINAL ANSWER: {answer}")
#                         print('='*60)
#                     return answer
            
#             # Check for action
#             action_match = re.search(self.tool_call_regex, response, re.DOTALL)
#             if action_match:
#                 tool_name, tool_input = action_match.groups()
#                 tool_input = tool_input.strip()
                
#                 if tool_name not in self.tools_map:
#                     observation = json.dumps({"error": f"Unknown tool '{tool_name}'"})
#                 else:
#                     if verbose:
#                         print(f"\n→ Executing: {tool_name}[{tool_input}]")
#                     observation = self._execute_tool(tool_name, tool_input)
                
#                 if verbose:
#                     print(f"→ Observation: {observation}")
                
#                 # NEW: Record tool call in history
#                 self.tool_history.append({
#                     "tool": tool_name,
#                     "input": tool_input,
#                     "observation": observation
#                 })
                
#                 self.messages.append({"role": "assistant", "content": response})
#                 self.messages.append({"role": "user", "content": f"Observation: {observation}"})
#             else:
#                 if verbose:
#                     print("\n⚠ No valid Action found")
#                 self.messages.append({"role": "assistant", "content": response})
        
#         # NEW: Even if max iterations reached, try to extract companies
#         final_companies = self._extract_all_companies_from_history()
#         if final_companies:
#             return ", ".join(final_companies)
#         return "Error: Max iterations reached without finding an answer."
    
#     def _extract_all_companies_from_history(self) -> List[str]:
#         """
#         Extract all resolved companies from tool call history.
#         This is the core of Option 4 - programmatically parse tool results.
#         """
#         all_companies = []
#         seen_companies = set()  # Avoid duplicates
        
#         for tool_call in self.tool_history:
#             tool_name = tool_call["tool"]
#             observation = tool_call["observation"]
            
#             try:
#                 result_data = json.loads(observation)
                
#                 # Extract from map_nickname (entity blocks)
#                 if tool_name == "map_nickname" and result_data.get("is_nickname"):
#                     official_names = result_data.get("official_names", [])
#                     # Support both list and single string (backward compatibility)
#                     if isinstance(official_names, list):
#                         for name in official_names:
#                             if name not in seen_companies:
#                                 all_companies.append(name)
#                                 seen_companies.add(name)
#                     elif isinstance(official_names, str):
#                         if official_names not in seen_companies:
#                             all_companies.append(official_names)
#                             seen_companies.add(official_names)
                
#                 # Extract from disambiguate
#                 elif tool_name == "disambiguate":
#                     selected = result_data.get("selected", [])
#                     if isinstance(selected, list):
#                         for name in selected:
#                             if name not in seen_companies:
#                                 all_companies.append(name)
#                                 seen_companies.add(name)
                
#                 # Extract from fuzzy_search (if only 1 match, it's definitive)
#                 elif tool_name == "fuzzy_search":
#                     matches = result_data.get("matches", [])
#                     if len(matches) == 1:
#                         company_name = matches[0].get("company_name")
#                         if company_name and company_name not in seen_companies:
#                             all_companies.append(company_name)
#                             seen_companies.add(company_name)
#                     # If multiple matches but no disambiguate call followed,
#                     # this means the agent didn't complete properly
#                     # Don't include these to avoid ambiguity
            
#             except (json.JSONDecodeError, AttributeError, KeyError):
#                 # Skip malformed observations
#                 continue
        
#         return all_companies
    
#     def _get_completion(self) -> str:
#         """Get completion from OpenAI API"""
#         completion = self.client.chat.completions.create(
#             model=Config.OPENAI_MODEL,
#             messages=self.messages,
#             temperature=0.0
#         )
#         return completion.choices[0].message.content
    
#     def _execute_tool(self, tool_name: str, tool_input: str) -> str:
#         """Execute a tool with the given input"""
#         tool = self.tools_map[tool_name]

#         tool_input = tool_input.strip('"\'')

#         if tool_name in ["extract_names", "disambiguate"]:
#             return tool.run(query=tool_input)
#         else:
#             return tool.run(candidate_name=tool_input)



In [26]:
# ============================================================================
# MAIN INITIALIZATION
# ============================================================================

def initialize_agent():
    """Initialize the complete agent system"""
    # Setup OpenAI client
    api_key = os.getenv('OPENAI_API_KEY')
    if not api_key:
        raise ValueError("OPENAI_API_KEY not found in environment")
    client = OpenAI(api_key=api_key)
    print("✓ OpenAI client initialized")
    
    # Initialize company database
    df_companies = initialize_company_database()
    print(f"✓ Loaded {len(df_companies)} companies")
    
    # Prepare data for fuzzy matcher
    company_data = [(str(i), row['Company']) for i, row in df_companies.iterrows()]
    
    # Initialize fuzzy matcher
    matcher = FastFuzzyMatcher(company_data)
    print("✓ Fuzzy matcher initialized")
    
    # Initialize tools
    tools_impl = CompanyResolutionTools(client, matcher)
    
    tools = [
        Tool("extract_names", tools_impl.extract_names, 
             "Extract company names with spelling correction and expansion variants"),
        Tool("map_nickname", tools_impl.map_nickname, 
             "Map nickname to official company name(s). \
               IMPORTANT: If this returns multiple companies, ALL of them must be included in the final answer."),
        Tool("fuzzy_search", tools_impl.fuzzy_search, 
             "Find similar company names using fuzzy matching"),
        Tool("disambiguate", tools_impl.disambiguate, 
             "Disambiguate between multiple candidates and return the most likely candidates"),
    ]
    
    # Create agent
    agent = ReActAgent(client=client, tools=tools)
    print("✓ ReAct Agent initialized\n")
    
    return agent

In [27]:
# ============================================================================
# USAGE EXAMPLE
# ============================================================================

data = {
    'Company': [
        "Apple Inc.", 
        "Two Sigma Investments, LP", 
        "Millennium Partners", 
        "Millennium Management LLC", 
        "WorldQuant Millennium Advisors",
        "Bridge Associates", 
        "Bridgewater Associates",
        "Bridger Capital",
        "Curry's Retail Ltd.",
        "Alphabet Inc.",
        "Abu Dhabi Investment Authority"
    ]
}

if __name__ == "__main__":
    # Initialize the agent
    agent = initialize_agent()
    
    # Test queries
    test_queries = [
        "What is the client balance of Milenium and Bride Assoc.?",
        # "Show me data for 2sigma",
        # "Compare Millennium and WorldQuant performance",
        # "What about Bridgwater Associates?"
    ]
    
    # Run tests
    for query in test_queries:
        print(f"\n{'#'*60}")
        print(f"QUERY: {query}")
        print('#'*60)
        result = agent.run(query, verbose=True)
        print(f"\nFinal Result: {result}\n")

✓ OpenAI client initialized
✓ Loaded 11 companies
✓ Fuzzy matcher initialized
✓ ReAct Agent initialized


############################################################
QUERY: What is the client balance of Milenium and Bride Assoc.?
############################################################

ITERATION 1

Agent Response:
Thought: I need to extract the company names "Milenium" and "Bride Assoc." from the user query and process each one to find their official names. I will start by extracting the names.

Action: extract_names[What is the client balance of Milenium and Bride Assoc.?]

→ Executing: extract_names[What is the client balance of Milenium and Bride Assoc.?]
→ Observation: {"companies": [{"original": "Milenium", "variants": ["Milenium", "Millennium", "Millennium Corporation"]}, {"original": "Bride Assoc.", "variants": ["Bride Assoc.", "Bride Assoc", "Bride Associates"]}], "message": "Extracted 2 company name(s) with variants"}

ITERATION 2

Agent Response:
Thought: I have extract