# CodeRAG Focused Learning 2: DS-Code Graph with Semantic Relationships

**Mục tiêu**: Hiểu sâu về DS-Code Graph và cách xây dựng semantic relationships giữa các code elements

**Paper Reference**: Section 3.2 - DS-Code Graph Construction

---

## 🎯 Khái niệm cốt lõi

### Từ Paper (Section 3.2):
> *"Different from existing code graphs, DS-code graph not only models dependency relationships but also introduces semantic relationships among nodes."*

> *"DS-code graph contains four node types: Module, Class, Method, Function. DS-code graph contains five edge types: Import, Contain, Inherit, Call, Similarity."*

### Đặc điểm phức tạp:
1. **Multi-type Node Schema**: 4 loại nodes với hierarchy rõ ràng
2. **Hybrid Edge System**: Kết hợp dependency và semantic relationships
3. **Embedding-based Similarity**: Sử dụng embeddings để tính semantic similarity
4. **Graph Database Storage**: Lưu trữ hiệu quả với Neo4j
5. **Language-Specific Design**: Schema tailored cho Python characteristics

---

## 🔧 Environment Setup

In [None]:
import os
import ast
import json
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional, Set
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from pathlib import Path
import hashlib
from collections import defaultdict

# For embeddings and similarity computation
from langchain.embeddings import OpenAIEmbeddings
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer

# For code parsing
import tree_sitter
from tree_sitter import Language, Parser

# Set environment
os.environ['OPENAI_API_KEY'] = 'your-openai-api-key'

plt.style.use('seaborn-v0_8')
sns.set_palette("Set2")

## 📚 Lý thuyết sâu: DS-Code Graph Schema

### From Paper Section 3.2:

**Node Types:**
- **Module**: represents a code file
- **Class**: a class defined in the repository  
- **Method**: a method defined in the class
- **Function**: a function code defined in the repository

**Edge Types:**
- **Import**: modules' dependency relationships
- **Contain**: source code of a node contains the counterpart of another node
- **Inherit**: class inheritance relationships
- **Call**: invoking relationship among code snippets
- **Similarity**: two nodes have similar semantics

In [None]:
@dataclass
class CodeNode:
    """Enhanced Code Node với full attributes từ paper"""
    id: str
    node_type: str  # Module, Class, Method, Function
    name: str
    source_code: str
    signature: str
    file_path: str
    
    # Additional attributes for different node types
    class_name: Optional[str] = None  # For methods
    line_start: Optional[int] = None
    line_end: Optional[int] = None
    complexity_score: float = 0.0
    
    # Embeddings and similarity
    embedding: Optional[np.ndarray] = None
    semantic_hash: Optional[str] = None
    
    # Dependencies
    imports: List[str] = field(default_factory=list)
    calls: List[str] = field(default_factory=list)
    
    def __post_init__(self):
        if self.semantic_hash is None:
            self.semantic_hash = self._compute_semantic_hash()
    
    def _compute_semantic_hash(self) -> str:
        """Compute semantic hash for quick similarity checks"""
        content = f"{self.name}_{self.node_type}_{self.source_code}"
        return hashlib.md5(content.encode()).hexdigest()[:8]

@dataclass 
class CodeEdge:
    """Enhanced Code Edge với attributes"""
    source_id: str
    target_id: str
    edge_type: str  # import, contain, inherit, call, similarity
    confidence: float = 1.0
    weight: float = 1.0
    metadata: Dict = field(default_factory=dict)

class AdvancedCodeParser:
    """Advanced code parser với tree-sitter support"""
    
    def __init__(self):
        self.import_patterns = [
            r'import\s+([\w\.]+)',
            r'from\s+([\w\.]+)\s+import'
        ]
        
    def parse_file(self, file_path: str, content: str) -> List[CodeNode]:
        """Parse file and extract all code nodes"""
        nodes = []
        
        # Add module node
        module_node = CodeNode(
            id=f"module:{file_path}",
            node_type="Module",
            name=Path(file_path).stem,
            source_code=content,
            signature=f"module {file_path}",
            file_path=file_path,
            imports=self._extract_imports(content)
        )
        nodes.append(module_node)
        
        try:
            tree = ast.parse(content)
            
            # Extract classes and functions
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef):
                    class_nodes = self._extract_class_node(node, content, file_path)
                    nodes.extend(class_nodes)
                elif isinstance(node, ast.FunctionDef):
                    # Only top-level functions (not methods)
                    if not any(isinstance(parent, ast.ClassDef) 
                             for parent in ast.walk(tree) 
                             if hasattr(parent, 'body') and node in getattr(parent, 'body', [])):
                        func_node = self._extract_function_node(node, content, file_path)
                        nodes.append(func_node)
                        
        except SyntaxError as e:
            print(f"Syntax error parsing {file_path}: {e}")
            
        return nodes
    
    def _extract_imports(self, content: str) -> List[str]:
        """Extract import statements"""
        imports = []
        try:
            tree = ast.parse(content)
            for node in ast.walk(tree):
                if isinstance(node, ast.Import):
                    for alias in node.names:
                        imports.append(alias.name)
                elif isinstance(node, ast.ImportFrom):
                    if node.module:
                        imports.append(node.module)
        except:
            pass
        return imports
    
    def _extract_class_node(self, node: ast.ClassDef, content: str, file_path: str) -> List[CodeNode]:
        """Extract class and its methods"""
        nodes = []
        
        # Extract class source code
        source_lines = content.split('\n')
        class_source = '\n'.join(source_lines[node.lineno-1:node.end_lineno])
        
        # Extract base classes for inheritance
        base_classes = [base.id if isinstance(base, ast.Name) else str(base) 
                       for base in node.bases]
        
        class_node = CodeNode(
            id=f"class:{file_path}:{node.name}",
            node_type="Class",
            name=node.name,
            source_code=class_source,
            signature=f"class {node.name}({', '.join(base_classes)}):",
            file_path=file_path,
            line_start=node.lineno,
            line_end=node.end_lineno
        )
        nodes.append(class_node)
        
        # Extract methods
        for method in node.body:
            if isinstance(method, ast.FunctionDef):
                method_source = '\n'.join(source_lines[method.lineno-1:method.end_lineno])
                
                method_node = CodeNode(
                    id=f"method:{file_path}:{node.name}:{method.name}",
                    node_type="Method",
                    name=method.name,
                    source_code=method_source,
                    signature=self._extract_signature(method),
                    file_path=file_path,
                    class_name=node.name,
                    line_start=method.lineno,
                    line_end=method.end_lineno,
                    calls=self._extract_function_calls(method)
                )
                nodes.append(method_node)
                
        return nodes
    
    def _extract_function_node(self, node: ast.FunctionDef, content: str, file_path: str) -> CodeNode:
        """Extract function node"""
        source_lines = content.split('\n')
        func_source = '\n'.join(source_lines[node.lineno-1:node.end_lineno])
        
        return CodeNode(
            id=f"function:{file_path}:{node.name}",
            node_type="Function",
            name=node.name,
            source_code=func_source,
            signature=self._extract_signature(node),
            file_path=file_path,
            line_start=node.lineno,
            line_end=node.end_lineno,
            calls=self._extract_function_calls(node)
        )
    
    def _extract_signature(self, node: ast.FunctionDef) -> str:
        """Extract function/method signature"""
        args = [arg.arg for arg in node.args.args]
        return f"def {node.name}({', '.join(args)}):"
    
    def _extract_function_calls(self, node: ast.FunctionDef) -> List[str]:
        """Extract function calls within a function/method"""
        calls = []
        for child in ast.walk(node):
            if isinstance(child, ast.Call):
                if isinstance(child.func, ast.Name):
                    calls.append(child.func.id)
                elif isinstance(child.func, ast.Attribute):
                    calls.append(child.func.attr)
        return calls

