In [None]:
#!/usr/bin/env python3
"""
SIMPLE COMMAND-TO-METHOD MATCHING SYSTEM
=========================================

This is the SIMPLE approach using all-mpnet-base-v2

Run in Google Colab:
    !pip install sentence-transformers -q
    
Then run this script.
"""

# =============================================================================
# STEP 1: INSTALL (run in Colab cell first)
# =============================================================================
# !pip install sentence-transformers -q


# =============================================================================
# STEP 2: IMPORTS
# =============================================================================
from sentence_transformers import SentenceTransformer, util
import ast
from typing import List, Dict, Tuple
from dataclasses import dataclass


# =============================================================================
# STEP 3: DATA STRUCTURES
# =============================================================================
@dataclass
class Method:
    name: str
    docstring: str
    
    def __repr__(self):
        return f"{self.name}: {self.docstring}"


# =============================================================================
# STEP 4: THE CORE SYSTEM (Very Simple!)
# =============================================================================
class CommandMatcher:
    """
    Simple command-to-method matching using embeddings.
    
    Usage:
        matcher = CommandMatcher()
        matcher.load_class(python_code)
        result = matcher.match("turn on the light")
    """
    
    def __init__(self, model_name: str = "all-mpnet-base-v2"):
        """
        Initialize with the recommended model.
        
        Options:
        - "all-mpnet-base-v2"   : Best quality (recommended)
        - "all-MiniLM-L6-v2"    : Faster, smaller, good quality
        """
        print(f"Loading model: {model_name}")
        self.model = SentenceTransformer(model_name)
        print("Model loaded!")
        
        self.methods: List[Method] = []
        self.method_embeddings = None
    
    def load_class(self, source_code: str) -> List[Method]:
        """
        Load a Python class and extract methods with docstrings.
        """
        tree = ast.parse(source_code)
        self.methods = []
        
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                for item in node.body:
                    if isinstance(item, ast.FunctionDef):
                        # Skip private methods
                        if item.name.startswith('_'):
                            continue
                        
                        # Get docstring
                        docstring = ast.get_docstring(item)
                        if docstring:
                            self.methods.append(Method(
                                name=item.name,
                                docstring=docstring
                            ))
        
        # Embed all method docstrings
        if self.methods:
            docstrings = [m.docstring for m in self.methods]
            self.method_embeddings = self.model.encode(docstrings)
            print(f"Loaded {len(self.methods)} methods:")
            for m in self.methods:
                print(f"  - {m.name}: {m.docstring[:50]}...")
        
        return self.methods
    
    def match(self, command: str, top_k: int = 3) -> List[Tuple[Method, float]]:
        """
        Match a user command to the most similar methods.
        
        Returns:
            List of (Method, similarity_score) tuples, sorted by score
        """
        if not self.methods:
            return []
        
        # Embed the command
        command_embedding = self.model.encode(command)
        
        # Calculate similarities
        similarities = util.cos_sim(command_embedding, self.method_embeddings)[0]
        
        # Sort by similarity
        results = []
        for i, score in enumerate(similarities):
            results.append((self.methods[i], float(score)))
        
        results.sort(key=lambda x: x[1], reverse=True)
        
        return results[:top_k]
    
    def match_with_confidence(self, command: str, threshold: float = 0.5) -> Dict:
        """
        Match with confidence scoring.
        
        Returns:
            {
                "command": str,
                "best_match": Method or None,
                "confidence": float,
                "is_confident": bool,
                "all_matches": [(Method, score), ...]
            }
        """
        results = self.match(command, top_k=5)
        
        if not results:
            return {
                "command": command,
                "best_match": None,
                "confidence": 0.0,
                "is_confident": False,
                "all_matches": []
            }
        
        best_method, best_score = results[0]
        
        # Calculate margin (gap between 1st and 2nd)
        margin = 0
        if len(results) > 1:
            margin = best_score - results[1][1]
        
        # Confidence = score * margin factor
        confidence = best_score * (0.6 + 0.4 * min(margin * 3, 1.0))
        
        return {
            "command": command,
            "best_match": best_method,
            "confidence": round(confidence, 3),
            "is_confident": confidence >= threshold,
            "all_matches": results
        }


