# Building Custom Components

LlamaIndex is highly extensible. This notebook shows how to create custom LLMs, embeddings, retrievers, and other components.

## Learning Objectives

By the end of this notebook, you will:
1. Create custom LLM wrappers
2. Build custom embedding models
3. Implement custom node parsers
4. Design custom response synthesizers
5. Create custom callbacks for monitoring

---

In [None]:
# Setup
import nest_asyncio
nest_asyncio.apply()

from dotenv import load_dotenv
load_dotenv()

from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from typing import Any, List, Optional, Sequence

print("✓ Setup complete!")

## 1. Custom LLM Wrapper

Create a wrapper around any LLM API or local model:

In [None]:
from llama_index.core.llms import (
    CustomLLM,
    CompletionResponse,
    CompletionResponseGen,
    LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field

class MockLLM(CustomLLM):
    """A mock LLM for testing and demonstration."""
    
    model_name: str = Field(default="mock-model")
    response_prefix: str = Field(default="Mock response: ")
    
    @property
    def metadata(self) -> LLMMetadata:
        return LLMMetadata(
            context_window=4096,
            num_output=256,
            model_name=self.model_name,
        )
    
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        """Generate a completion."""
        # In real implementation, call your LLM API here
        response_text = f"{self.response_prefix}Based on your prompt about '{prompt[:50]}...', here is my response."
        return CompletionResponse(text=response_text)
    
    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        """Stream a completion."""
        response_text = f"{self.response_prefix}Streaming response for: {prompt[:30]}..."
        
        # Simulate streaming
        for word in response_text.split():
            yield CompletionResponse(text=word + " ", delta=word + " ")

print("✓ MockLLM defined!")

In [None]:
# Test the mock LLM
mock_llm = MockLLM(response_prefix="[MOCK] ")

print(f"Model: {mock_llm.metadata.model_name}")
print(f"Context window: {mock_llm.metadata.context_window}")

# Test completion
response = mock_llm.complete("What is machine learning?")
print(f"\nCompletion: {response.text}")

# Test streaming
print("\nStreaming: ", end="")
for chunk in mock_llm.stream_complete("Explain AI"):
    print(chunk.delta, end="", flush=True)
print()

### Real Custom LLM Example (with Caching)

Here's a more practical example - an LLM wrapper with caching:

In [None]:
import hashlib
from collections import OrderedDict

class CachedLLM(CustomLLM):
    """LLM wrapper with response caching."""
    
    base_llm: Any = Field(default=None)
    cache_size: int = Field(default=100)
    _cache: OrderedDict = None
    _cache_hits: int = 0
    _cache_misses: int = 0
    
    def __init__(self, base_llm: Any, cache_size: int = 100, **kwargs):
        super().__init__(base_llm=base_llm, cache_size=cache_size, **kwargs)
        self._cache = OrderedDict()
        self._cache_hits = 0
        self._cache_misses = 0
    
    @property
    def metadata(self) -> LLMMetadata:
        return self.base_llm.metadata
    
    def _get_cache_key(self, prompt: str) -> str:
        return hashlib.md5(prompt.encode()).hexdigest()
    
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        cache_key = self._get_cache_key(prompt)
        
        # Check cache
        if cache_key in self._cache:
            self._cache_hits += 1
            # Move to end (LRU)
            self._cache.move_to_end(cache_key)
            return CompletionResponse(text=self._cache[cache_key])
        
        # Cache miss - call base LLM
        self._cache_misses += 1
        response = self.base_llm.complete(prompt, **kwargs)
        
        # Store in cache
        self._cache[cache_key] = response.text
        
        # Evict if necessary
        while len(self._cache) > self.cache_size:
            self._cache.popitem(last=False)
        
        return response
    
    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        # Streaming doesn't use cache
        return self.base_llm.stream_complete(prompt, **kwargs)
    
    def get_stats(self) -> dict:
        total = self._cache_hits + self._cache_misses
        return {
            "cache_hits": self._cache_hits,
            "cache_misses": self._cache_misses,
            "hit_rate": self._cache_hits / total if total > 0 else 0,
            "cache_size": len(self._cache),
        }

print("✓ CachedLLM defined!")

In [None]:
# Test cached LLM
base_llm = OpenAI(model="gpt-4o-mini")
cached_llm = CachedLLM(base_llm=base_llm, cache_size=50)

# First call (cache miss)
print("First call (should be cache miss)...")
response1 = cached_llm.complete("What is Python?")
print(f"Stats: {cached_llm.get_stats()}")

# Second call with same prompt (cache hit)
print("\nSecond call with same prompt (should be cache hit)...")
response2 = cached_llm.complete("What is Python?")
print(f"Stats: {cached_llm.get_stats()}")

## 2. Custom Embedding Model

Create custom embeddings for specialized use cases:

In [None]:
from llama_index.core.embeddings import BaseEmbedding
import numpy as np

class SimpleHashEmbedding(BaseEmbedding):
    """A simple hash-based embedding for demonstration.
    
    In practice, you would wrap a real embedding model here.
    """
    
    embed_dim: int = Field(default=128)
    
    def __init__(self, embed_dim: int = 128, **kwargs):
        super().__init__(embed_dim=embed_dim, **kwargs)
    
    def _get_text_embedding(self, text: str) -> List[float]:
        """Get embedding for a single text."""
        # Simple hash-based embedding (for demonstration only!)
        # In practice, use a real embedding model
        np.random.seed(hash(text) % (2**32))
        embedding = np.random.randn(self.embed_dim).tolist()
        # Normalize
        norm = np.linalg.norm(embedding)
        return [x / norm for x in embedding]
    
    def _get_query_embedding(self, query: str) -> List[float]:
        """Get embedding for a query."""
        return self._get_text_embedding(query)
    
    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Async query embedding."""
        return self._get_query_embedding(query)
    
    async def _aget_text_embedding(self, text: str) -> List[float]:
        """Async text embedding."""
        return self._get_text_embedding(text)

print("✓ SimpleHashEmbedding defined!")

In [None]:
# Test custom embedding
custom_embed = SimpleHashEmbedding(embed_dim=64)

text1 = "Machine learning is fascinating"
text2 = "Deep learning uses neural networks"

emb1 = custom_embed.get_text_embedding(text1)
emb2 = custom_embed.get_text_embedding(text2)

print(f"Embedding dimension: {len(emb1)}")
print(f"Embedding 1 (first 5): {emb1[:5]}")
print(f"Embedding 2 (first 5): {emb2[:5]}")

# Calculate similarity
similarity = np.dot(emb1, emb2)
print(f"\nCosine similarity: {similarity:.4f}")

## 3. Custom Node Parser

Create custom chunking strategies:

In [None]:
from llama_index.core.node_parser import NodeParser
from llama_index.core.schema import Document, TextNode, BaseNode
import re

class MarkdownHeaderParser(NodeParser):
    """Parse documents by markdown headers."""
    
    min_chunk_size: int = Field(default=50)
    
    def __init__(self, min_chunk_size: int = 50, **kwargs):
        super().__init__(min_chunk_size=min_chunk_size, **kwargs)
    
    def _parse_nodes(
        self,
        nodes: Sequence[BaseNode],
        show_progress: bool = False,
        **kwargs,
    ) -> List[BaseNode]:
        """Parse nodes by markdown headers."""
        all_nodes = []
        
        for node in nodes:
            if isinstance(node, TextNode):
                text = node.text
                parsed = self._split_by_headers(text, node.metadata)
                all_nodes.extend(parsed)
            else:
                all_nodes.append(node)
        
        return all_nodes
    
    def _split_by_headers(self, text: str, metadata: dict) -> List[TextNode]:
        """Split text by markdown headers."""
        # Find all headers
        header_pattern = r'^(#{1,6})\s+(.+)$'
        
        lines = text.split('\n')
        sections = []
        current_section = []
        current_header = None
        current_level = 0
        
        for line in lines:
            match = re.match(header_pattern, line)
            if match:
                # Save previous section
                if current_section:
                    section_text = '\n'.join(current_section)
                    if len(section_text.strip()) >= self.min_chunk_size:
                        sections.append({
                            'header': current_header,
                            'level': current_level,
                            'text': section_text,
                        })
                
                # Start new section
                current_level = len(match.group(1))
                current_header = match.group(2)
                current_section = [line]
            else:
                current_section.append(line)
        
        # Don't forget last section
        if current_section:
            section_text = '\n'.join(current_section)
            if len(section_text.strip()) >= self.min_chunk_size:
                sections.append({
                    'header': current_header or 'Introduction',
                    'level': current_level,
                    'text': section_text,
                })
        
        # Create nodes
        nodes = []
        for section in sections:
            node_metadata = {
                **metadata,
                'section_header': section['header'],
                'header_level': section['level'],
            }
            nodes.append(TextNode(
                text=section['text'],
                metadata=node_metadata,
            ))
        
        return nodes if nodes else [TextNode(text=text, metadata=metadata)]

print("✓ MarkdownHeaderParser defined!")

In [None]:
# Test custom parser
test_markdown = """
# Introduction

This document explains machine learning concepts.
It covers the basics and some advanced topics.

## Types of Learning

There are several types of machine learning:
- Supervised learning
- Unsupervised learning
- Reinforcement learning

### Supervised Learning

Supervised learning uses labeled data to train models.
Examples include classification and regression.

### Unsupervised Learning

Unsupervised learning finds patterns in unlabeled data.
Common techniques include clustering and dimensionality reduction.

## Conclusion

Machine learning is a powerful tool for data analysis.
"""

# Create document and parse
doc = Document(text=test_markdown, metadata={"source": "test"})
parser = MarkdownHeaderParser(min_chunk_size=30)

nodes = parser.get_nodes_from_documents([doc])

print(f"Parsed into {len(nodes)} sections:\n")
for i, node in enumerate(nodes):
    print(f"Section {i+1}:")
    print(f"  Header: {node.metadata.get('section_header', 'N/A')}")
    print(f"  Level: {node.metadata.get('header_level', 0)}")
    print(f"  Text preview: {node.text[:50]}...")
    print()

## 4. Custom Callbacks for Monitoring

Create callbacks to monitor and log operations:

In [None]:
from llama_index.core.callbacks import CallbackManager, CBEventType, LlamaDebugHandler
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
from typing import Dict, Any, Optional
import time

class PerformanceCallbackHandler(BaseCallbackHandler):
    """Track performance metrics for LlamaIndex operations."""
    
    def __init__(self):
        super().__init__(
            event_starts_to_ignore=[],
            event_ends_to_ignore=[],
        )
        self.event_times: Dict[str, list] = {}
        self._start_times: Dict[str, float] = {}
    
    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        self._start_times[event_id] = time.time()
        return event_id
    
    def on_event_end(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: str = "",
        **kwargs: Any,
    ) -> None:
        if event_id in self._start_times:
            elapsed = time.time() - self._start_times[event_id]
            event_name = event_type.value
            
            if event_name not in self.event_times:
                self.event_times[event_name] = []
            self.event_times[event_name].append(elapsed)
            
            del self._start_times[event_id]
    
    def start_trace(self, trace_id: Optional[str] = None) -> None:
        pass
    
    def end_trace(
        self,
        trace_id: Optional[str] = None,
        trace_map: Optional[Dict[str, List[str]]] = None,
    ) -> None:
        pass
    
    def get_stats(self) -> Dict[str, Dict[str, float]]:
        """Get statistics for all event types."""
        stats = {}
        for event_name, times in self.event_times.items():
            if times:
                stats[event_name] = {
                    "count": len(times),
                    "total_time": sum(times),
                    "avg_time": sum(times) / len(times),
                    "min_time": min(times),
                    "max_time": max(times),
                }
        return stats

print("✓ PerformanceCallbackHandler defined!")

In [None]:
# Test with callback
perf_handler = PerformanceCallbackHandler()
callback_manager = CallbackManager([perf_handler])

# Configure settings with callback
Settings.llm = OpenAI(model="gpt-4o-mini")
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
Settings.callback_manager = callback_manager

# Load documents and create index with monitoring
documents = SimpleDirectoryReader("../data/sample_docs").load_data()
index = VectorStoreIndex.from_documents(documents, show_progress=True)

print("\n✓ Index created with monitoring!")

In [None]:
# Run some queries to gather metrics
query_engine = index.as_query_engine()

queries = [
    "What is machine learning?",
    "Explain neural networks.",
    "How does Python work?",
]

print("Running queries...")
for q in queries:
    response = query_engine.query(q)
    print(f"Q: {q[:30]}... ✓")

In [None]:
# View performance statistics
print("\n=== Performance Statistics ===")
stats = perf_handler.get_stats()

for event_type, metrics in stats.items():
    print(f"\n{event_type}:")
    for metric, value in metrics.items():
        if isinstance(value, float):
            print(f"  {metric}: {value:.4f}s")
        else:
            print(f"  {metric}: {value}")

## 5. Custom Prompt Templates

Create and use custom prompt templates:

In [None]:
from llama_index.core import PromptTemplate
from llama_index.core.prompts import PromptType

# Custom QA prompt with structured output
STRUCTURED_QA_PROMPT = PromptTemplate(
    """You are an expert assistant. Answer the question using the context provided.

CONTEXT:
{context_str}

QUESTION: {query_str}

Provide your answer in the following format:

ANSWER: [Your direct answer]

CONFIDENCE: [High/Medium/Low]

REASONING: [Brief explanation of why you gave this answer]

SOURCES: [Which parts of the context you used]
""",
    prompt_type=PromptType.QUESTION_ANSWER,
)

print("✓ Custom prompt template defined!")

In [None]:
# Use custom prompt
query_engine_custom = index.as_query_engine(
    text_qa_template=STRUCTURED_QA_PROMPT,
)

response = query_engine_custom.query("What is supervised learning?")
print(response)

## 6. Summary

You've learned how to create custom components in LlamaIndex:

### Key Takeaways

| Component | Base Class | Key Methods |
|-----------|------------|-------------|
| **Custom LLM** | `CustomLLM` | `complete()`, `stream_complete()` |
| **Custom Embedding** | `BaseEmbedding` | `_get_text_embedding()` |
| **Custom Parser** | `NodeParser` | `_parse_nodes()` |
| **Custom Callback** | `BaseCallbackHandler` | `on_event_start()`, `on_event_end()` |

### When to Build Custom Components

1. **Custom LLM**: Wrap proprietary APIs, add caching/logging
2. **Custom Embedding**: Use specialized embedding models
3. **Custom Parser**: Domain-specific document structures
4. **Custom Callbacks**: Monitoring, debugging, analytics

### Next Steps

In the Specialty section, we'll explore multimodal RAG, GraphRAG, and LlamaCloud.

---

## Exercises

1. **Rate-limited LLM**: Create an LLM wrapper with rate limiting

2. **Semantic chunker**: Build a parser that uses embeddings to find semantic boundaries

3. **Cost tracker**: Create a callback that estimates API costs

4. **Hybrid embedding**: Combine multiple embedding models

In [None]:
# Exercise space
# Build your own custom components here!