# Test parser
parser = AdvancedCodeParser()

test_code = '''
import os
from typing import List

class DataProcessor:
    def __init__(self, config):
        self.config = config
    
    def process_data(self, data):
        cleaned = self.clean_data(data)
        return self.validate_data(cleaned)
    
    def clean_data(self, data):
        return [x for x in data if x is not None]
    
    def validate_data(self, data):
        return len(data) > 0

def utility_function(x, y):
    processor = DataProcessor({"strict": True})
    return processor.process_data([x, y])
'''

print("Testing Advanced Code Parser...")
nodes = parser.parse_file("test.py", test_code)
print(f"Extracted {len(nodes)} nodes:")
for node in nodes:
    print(f"  {node.node_type}: {node.name} (calls: {node.calls})")

## 🧠 Deep Dive: Semantic Relationship Construction

### Key Innovation từ paper:
> *"Besides dependency relationships, we also construct semantic relationships between nodes in the DS-code graph... we use a reliable embedding model to encode the source code of each node and complete semantic relationships according to their vectors' cosine similarities."*

### Advanced Features:
1. **Multi-level Similarity**: Code, signature, và semantic similarity
2. **Embedding Optimization**: Caching và efficient computation
3. **Threshold Tuning**: Dynamic threshold cho different node types
4. **Similarity Metrics**: Multiple metrics beyond cosine similarity