# =============================================================================
# STEP 5: TEST CASES
# =============================================================================
def test_smart_home():
    """Test with Smart Home controller"""
    
    print("\n" + "="*70)
    print("TEST 1: SMART HOME CONTROLLER")
    print("="*70)
    
    smart_home_code = '''
class SmartHome:
    """Smart home controller"""
    
    def turn_light_on(self):
        """Turn on the light"""
        pass
    
    def turn_light_off(self):
        """Turn off the light"""
        pass
    
    def turn_ac_on(self):
        """Turn on the air conditioner"""
        pass
    
    def turn_ac_off(self):
        """Turn off the air conditioner"""
        pass
    
    def set_temperature(self, temp):
        """Set the temperature to a specific value"""
        pass
    
    def lock_door(self):
        """Lock the front door"""
        pass
    
    def unlock_door(self):
        """Unlock the front door"""
        pass
'''
    
    matcher = CommandMatcher()
    matcher.load_class(smart_home_code)
    
    # Test commands
    test_commands = [
        # Direct matches
        "Turn on the light",
        "Turn off the light",
        "Turn on the air conditioner",
        
        # Variations (the key test!)
        "Switch on the AC",
        "Turn the light off",
        "Enable the air conditioner",
        "Shut off the lights",
        "Lock the door",
        "Set temperature to 72",
        
        # Harder variations
        "I want the AC on",
        "Make it cooler in here",
        "It's too dark",
    ]
    
    print("\n" + "-"*70)
    print("MATCHING RESULTS:")
    print("-"*70)
    
    for cmd in test_commands:
        result = matcher.match_with_confidence(cmd)
        
        if result["is_confident"]:
            status = "✓ CONFIDENT"
        else:
            status = "? UNCERTAIN"
        
        print(f"\nCommand: \"{cmd}\"")
        print(f"  {status}")
        print(f"  Best: {result['best_match'].name} (conf: {result['confidence']:.2f})")
        
        # Show top 3
        for method, score in result["all_matches"][:3]:
            bar = "█" * int(score * 20)
            print(f"    {score:.2f} {bar} {method.name}")


def test_calculator():
    """Test with Calculator"""
    
    print("\n" + "="*70)
    print("TEST 2: CALCULATOR")
    print("="*70)
    
    calculator_code = '''
class Calculator:
    """Basic calculator operations"""
    
    def add(self, a, b):
        """Add two numbers together"""
        return a + b
    
    def subtract(self, a, b):
        """Subtract the second number from the first"""
        return a - b
    
    def multiply(self, a, b):
        """Multiply two numbers together"""
        return a * b
    
    def divide(self, a, b):
        """Divide the first number by the second"""
        return a / b
    
    def power(self, base, exponent):
        """Raise a number to a power"""
        return base ** exponent
    
    def square_root(self, n):
        """Calculate the square root of a number"""
        return n ** 0.5
'''
    
    matcher = CommandMatcher()
    matcher.load_class(calculator_code)
    
    test_commands = [
        "Add 5 and 3",
        "Sum these numbers",
        "What's the total",
        "Subtract 10 from 20",
        "Take away 5",
        "Multiply 4 by 6",
        "What's 5 times 3",
        "Divide 10 by 2",
        "Split 20 into 4",
        "Calculate 2 to the power of 8",
        "What's the square root of 16",
    ]
    
    print("\n" + "-"*70)
    print("MATCHING RESULTS:")
    print("-"*70)
    
    for cmd in test_commands:
        result = matcher.match_with_confidence(cmd)
        
        if result["is_confident"]:
            status = "✓"
        else:
            status = "?"
        
        print(f"\n{status} \"{cmd}\"")
        print(f"   → {result['best_match'].name} (conf: {result['confidence']:.2f})")


def test_file_manager():
    """Test with File Manager - completely different domain"""
    
    print("\n" + "="*70)
    print("TEST 3: FILE MANAGER (Different Domain)")
    print("="*70)
    
    file_manager_code = '''
class FileManager:
    """File system operations"""
    
    def read_file(self, path):
        """Read the contents of a file"""
        pass
    
    def write_file(self, path, content):
        """Write content to a file"""
        pass
    
    def delete_file(self, path):
        """Delete a file from the system"""
        pass
    
    def copy_file(self, source, destination):
        """Copy a file to a new location"""
        pass
    
    def move_file(self, source, destination):
        """Move a file to a new location"""
        pass
    
    def list_directory(self, path):
        """List all files in a directory"""
        pass
    
    def create_directory(self, path):
        """Create a new directory"""
        pass
'''
    
    matcher = CommandMatcher()
    matcher.load_class(file_manager_code)
    
    test_commands = [
        "Read the file",
        "Open this document",
        "Load the data",
        "Save this to a file",
        "Write to document.txt",
        "Delete the file",
        "Remove this",
        "Copy file to backup",
        "Move to another folder",
        "Show me all files",
        "Create a new folder",
    ]
    
    print("\n" + "-"*70)
    print("MATCHING RESULTS:")
    print("-"*70)
    
    for cmd in test_commands:
        result = matcher.match_with_confidence(cmd)
        
        if result["is_confident"]:
            status = "✓"
        else:
            status = "?"
        
        print(f"\n{status} \"{cmd}\"")
        print(f"   → {result['best_match'].name} (conf: {result['confidence']:.2f})")


# =============================================================================
# STEP 6: RUN ALL TESTS
# =============================================================================
if __name__ == "__main__":
    print("="*70)
    print("COMMAND-TO-METHOD MATCHING SYSTEM")
    print("Model: all-mpnet-base-v2 (no training)")
    print("="*70)
    
    test_smart_home()
    test_calculator()
    test_file_manager()
    
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print("""
    This test shows how well the PRE-TRAINED model works
    WITHOUT any fine-tuning.
    
    If accuracy is >80%, you may not need training at all!
    
    If accuracy is <80%, proceed to fine-tuning.
    """)