In [None]:
class SemanticRelationshipBuilder:
    """Advanced semantic relationship builder"""
    
    def __init__(self, embedding_model="openai", cache_embeddings=True):
        self.embedding_model = embedding_model
        self.cache_embeddings = cache_embeddings
        self.embedding_cache = {}
        
        # Initialize embedding models
        if embedding_model == "openai":
            self.embeddings = OpenAIEmbeddings()
        else:
            # Fallback to TF-IDF
            self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english')
            
        # Similarity thresholds by node type pairs
        self.similarity_thresholds = {
            ('Function', 'Function'): 0.75,
            ('Method', 'Method'): 0.70,
            ('Function', 'Method'): 0.65,
            ('Class', 'Class'): 0.60,
            ('Module', 'Module'): 0.50
        }
        
    def compute_embeddings(self, nodes: List[CodeNode]) -> Dict[str, np.ndarray]:
        """Compute embeddings for all nodes với caching"""
        embeddings_map = {}
        
        # Check cache first
        texts_to_embed = []
        nodes_to_embed = []
        
        for node in nodes:
            cache_key = f"{node.semantic_hash}_{self.embedding_model}"
            
            if self.cache_embeddings and cache_key in self.embedding_cache:
                embeddings_map[node.id] = self.embedding_cache[cache_key]
            else:
                # Prepare text for embedding
                text = self._prepare_text_for_embedding(node)
                texts_to_embed.append(text)
                nodes_to_embed.append(node)
        
        # Compute embeddings for remaining nodes
        if texts_to_embed:
            print(f"Computing embeddings for {len(texts_to_embed)} nodes...")
            
            if self.embedding_model == "openai":
                try:
                    batch_embeddings = self.embeddings.embed_documents(texts_to_embed)
                except Exception as e:
                    print(f"OpenAI embedding error: {e}, falling back to TF-IDF")
                    batch_embeddings = self._compute_tfidf_embeddings(texts_to_embed)
            else:
                batch_embeddings = self._compute_tfidf_embeddings(texts_to_embed)
            
            # Store results
            for node, embedding in zip(nodes_to_embed, batch_embeddings):
                embedding_array = np.array(embedding)
                embeddings_map[node.id] = embedding_array
                
                # Cache if enabled
                if self.cache_embeddings:
                    cache_key = f"{node.semantic_hash}_{self.embedding_model}"
                    self.embedding_cache[cache_key] = embedding_array
        
        return embeddings_map
    
    def _prepare_text_for_embedding(self, node: CodeNode) -> str:
        """Prepare node text for embedding"""
        # Different strategies based on node type
        if node.node_type == "Module":
            # For modules, use imports and top-level structure
            text = f"Module {node.name} imports: {' '.join(node.imports)}"
        elif node.node_type in ["Function", "Method"]:
            # For functions/methods, use signature + docstring + key statements
            docstring = self._extract_docstring(node.source_code)
            text = f"{node.signature} {docstring} calls: {' '.join(node.calls)}"
        elif node.node_type == "Class":
            # For classes, use class signature + method names
            text = f"{node.signature} {node.source_code[:200]}"
        else:
            text = node.source_code
        
        return text.strip()
    
    def _extract_docstring(self, source_code: str) -> str:
        """Extract docstring from source code"""
        try:
            tree = ast.parse(source_code)
            if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.ClassDef)):
                return ast.get_docstring(tree.body[0]) or ""
        except:
            pass
        return ""
    
    def _compute_tfidf_embeddings(self, texts: List[str]) -> List[np.ndarray]:
        """Fallback TF-IDF embeddings"""
        try:
            tfidf_matrix = self.tfidf.fit_transform(texts)
            return [tfidf_matrix[i].toarray().flatten() for i in range(len(texts))]
        except:
            # Ultimate fallback: random embeddings
            return [np.random.rand(100) for _ in texts]
    
    def find_semantic_similarities(self, nodes: List[CodeNode], embeddings_map: Dict[str, np.ndarray]) -> List[CodeEdge]:
        """Find semantic similarity relationships"""
        similarity_edges = []
        
        # Compare all pairs
        for i, node1 in enumerate(nodes):
            for node2 in nodes[i+1:]:
                if node1.id == node2.id:
                    continue
                
                # Skip if same file and same type (likely similar by design)
                if (node1.file_path == node2.file_path and 
                    node1.node_type == node2.node_type and
                    node1.node_type == "Module"):
                    continue
                
                # Get embeddings
                if node1.id not in embeddings_map or node2.id not in embeddings_map:
                    continue
                
                emb1 = embeddings_map[node1.id]
                emb2 = embeddings_map[node2.id]
                
                # Compute similarity
                similarity = self._compute_similarity(emb1, emb2)
                
                # Check threshold
                threshold = self._get_threshold(node1.node_type, node2.node_type)
                
                if similarity >= threshold:
                    edge = CodeEdge(
                        source_id=node1.id,
                        target_id=node2.id,
                        edge_type="similarity",
                        confidence=similarity,
                        weight=similarity,
                        metadata={
                            'similarity_score': similarity,
                            'threshold': threshold,
                            'node_types': (node1.node_type, node2.node_type)
                        }
                    )
                    similarity_edges.append(edge)
        
        print(f"Found {len(similarity_edges)} semantic similarity relationships")
        return similarity_edges
    
    def _compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
        """Compute cosine similarity between embeddings"""
        try:
            # Normalize vectors
            norm1 = np.linalg.norm(emb1)
            norm2 = np.linalg.norm(emb2)
            
            if norm1 == 0 or norm2 == 0:
                return 0.0
            
            # Cosine similarity
            similarity = np.dot(emb1, emb2) / (norm1 * norm2)
            return float(similarity)
        except:
            return 0.0
    
    def _get_threshold(self, type1: str, type2: str) -> float:
        """Get similarity threshold for node type pair"""
        key = (type1, type2) if (type1, type2) in self.similarity_thresholds else (type2, type1)
        return self.similarity_thresholds.get(key, 0.65)  # Default threshold
    
    def analyze_similarity_patterns(self, edges: List[CodeEdge]) -> Dict[str, any]:
        """Analyze patterns in semantic similarities"""
        analysis = {
            'total_similarities': len(edges),
            'by_node_types': defaultdict(int),
            'similarity_distribution': [],
            'high_confidence': 0,
            'cross_file_similarities': 0
        }
        
        for edge in edges:
            # Node type patterns
            node_types = edge.metadata.get('node_types', ('Unknown', 'Unknown'))
            type_key = f"{node_types[0]}-{node_types[1]}"
            analysis['by_node_types'][type_key] += 1
            
            # Similarity scores
            similarity = edge.metadata.get('similarity_score', 0)
            analysis['similarity_distribution'].append(similarity)
            
            if similarity >= 0.8:
                analysis['high_confidence'] += 1
            
            # Cross-file similarities
            source_file = edge.source_id.split(':')[1] if ':' in edge.source_id else ""
            target_file = edge.target_id.split(':')[1] if ':' in edge.target_id else ""
            if source_file != target_file:
                analysis['cross_file_similarities'] += 1
        
        return analysis

# Test semantic relationship builder
print("\nTesting Semantic Relationship Builder...")
semantic_builder = SemanticRelationshipBuilder(embedding_model="tfidf")  # Use TF-IDF for demo

# Compute embeddings
embeddings_map = semantic_builder.compute_embeddings(nodes)
print(f"Computed embeddings for {len(embeddings_map)} nodes")

# Find similarities
similarity_edges = semantic_builder.find_semantic_similarities(nodes, embeddings_map)
analysis = semantic_builder.analyze_similarity_patterns(similarity_edges)

print(f"\nSemantic Analysis Results:")
print(f"Total similarities: {analysis['total_similarities']}")
print(f"High confidence (>=0.8): {analysis['high_confidence']}")
print(f"Cross-file similarities: {analysis['cross_file_similarities']}")
print(f"Node type patterns: {dict(analysis['by_node_types'])}")

## 🕸️ Complete DS-Code Graph Construction

### Integration of All Relationship Types:
Kết hợp tất cả dependency và semantic relationships thành complete graph

In [None]:
class DSCodeGraph:
    """Complete DS-Code Graph implementation"""
    
    def __init__(self, semantic_builder: SemanticRelationshipBuilder):
        self.graph = nx.DiGraph()
        self.nodes = {}  # id -> CodeNode
        self.edges = {}  # (source, target) -> CodeEdge
        self.parser = AdvancedCodeParser()
        self.semantic_builder = semantic_builder
        
        # Statistics
        self.stats = {
            'nodes_by_type': defaultdict(int),
            'edges_by_type': defaultdict(int)
        }
    
    def build_from_repository(self, repository: Dict[str, str]):
        """Build complete DS-Code Graph from repository"""
        print("Building DS-Code Graph...")
        
        # Step 1: Extract all nodes
        print("Step 1: Extracting code nodes...")
        all_nodes = []
        for file_path, content in repository.items():
            file_nodes = self.parser.parse_file(file_path, content)
            all_nodes.extend(file_nodes)
        
        # Add nodes to graph
        for node in all_nodes:
            self.add_node(node)
        
        print(f"Extracted {len(all_nodes)} nodes")
        
        # Step 2: Extract dependency relationships
        print("Step 2: Extracting dependency relationships...")
        dependency_edges = self._extract_dependency_relationships(all_nodes)
        
        for edge in dependency_edges:
            self.add_edge(edge)
        
        print(f"Added {len(dependency_edges)} dependency relationships")
        
        # Step 3: Extract semantic relationships
        print("Step 3: Extracting semantic relationships...")
        embeddings_map = self.semantic_builder.compute_embeddings(all_nodes)
        semantic_edges = self.semantic_builder.find_semantic_similarities(all_nodes, embeddings_map)
        
        for edge in semantic_edges:
            self.add_edge(edge)
        
        print(f"Added {len(semantic_edges)} semantic relationships")
        
        # Update statistics
        self._update_statistics()
        
        print(f"\nDS-Code Graph completed:")
        print(f"Total nodes: {len(self.nodes)}")
        print(f"Total edges: {len(self.edges)}")
        
    def add_node(self, node: CodeNode):
        """Add node to graph"""
        self.graph.add_node(node.id, **node.__dict__)
        self.nodes[node.id] = node
        self.stats['nodes_by_type'][node.node_type] += 1
    
    def add_edge(self, edge: CodeEdge):
        """Add edge to graph"""
        if edge.source_id in self.nodes and edge.target_id in self.nodes:
            self.graph.add_edge(
                edge.source_id, 
                edge.target_id,
                edge_type=edge.edge_type,
                confidence=edge.confidence,
                weight=edge.weight,
                **edge.metadata
            )
            self.edges[(edge.source_id, edge.target_id)] = edge
            self.stats['edges_by_type'][edge.edge_type] += 1
    
    def _extract_dependency_relationships(self, nodes: List[CodeNode]) -> List[CodeEdge]:
        """Extract all dependency relationships"""
        edges = []
        
        # Create lookup maps
        modules_by_path = {node.file_path: node for node in nodes if node.node_type == "Module"}
        functions_by_name = defaultdict(list)
        methods_by_name = defaultdict(list)
        classes_by_name = defaultdict(list)
        
        for node in nodes:
            if node.node_type == "Function":
                functions_by_name[node.name].append(node)
            elif node.node_type == "Method":
                methods_by_name[node.name].append(node)
            elif node.node_type == "Class":
                classes_by_name[node.name].append(node)
        
        for node in nodes:
            # 1. Import relationships
            if node.node_type == "Module" and node.imports:
                for import_name in node.imports:
                    # Find imported module
                    for other_node in nodes:
                        if (other_node.node_type == "Module" and 
                            (import_name in other_node.file_path or import_name == other_node.name)):
                            edges.append(CodeEdge(
                                source_id=node.id,
                                target_id=other_node.id,
                                edge_type="import",
                                confidence=0.9
                            ))
            
            # 2. Containment relationships
            if node.node_type == "Module":
                # Module contains classes and functions
                for other_node in nodes:
                    if (other_node.file_path == node.file_path and 
                        other_node.node_type in ["Class", "Function"]):
                        edges.append(CodeEdge(
                            source_id=node.id,
                            target_id=other_node.id,
                            edge_type="contain",
                            confidence=1.0
                        ))
            elif node.node_type == "Class":
                # Class contains methods
                for other_node in nodes:
                    if (other_node.node_type == "Method" and 
                        other_node.class_name == node.name and
                        other_node.file_path == node.file_path):
                        edges.append(CodeEdge(
                            source_id=node.id,
                            target_id=other_node.id,
                            edge_type="contain",
                            confidence=1.0
                        ))
            
            # 3. Call relationships
            if node.node_type in ["Function", "Method"] and node.calls:
                for call_name in node.calls:
                    # Find called functions/methods
                    target_nodes = functions_by_name.get(call_name, []) + methods_by_name.get(call_name, [])
                    
                    for target_node in target_nodes:
                        if target_node.id != node.id:  # Don't call self
                            confidence = 0.8 if target_node.file_path == node.file_path else 0.6
                            edges.append(CodeEdge(
                                source_id=node.id,
                                target_id=target_node.id,
                                edge_type="call",
                                confidence=confidence,
                                metadata={'call_name': call_name}
                            ))
            
            # 4. Inheritance relationships (simplified)
            if node.node_type == "Class":
                # Extract base classes from signature
                if '(' in node.signature and ')' in node.signature:
                    base_part = node.signature.split('(')[1].split(')')[0]
                    if base_part.strip() and base_part.strip() != "":
                        base_classes = [bc.strip() for bc in base_part.split(',')]
                        
                        for base_class in base_classes:
                            if base_class and base_class != "object":
                                target_nodes = classes_by_name.get(base_class, [])
                                for target_node in target_nodes:
                                    edges.append(CodeEdge(
                                        source_id=node.id,
                                        target_id=target_node.id,
                                        edge_type="inherit",
                                        confidence=0.9,
                                        metadata={'base_class': base_class}
                                    ))
        
        return edges
    
    def _update_statistics(self):
        """Update graph statistics"""
        self.stats['total_nodes'] = len(self.nodes)
        self.stats['total_edges'] = len(self.edges)
        self.stats['density'] = nx.density(self.graph)
        
        # Calculate connectivity metrics
        if self.graph.nodes():
            self.stats['avg_degree'] = sum(dict(self.graph.degree()).values()) / len(self.graph.nodes())
            self.stats['max_degree'] = max(dict(self.graph.degree()).values())
        
    def get_one_hop_neighbors(self, node_id: str, edge_types: Optional[List[str]] = None) -> List[Tuple[str, Dict]]:
        """Get one-hop neighbors with optional edge type filtering"""
        neighbors = []
        
        if node_id not in self.graph:
            return neighbors
        
        # Outgoing edges
        for successor in self.graph.successors(node_id):
            edge_data = self.graph.get_edge_data(node_id, successor)
            if edge_types is None or edge_data.get('edge_type') in edge_types:
                neighbors.append((successor, edge_data))
        
        # Incoming edges  
        for predecessor in self.graph.predecessors(node_id):
            edge_data = self.graph.get_edge_data(predecessor, node_id)
            if edge_types is None or edge_data.get('edge_type') in edge_types:
                neighbors.append((predecessor, edge_data))
        
        return neighbors
    
    def visualize_subgraph(self, center_node_id: str, max_depth: int = 2, 
                          edge_types: Optional[List[str]] = None, figsize=(16, 12)):
        """Visualize subgraph around a center node"""
        
        if center_node_id not in self.nodes:
            print(f"Node {center_node_id} not found")
            return
        
        # BFS to find subgraph
        subgraph_nodes = {center_node_id}
        current_level = {center_node_id}
        
        for depth in range(max_depth):
            next_level = set()
            for node_id in current_level:
                neighbors = self.get_one_hop_neighbors(node_id, edge_types)
                for neighbor_id, _ in neighbors[:5]:  # Limit to 5 neighbors per node
                    if neighbor_id not in subgraph_nodes:
                        next_level.add(neighbor_id)
                        subgraph_nodes.add(neighbor_id)
            
            current_level = next_level
            if not current_level:
                break
        
        # Create subgraph
        subgraph = self.graph.subgraph(subgraph_nodes)
        
        if len(subgraph.nodes()) == 0:
            print("No nodes to visualize")
            return
        
        plt.figure(figsize=figsize)
        
        # Layout
        pos = nx.spring_layout(subgraph, k=3, iterations=50)
        
        # Node colors by type
        node_colors = {
            'Module': 'lightcoral',
            'Class': 'lightblue',
            'Method': 'lightgreen', 
            'Function': 'lightyellow'
        }
        
        colors = []
        sizes = []
        for node_id in subgraph.nodes():
            node = self.nodes[node_id]
            colors.append(node_colors.get(node.node_type, 'gray'))
            # Center node larger
            sizes.append(2000 if node_id == center_node_id else 1000)
        
        # Draw nodes
        nx.draw_networkx_nodes(subgraph, pos,
                              node_color=colors,
                              node_size=sizes,
                              alpha=0.8)
        
        # Draw edges by type
        edge_colors = {
            'import': 'blue',
            'contain': 'green', 
            'inherit': 'purple',
            'call': 'red',
            'similarity': 'orange'
        }
        
        edge_styles = {
            'similarity': 'dashed',
            'inherit': 'dotted'
        }
        
        for edge_type, color in edge_colors.items():
            edges = [(u, v) for u, v, d in subgraph.edges(data=True) 
                    if d.get('edge_type') == edge_type]
            
            if edges:
                style = edge_styles.get(edge_type, 'solid')
                width = 3 if edge_type == 'call' else 2
                alpha = 0.6 if edge_type == 'similarity' else 0.8
                
                nx.draw_networkx_edges(subgraph, pos,
                                      edgelist=edges,
                                      edge_color=color,
                                      width=width,
                                      alpha=alpha,
                                      style=style)
        
        # Draw labels
        labels = {node_id: self.nodes[node_id].name for node_id in subgraph.nodes()}
        nx.draw_networkx_labels(subgraph, pos, labels, font_size=8)
        
        plt.title(f'DS-Code Graph: {self.nodes[center_node_id].name}\n'
                 f'(Depth: {max_depth}, Nodes: {len(subgraph.nodes())}, Edges: {len(subgraph.edges())})',
                 fontsize=14)
        
        # Legend
        legend_elements = []
        for edge_type, color in edge_colors.items():
            if any(d.get('edge_type') == edge_type for _, _, d in subgraph.edges(data=True)):
                legend_elements.append(plt.Line2D([0], [0], color=color, lw=2, label=edge_type.title()))
        
        if legend_elements:
            plt.legend(handles=legend_elements, loc='upper right')
        
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        # Print subgraph statistics
        print(f"\nSubgraph Statistics:")
        print(f"Center node: {self.nodes[center_node_id].name} ({self.nodes[center_node_id].node_type})")
        print(f"Nodes: {len(subgraph.nodes())}")
        print(f"Edges: {len(subgraph.edges())}")
        
        edge_type_counts = defaultdict(int)
        for _, _, d in subgraph.edges(data=True):
            edge_type_counts[d.get('edge_type', 'unknown')] += 1
        
        print("Edge types:", dict(edge_type_counts))

# Test complete DS-Code Graph
complex_repository = {
    'auth.py': '''
import hashlib
from typing import Optional

class User:
    def __init__(self, username: str, email: str):
        self.username = username
        self.email = email
        self.password_hash = None
    
    def set_password(self, password: str):
        """Set user password with hashing"""
        self.password_hash = hash_password(password)
    
    def verify_password(self, password: str) -> bool:
        """Verify user password"""
        return verify_password_hash(password, self.password_hash)

def hash_password(password: str) -> str:
    """Hash password using SHA256"""
    return hashlib.sha256(password.encode()).hexdigest()

def verify_password_hash(password: str, hash_value: str) -> bool:
    """Verify password against hash"""
    return hash_password(password) == hash_value
''',
    
    'utils.py': '''
from typing import List, Any

def validate_email(email: str) -> bool:
    """Validate email format"""
    return "@" in email and "." in email

def clean_data(data: List[Any]) -> List[Any]:
    """Clean data by removing None values"""
    return [item for item in data if item is not None]

def sanitize_input(text: str) -> str:
    """Sanitize user input"""
    cleaned = clean_data([text.strip()])
    return cleaned[0] if cleaned else ""
''',
    
    'app.py': '''
from auth import User, hash_password
from utils import validate_email, sanitize_input

class Application:
    def __init__(self):
        self.users = []
    
    def register_user(self, username: str, email: str, password: str) -> bool:
        """Register a new user"""
        if not validate_email(email):
            return False
        
        clean_username = sanitize_input(username)
        user = User(clean_username, email)
        user.set_password(password)
        self.users.append(user)
        return True
    
    def authenticate_user(self, username: str, password: str) -> bool:
        """Authenticate user credentials"""
        for user in self.users:
            if user.username == username:
                return user.verify_password(password)
        return False

def main():
    """Main application entry point"""
    app = Application()
    
    # Register demo user
    success = app.register_user("demo", "demo@example.com", "password123")
    print(f"Registration: {success}")
    
    # Test authentication
    auth_result = app.authenticate_user("demo", "password123")
    print(f"Authentication: {auth_result}")
'''
}

print("\nBuilding Complete DS-Code Graph...")
ds_graph = DSCodeGraph(semantic_builder)
ds_graph.build_from_repository(complex_repository)

# Print statistics
print(f"\nGraph Statistics:")
print(f"Nodes by type: {dict(ds_graph.stats['nodes_by_type'])}")
print(f"Edges by type: {dict(ds_graph.stats['edges_by_type'])}")
print(f"Density: {ds_graph.stats.get('density', 0):.3f}")
print(f"Average degree: {ds_graph.stats.get('avg_degree', 0):.2f}")

# Visualize around a key function
app_class_id = None
for node_id, node in ds_graph.nodes.items():
    if node.name == "Application" and node.node_type == "Class":
        app_class_id = node_id
        break

if app_class_id:
    print(f"\nVisualizing subgraph around Application class...")
    ds_graph.visualize_subgraph(app_class_id, max_depth=2)
else:
    print("Application class not found for visualization")

## 📊 Performance Analysis và Quality Metrics

### Phân tích chất lượng của DS-Code Graph construction:

In [None]:
def analyze_ds_code_graph_quality(ds_graph: DSCodeGraph):
    """Comprehensive quality analysis of DS-Code Graph"""
    
    # Collect graph metrics
    metrics = {
        'graph_connectivity': {},
        'relationship_quality': {},
        'semantic_analysis': {},
        'structural_properties': {}
    }
    
    # 1. Graph connectivity analysis
    if ds_graph.graph.nodes():
        # Basic connectivity
        metrics['graph_connectivity'] = {
            'num_nodes': len(ds_graph.graph.nodes()),
            'num_edges': len(ds_graph.graph.edges()),
            'density': nx.density(ds_graph.graph),
            'is_connected': nx.is_weakly_connected(ds_graph.graph),
            'num_components': nx.number_weakly_connected_components(ds_graph.graph)
        }
        
        # Degree analysis
        degrees = dict(ds_graph.graph.degree())
        in_degrees = dict(ds_graph.graph.in_degree())
        out_degrees = dict(ds_graph.graph.out_degree())
        
        metrics['graph_connectivity'].update({
            'avg_degree': np.mean(list(degrees.values())),
            'max_degree': max(degrees.values()),
            'avg_in_degree': np.mean(list(in_degrees.values())),
            'avg_out_degree': np.mean(list(out_degrees.values()))
        })
    
    # 2. Relationship quality analysis
    edge_types = defaultdict(list)
    confidence_scores = []
    
    for source, target, data in ds_graph.graph.edges(data=True):
        edge_type = data.get('edge_type', 'unknown')
        confidence = data.get('confidence', 0)
        
        edge_types[edge_type].append(confidence)
        confidence_scores.append(confidence)
    
    metrics['relationship_quality'] = {
        'edge_type_counts': {et: len(confs) for et, confs in edge_types.items()},
        'avg_confidence_by_type': {et: np.mean(confs) for et, confs in edge_types.items()},
        'overall_avg_confidence': np.mean(confidence_scores) if confidence_scores else 0,
        'high_confidence_ratio': sum(1 for c in confidence_scores if c >= 0.8) / len(confidence_scores) if confidence_scores else 0
    }
    
    # 3. Semantic analysis
    semantic_edges = [e for e in ds_graph.edges.values() if e.edge_type == 'similarity']
    
    if semantic_edges:
        semantic_scores = [e.confidence for e in semantic_edges]
        cross_file_semantic = sum(1 for e in semantic_edges 
                                 if e.source_id.split(':')[1] != e.target_id.split(':')[1])
        
        metrics['semantic_analysis'] = {
            'num_semantic_relationships': len(semantic_edges),
            'avg_semantic_confidence': np.mean(semantic_scores),
            'cross_file_semantic_ratio': cross_file_semantic / len(semantic_edges),
            'semantic_score_distribution': {
                'min': min(semantic_scores),
                'max': max(semantic_scores),
                'std': np.std(semantic_scores)
            }
        }
    
    # 4. Structural properties
    node_type_connectivity = defaultdict(lambda: {'in': 0, 'out': 0, 'total': 0})
    
    for node_id, node in ds_graph.nodes.items():
        in_deg = ds_graph.graph.in_degree(node_id)
        out_deg = ds_graph.graph.out_degree(node_id)
        
        node_type_connectivity[node.node_type]['in'] += in_deg
        node_type_connectivity[node.node_type]['out'] += out_deg
        node_type_connectivity[node.node_type]['total'] += (in_deg + out_deg)
    
    # Normalize by node count
    for node_type in node_type_connectivity:
        count = ds_graph.stats['nodes_by_type'][node_type]
        if count > 0:
            node_type_connectivity[node_type] = {
                k: v / count for k, v in node_type_connectivity[node_type].items()
            }
    
    metrics['structural_properties'] = {
        'node_type_connectivity': dict(node_type_connectivity),
        'nodes_by_type': dict(ds_graph.stats['nodes_by_type']),
        'edges_by_type': dict(ds_graph.stats['edges_by_type'])
    }
    
    return metrics

def visualize_graph_quality_metrics(metrics: Dict):
    """Visualize DS-Code Graph quality metrics"""
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Edge type distribution
    edge_counts = metrics['relationship_quality']['edge_type_counts']
    if edge_counts:
        ax1.bar(edge_counts.keys(), edge_counts.values(), alpha=0.8)
        ax1.set_title('Edge Type Distribution')
        ax1.set_ylabel('Count')
        ax1.tick_params(axis='x', rotation=45)
    
    # 2. Confidence by edge type
    conf_by_type = metrics['relationship_quality']['avg_confidence_by_type']
    if conf_by_type:
        colors = ['skyblue' if c >= 0.8 else 'lightcoral' for c in conf_by_type.values()]
        bars = ax2.bar(conf_by_type.keys(), conf_by_type.values(), color=colors, alpha=0.8)
        ax2.set_title('Average Confidence by Edge Type')
        ax2.set_ylabel('Confidence Score')
        ax2.tick_params(axis='x', rotation=45)
        ax2.axhline(y=0.8, color='red', linestyle='--', alpha=0.5, label='High Confidence Threshold')
        ax2.legend()
    
    # 3. Node type connectivity
    if 'node_type_connectivity' in metrics['structural_properties']:
        connectivity = metrics['structural_properties']['node_type_connectivity']
        node_types = list(connectivity.keys())
        in_degrees = [connectivity[nt]['in'] for nt in node_types]
        out_degrees = [connectivity[nt]['out'] for nt in node_types]
        
        x = np.arange(len(node_types))
        width = 0.35
        
        ax3.bar(x - width/2, in_degrees, width, label='In-degree', alpha=0.8)
        ax3.bar(x + width/2, out_degrees, width, label='Out-degree', alpha=0.8)
        ax3.set_title('Average Connectivity by Node Type')
        ax3.set_ylabel('Average Degree')
        ax3.set_xticks(x)
        ax3.set_xticklabels(node_types)
        ax3.legend()
    
    # 4. Overall graph metrics
    if 'graph_connectivity' in metrics:
        conn_metrics = metrics['graph_connectivity']
        
        metric_names = ['Density', 'Avg Degree', 'Connectivity']
        metric_values = [
            conn_metrics.get('density', 0),
            conn_metrics.get('avg_degree', 0) / 10,  # Normalized for visualization
            1.0 if conn_metrics.get('is_connected', False) else 0.0
        ]
        
        colors = ['green' if v >= 0.5 else 'orange' if v >= 0.2 else 'red' for v in metric_values]
        ax4.bar(metric_names, metric_values, color=colors, alpha=0.8)
        ax4.set_title('Overall Graph Quality Metrics')
        ax4.set_ylabel('Normalized Score')
        ax4.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed analysis
    print("\n" + "="*70)
    print("DS-CODE GRAPH QUALITY ANALYSIS")
    print("="*70)
    
    if 'graph_connectivity' in metrics:
        conn = metrics['graph_connectivity']
        print(f"\n📊 Graph Connectivity:")
        print(f"• Nodes: {conn.get('num_nodes', 0)}")
        print(f"• Edges: {conn.get('num_edges', 0)}")
        print(f"• Density: {conn.get('density', 0):.3f}")
        print(f"• Connected: {conn.get('is_connected', False)}")
        print(f"• Components: {conn.get('num_components', 0)}")
    
    if 'relationship_quality' in metrics:
        rel = metrics['relationship_quality']
        print(f"\n🔗 Relationship Quality:")
        print(f"• Overall confidence: {rel.get('overall_avg_confidence', 0):.3f}")
        print(f"• High confidence ratio: {rel.get('high_confidence_ratio', 0):.3f}")
        print(f"• Edge types: {rel.get('edge_type_counts', {})}")
    
    if 'semantic_analysis' in metrics:
        sem = metrics['semantic_analysis']
        print(f"\n🧠 Semantic Analysis:")
        print(f"• Semantic relationships: {sem.get('num_semantic_relationships', 0)}")
        print(f"• Avg semantic confidence: {sem.get('avg_semantic_confidence', 0):.3f}")
        print(f"• Cross-file semantic ratio: {sem.get('cross_file_semantic_ratio', 0):.3f}")
    
    print(f"\n💡 Key Insights:")
    
    # Generate insights based on metrics
    insights = []
    
    if metrics.get('relationship_quality', {}).get('high_confidence_ratio', 0) >= 0.7:
        insights.append("High-quality relationships với strong confidence scores")
    
    if metrics.get('semantic_analysis', {}).get('cross_file_semantic_ratio', 0) >= 0.3:
        insights.append("Good cross-file semantic relationships detected")
    
    if metrics.get('graph_connectivity', {}).get('density', 0) >= 0.1:
        insights.append("Well-connected graph structure")
    
    edge_types = metrics.get('relationship_quality', {}).get('edge_type_counts', {})
    if len(edge_types) >= 4:
        insights.append("Comprehensive relationship coverage (4+ edge types)")
    
    for insight in insights:
        print(f"• {insight}")
    
    if not insights:
        print("• Graph structure có thể cần optimization")
    
    print("="*70)

# Run comprehensive quality analysis
print("\nRunning DS-Code Graph quality analysis...")
quality_metrics = analyze_ds_code_graph_quality(ds_graph)
visualize_graph_quality_metrics(quality_metrics)

## 🧪 Mock Data Testing và Independent Validation

Tạo test cases độc lập để validate DS-Code Graph construction:

In [None]:
def create_test_scenarios_for_ds_graph():
    """Create comprehensive test scenarios for DS-Code Graph"""
    
    scenarios = {
        'inheritance_test': {
            'description': 'Test inheritance relationship extraction',
            'code': '''
class Animal:
    def speak(self):
        pass

class Dog(Animal):
    def speak(self):
        return "Woof!"

class Cat(Animal):
    def speak(self):
        return "Meow!"
''',
            'expected_edges': [
                ('inherit', 'Dog', 'Animal'),
                ('inherit', 'Cat', 'Animal'),
                ('similarity', 'Dog.speak', 'Cat.speak')
            ]
        },
        
        'call_dependency_test': {
            'description': 'Test function call dependency extraction',
            'code': '''
def utility_func(x):
    return x * 2

def helper_func(y):
    return y + 1

def main_func(data):
    processed = utility_func(data)
    result = helper_func(processed)
    return result
''',
            'expected_edges': [
                ('call', 'main_func', 'utility_func'),
                ('call', 'main_func', 'helper_func'),
                ('similarity', 'utility_func', 'helper_func')
            ]
        },
        
        'containment_test': {
            'description': 'Test containment relationship extraction',
            'code': '''
class Calculator:
    def add(self, a, b):
        return a + b
    
    def multiply(self, a, b):
        return a * b
    
    def complex_operation(self, x, y):
        sum_result = self.add(x, y)
        return self.multiply(sum_result, 2)
''',
            'expected_edges': [
                ('contain', 'Calculator', 'add'),
                ('contain', 'Calculator', 'multiply'),
                ('contain', 'Calculator', 'complex_operation'),
                ('call', 'complex_operation', 'add'),
                ('call', 'complex_operation', 'multiply')
            ]
        }
    }
    
    return scenarios

def run_ds_graph_test_scenario(scenario_name: str, scenario_data: Dict, semantic_builder: SemanticRelationshipBuilder):
    """Run a single test scenario"""
    
    print(f"\n🧪 Testing: {scenario_name}")
    print(f"Description: {scenario_data['description']}")
    
    # Create test repository
    test_repo = {f"{scenario_name}.py": scenario_data['code']}
    
    # Build DS-Code Graph
    test_graph = DSCodeGraph(semantic_builder)
    test_graph.build_from_repository(test_repo)
    
    # Extract actual relationships
    actual_edges = []
    for source, target, data in test_graph.graph.edges(data=True):
        edge_type = data.get('edge_type')
        source_name = test_graph.nodes[source].name
        target_name = test_graph.nodes[target].name
        
        # For methods, include class name
        if test_graph.nodes[source].node_type == 'Method':
            source_name = f"{test_graph.nodes[source].class_name}.{source_name}"
        if test_graph.nodes[target].node_type == 'Method':
            target_name = f"{test_graph.nodes[target].class_name}.{target_name}"
        
        actual_edges.append((edge_type, source_name, target_name))
    
    # Compare with expected
    expected_edges = scenario_data['expected_edges']
    
    print(f"\nExpected edges: {len(expected_edges)}")
    print(f"Actual edges: {len(actual_edges)}")
    
    # Find matches
    matches = 0
    for expected in expected_edges:
        if expected in actual_edges:
            matches += 1
            print(f"✓ Found: {expected}")
        else:
            print(f"✗ Missing: {expected}")
    
    # Show unexpected edges
    unexpected = [edge for edge in actual_edges if edge not in expected_edges]
    if unexpected:
        print(f"\nUnexpected edges found:")
        for edge in unexpected[:5]:  # Limit output
            print(f"  + {edge}")
    
    # Calculate success rate
    success_rate = matches / len(expected_edges) if expected_edges else 0
    print(f"\nSuccess rate: {matches}/{len(expected_edges)} ({success_rate:.1%})")
    
    return {
        'scenario': scenario_name,
        'expected': len(expected_edges),
        'actual': len(actual_edges),
        'matches': matches,
        'success_rate': success_rate,
        'graph': test_graph
    }

def run_comprehensive_ds_graph_tests():
    """Run all DS-Code Graph test scenarios"""
    
    scenarios = create_test_scenarios_for_ds_graph()
    results = []
    
    # Use TF-IDF for faster testing
    test_semantic_builder = SemanticRelationshipBuilder(embedding_model="tfidf")
    
    for scenario_name, scenario_data in scenarios.items():
        result = run_ds_graph_test_scenario(scenario_name, scenario_data, test_semantic_builder)
        results.append(result)
    
    # Overall summary
    print("\n" + "="*70)
    print("DS-CODE GRAPH TEST SUMMARY")
    print("="*70)
    
    total_expected = sum(r['expected'] for r in results)
    total_matches = sum(r['matches'] for r in results)
    overall_success = total_matches / total_expected if total_expected > 0 else 0
    
    print(f"\n📊 Overall Results:")
    print(f"• Total test scenarios: {len(results)}")
    print(f"• Total expected relationships: {total_expected}")
    print(f"• Total matches found: {total_matches}")
    print(f"• Overall success rate: {overall_success:.1%}")
    
    print(f"\n📋 Individual Scenario Results:")
    for result in results:
        print(f"• {result['scenario']}: {result['success_rate']:.1%} ({result['matches']}/{result['expected']})")
    
    # Visualization
    plt.figure(figsize=(12, 6))
    
    # Success rates by scenario
    scenario_names = [r['scenario'].replace('_test', '') for r in results]
    success_rates = [r['success_rate'] for r in results]
    
    colors = ['green' if sr >= 0.8 else 'orange' if sr >= 0.6 else 'red' for sr in success_rates]
    
    plt.bar(scenario_names, success_rates, color=colors, alpha=0.8)
    plt.title('DS-Code Graph Test Results by Scenario')
    plt.ylabel('Success Rate')
    plt.xlabel('Test Scenario')
    plt.ylim(0, 1)
    plt.xticks(rotation=45)
    
    # Add success rate labels
    for i, rate in enumerate(success_rates):
        plt.text(i, rate + 0.02, f'{rate:.1%}', ha='center', va='bottom')
    
    plt.axhline(y=0.8, color='green', linestyle='--', alpha=0.5, label='Target (80%)')
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    return results

# Run comprehensive tests
test_results = run_comprehensive_ds_graph_tests()

print("\n" + "="*70)
print("DS-CODE GRAPH FOCUSED LEARNING COMPLETE")
print("="*70)
print("Key Learnings:")
print("1. Multi-type node schema enables rich relationship modeling")
print("2. Semantic relationships complement dependency relationships")
print("3. Embedding-based similarity detection works for code")
print("4. Graph visualization reveals complex code structures")
print("5. Quality metrics help validate graph construction")
print("6. Test scenarios ensure reliable relationship extraction")
print("="*70)