# Ember Quickstart: From Zero to Advanced AI Systems

Welcome to Ember - a principled framework for building production AI systems. This notebook takes you through all key concepts, from basic model calls to advanced JAX integration with learnable components.

**What you'll learn:**
1. Setting up and making your first LLM call
2. Working with the Models API for multi-provider orchestration
3. Efficient data processing with streaming-first design
4. Building composable AI systems with Operators
5. Creating ensemble and judge systems for evaluation
6. Zero-config optimization with XCS
7. Advanced: JAX gradients on mixed structures with learnable routing

**Time to complete:** ~30 minutes

**Prerequisites:** Basic Python knowledge

## Setup and Environment Check

First, let's ensure Ember is properly installed and configured. If you haven't run the setup wizard yet, we'll guide you through it.

In [ ]:
# Check if Ember is installed
try:
    import ember
    print(f"✅ Ember {ember.__version__} is installed")
except ImportError:
    print("❌ Ember not found. Please install with: pip install ember-ai")
    print("\n💡 Tip: After installation, run 'uv run ember setup' for interactive configuration")
    raise

# Check for API keys
import os
api_keys_found = []
for key in ["OPENAI_API_KEY", "ANTHROPIC_API_KEY", "GOOGLE_API_KEY"]:
    if os.getenv(key):
        api_keys_found.append(key.replace("_API_KEY", ""))

if api_keys_found:
    print(f"✅ Found API keys for: {', '.join(api_keys_found)}")
else:
    print("⚠️  No API keys found. Run 'uv run ember setup' to configure")

# Import key modules
from ember.api import models, operators, op, data, stream
from ember.api.xcs import jit
import jax
import jax.numpy as jnp
import optax
import time

print("\n✅ All imports successful! Let's begin.")

## Module 1: Models API - Multi-Provider LLM Orchestration

Ember provides a unified interface for OpenAI, Anthropic, Google, and more. Let's explore available models and key patterns.

In [ ]:
# Discover available models
print("Available providers:")
providers = models.providers()
for provider in providers:
    print(f"  - {provider}")

print("\nAvailable models:")
all_models = models.list()

# Group by provider
from collections import defaultdict
by_provider = defaultdict(list)
for model in all_models:
    if model.startswith('gpt'):
        by_provider['openai'].append(model)
    elif model.startswith('claude'):
        by_provider['anthropic'].append(model)
    elif model.startswith('gemini'):
        by_provider['google'].append(model)

for provider, models in sorted(by_provider.items()):
    print(f"\n{provider.title()} ({len(models)} models):")
    for model in sorted(models)[:5]:  # Show first 5
        print(f"  - {model}")
    if len(models) > 5:
        print(f"  ... and {len(models) - 5} more")

In [ ]:
# Pattern 1: Direct invocation - perfect for one-off calls
response = models("gpt-4o", "What is machine learning in one sentence?")
print("Direct call response:")
print(f"  Text: {response.text}")
print(f"  Cost: ${response.usage['cost']:.4f}")
print(f"  Tokens: {response.usage['total_tokens']}")

In [ ]:
# Pattern 2: Model binding - reusable configurations
creative_writer = models.instance(
    "claude-3.5-sonnet-20241022",
    temperature=0.8,
    system="You are a creative writer who uses vivid metaphors."
)

technical_writer = models.instance(
    "gpt-4o",
    temperature=0.2,
    system="You are a technical writer who values precision and clarity."
)

prompt = "Explain how a neural network works"

print("Creative explanation:")
print(creative_writer(prompt).text[:200] + "...\n")

print("Technical explanation:")
print(technical_writer(prompt).text[:200] + "...")

In [ ]:
# Pattern 3: Cost tracking and usage monitoring
# Make a few calls to accumulate usage
for i in range(3):
    models("gpt-3.5-turbo", f"Generate a random fact #{i}")

# Get usage summary
from ember.models.registry import get_global_registry
registry = get_global_registry()
summary = registry.get_usage_summary("gpt-3.5-turbo")

print("Usage Summary:")
print(f"  Total calls: {summary.call_count}")
print(f"  Total tokens: {summary.total_tokens}")
print(f"  Total cost: ${summary.cost_usd:.4f}")
print(f"  Avg tokens/call: {summary.avg_tokens_per_call:.1f}")

## Module 2: Data API - Streaming-First Processing

Ember's data module is designed for memory-efficient processing of large datasets. Let's explore streaming patterns.

In [None]:
# Stream data without loading entire dataset into memory
print("Streaming first 5 MMLU questions:\n")

for i, item in enumerate(stream("mmlu", subset="elementary_mathematics")):
    if i >= 5:
        break
    print(f"Q{i+1}: {item['question']}")
    print(f"A: {item['answer']}")
    if 'choices' in item:
        print(f"Choices: {list(item['choices'].values())}")
    print()

In [None]:
# Advanced: Build a data processing pipeline
def add_difficulty_score(item):
    """Add a difficulty score based on question length."""
    item["difficulty_score"] = len(item["question"]) / 100.0
    return item

# Chain operations for efficient processing
hard_questions = (
    stream("mmlu", subset="elementary_mathematics")
    .transform(add_difficulty_score)
    .filter(lambda x: x["difficulty_score"] > 0.5)
    .first(3)
)

print("Hard questions (based on length):")
for q in hard_questions:
    print(f"- {q['question'][:100]}...")
    print(f"  Difficulty: {q['difficulty_score']:.2f}\n")

## Module 3: Operators - Composable AI Components

Operators are the building blocks for complex AI systems. They can be simple functions or sophisticated components.

In [ ]:
# Level 1: Simple function operator
@op
def sentiment_analyzer(text: str) -> dict:
    """Analyze sentiment of text."""
    prompt = f"Analyze the sentiment of this text. Return ONLY 'positive', 'negative', or 'neutral': {text}"
    sentiment = models("gpt-3.5-turbo", prompt).text.strip().lower()
    return {"text": text, "sentiment": sentiment}

# Test it
result = sentiment_analyzer("I love this new framework! It makes AI development so much easier.")
print(f"Sentiment: {result['sentiment']}")

In [ ]:
# Level 2: Class-based operator with state
from ember.operators import Operator

class Summarizer(Operator):
    def __init__(self, style="concise"):
        self.style = style
        self.model = models.instance(
            "gpt-3.5-turbo",
            system=f"You are a {style} summarizer. Keep summaries brief."
        )
    
    def forward(self, text: str) -> str:
        prompt = f"Summarize this text in a {self.style} way: {text}"
        return self.model(prompt).text

# Create different summarizers
technical_summary = Summarizer("technical")
simple_summary = Summarizer("simple")

text = """Neural networks are computational models inspired by biological neural networks. 
They consist of interconnected nodes (neurons) organized in layers. Each connection has 
a weight that is adjusted during training through backpropagation."""

print("Technical summary:")
print(technical_summary(text))
print("\nSimple summary:")
print(simple_summary(text))

In [None]:
# Composition patterns with different models
from ember.operators import chain, ensemble

# Chain operators sequentially - each optimized for its task
@op
def extract_key_points(text: str) -> list:
    # GPT-4o for structured extraction
    prompt = f"Extract 3 key points from this text as a Python list: {text}"
    response = models("gpt-4o", prompt).text
    # Simple parsing - in production use structured output
    return eval(response) if response.startswith('[') else [response]

@op 
def expand_points(points: list) -> str:
    # Claude for creative expansion
    prompt = f"Expand these points into a coherent, engaging paragraph: {points}"
    return models("claude-3-5-sonnet-20241022", prompt).text

@op
def translate_to_spanish(text: str) -> str:
    # Gemini for translation tasks
    prompt = f"Translate to Spanish: {text}"
    return models("gemini-pro", prompt).text

# Create a multi-model pipeline
analysis_pipeline = chain(
    extract_key_points,    # GPT-4o extracts structure
    expand_points,         # Claude expands creatively
    sentiment_analyzer     # Claude analyzes sentiment
)

# For translation pipeline
translation_pipeline = chain(
    extract_key_points,    # GPT-4o extracts key points
    expand_points,         # Claude expands
    translate_to_spanish   # Gemini translates
)

result = analysis_pipeline(text)
print("Multi-model pipeline result:")
print(f"Final sentiment of expanded summary: {result['sentiment']}")

print("\n" + "="*50 + "\n")

# Show translation pipeline
translation_result = translation_pipeline(text)
print("Translation pipeline result (Spanish):")
print(translation_result[:200] + "...")

## Module 4: Ensemble & Judge System for MMLU Evaluation

Let's build a sophisticated evaluation system using ensembles of experts and judges.

In [ ]:
# Create specialized experts for different subjects
math_expert = models.instance(
    "gpt-4o",
    system="You are a mathematics professor. Answer questions precisely with step-by-step reasoning."
)

science_expert = models.instance(
    "claude-3-sonnet", 
    system="You are a science teacher. Explain concepts clearly using examples."
)

general_expert = models.instance(
    "gpt-3.5-turbo",
    system="You are a knowledgeable assistant. Provide accurate, well-reasoned answers."
)

# Function to format MMLU questions
def format_mmlu_question(item):
    """Format MMLU item for expert evaluation."""
    question = item["question"]
    if "choices" in item:
        choices_text = "\n".join([f"{k}: {v}" for k, v in item["choices"].items()])
        return f"{question}\n\nChoices:\n{choices_text}\n\nAnswer with the letter only."
    return question

In [None]:
# Create a multi-provider ensemble evaluation system
def ensemble_evaluate(question_item, experts_dict):
    """Use multiple experts from different providers."""
    formatted_q = format_mmlu_question(question_item)
    
    # Get answers from all experts
    expert_answers = []
    for name, expert in experts_dict.items():
        answer = expert(formatted_q).text.strip()
        # Extract model info from the bound instance
        model_id = expert._model if hasattr(expert, '_model') else name
        expert_answers.append({
            "expert_name": name,
            "model": model_id,
            "answer": answer,
            "confidence": 0.7 + (len(expert_answers) * 0.05)  # Simulated confidence
        })
    
    return expert_answers

# Create a judge using Claude (excellent at analysis and comparison)
def judge_answers(question_item, expert_answers):
    """Use Claude as judge to select the best answer from multiple providers."""
    judge_prompt = f"""Given this question: {question_item['question']}

Expert answers from different AI providers:
{chr(10).join([f"{a['expert_name']} ({a['model']}): {a['answer']}" for a in expert_answers])}

Select the best answer based on:
1. Correctness and accuracy
2. Clarity of reasoning
3. Appropriateness to the question type

Return ONLY the expert name (math_expert, science_expert, general_expert, or speed_expert)."""
    
    # Claude is excellent at following complex instructions
    judge = models.instance("claude-3-haiku", temperature=0.1)
    judge_response = judge(judge_prompt).text.strip()
    
    # Find the selected answer
    for ans in expert_answers:
        if ans["expert_name"] in judge_response:
            return ans["answer"], ans["expert_name"], ans["model"]
    
    # Fallback to highest confidence
    best = max(expert_answers, key=lambda x: x["confidence"])
    return best["answer"], best["expert_name"], best["model"]

# Test on a sample MMLU question
sample_question = next(iter(stream("mmlu", subset="elementary_mathematics")))
print(f"Question: {sample_question['question']}")
print(f"Correct answer: {sample_question['answer']}\n")

# Run multi-provider ensemble evaluation
experts = {
    "math_expert": math_expert,
    "science_expert": science_expert,
    "general_expert": general_expert,
    "speed_expert": speed_expert
}

expert_answers = ensemble_evaluate(sample_question, experts)

print("Expert answers from different providers:")
for ans in expert_answers:
    print(f"  {ans['expert_name']} ({ans['model']}): {ans['answer']}")

# Judge selects best answer
final_answer, selected_expert, selected_model = judge_answers(sample_question, expert_answers)
print(f"\nJudge (Claude) selected: {selected_expert} ({selected_model})")
print(f"Final answer: {final_answer}")
print(f"Correct: {final_answer == sample_question['answer']}")

print("\nThis demonstrates true multi-provider ensemble:")
print("- OpenAI (GPT-4o) for mathematical reasoning")
print("- Anthropic (Claude 3.5) for scientific explanation") 
print("- Google (Gemini) for general knowledge")
print("- OpenAI (GPT-4o-mini) for quick responses")
print("- Anthropic (Claude Haiku) as judge")

## Module 5: XCS - Zero-Config Optimization

XCS automatically optimizes your Python functions by detecting parallelization opportunities. Let's see it in action.

In [None]:
# Simulate slow operations that can be parallelized
def slow_operation(x, delay=0.1):
    """Simulate a slow operation."""
    time.sleep(delay)
    return x * 2

# Without optimization
def process_sequential(items):
    results = []
    for item in items:
        # These operations are independent
        r1 = slow_operation(item, 0.05)
        r2 = slow_operation(item * 2, 0.05)
        r3 = slow_operation(item * 3, 0.05)
        results.append(r1 + r2 + r3)
    return results

# With XCS optimization
@jit
def process_optimized(items):
    results = []
    for item in items:
        # XCS detects these can run in parallel!
        r1 = slow_operation(item, 0.05)
        r2 = slow_operation(item * 2, 0.05) 
        r3 = slow_operation(item * 3, 0.05)
        results.append(r1 + r2 + r3)
    return results

# Compare performance
test_items = [1, 2, 3, 4, 5]

print("Sequential execution:")
start = time.time()
result_seq = process_sequential(test_items)
seq_time = time.time() - start
print(f"Time: {seq_time:.2f}s")

print("\nOptimized execution (first run includes tracing):")
start = time.time()
result_opt = process_optimized(test_items)
opt_time_first = time.time() - start
print(f"Time: {opt_time_first:.2f}s")

print("\nOptimized execution (subsequent runs):")
start = time.time()
result_opt = process_optimized(test_items)
opt_time = time.time() - start
print(f"Time: {opt_time:.2f}s")
print(f"\nSpeedup: {seq_time/opt_time:.1f}x")

# Show optimization details
if hasattr(process_optimized, '_xcs_explain'):
    print("\nOptimization explanation:")
    print(process_optimized._xcs_explain())

In [ ]:
# Real-world example: Parallel LLM calls
@jit
def analyze_text_parallel(text):
    """Analyze text from multiple perspectives in parallel."""
    # These LLM calls are independent and will run in parallel
    sentiment = models("gpt-3.5-turbo", f"Sentiment of: {text}").text
    summary = models("gpt-3.5-turbo", f"Summarize: {text}").text  
    keywords = models("gpt-3.5-turbo", f"Extract keywords from: {text}").text
    
    return {
        "sentiment": sentiment,
        "summary": summary,
        "keywords": keywords
    }

# Test with real text
sample_text = """Ember is a revolutionary AI framework that makes building 
production AI systems simple and efficient. It provides clean APIs, 
automatic optimization, and seamless multi-provider support."""

print("Analyzing text with parallel LLM calls...")
start = time.time()
analysis = analyze_text_parallel(sample_text)
print(f"\nCompleted in {time.time() - start:.2f}s")
print(f"\nResults:")
for key, value in analysis.items():
    print(f"{key}: {value[:100]}..." if len(value) > 100 else f"{key}: {value}")

## Module 6: Advanced - JAX Integration with Mixed Structures

Ember's deep JAX integration enables gradient-based optimization of AI systems with mixed static/dynamic components.

In [ ]:
# Create a learnable router with mixed structure
from ember.operators import Operator

class LearnableRouter(Operator):
    """Routes prompts to different models based on learned embeddings."""
    
    def __init__(self, routes: dict, embed_dim: int = 32, key: jax.Array = None):
        # Static fields (not learnable)
        self.routes = routes  # Dict of model names to model instances
        self.route_names = list(routes.keys())
        
        # Dynamic fields (learnable via gradients)
        if key is None:
            key = jax.random.PRNGKey(42)
        k1, k2 = jax.random.split(key)
        
        self.route_embeddings = jax.random.normal(k1, (len(routes), embed_dim))
        self.projection = jax.random.normal(k2, (embed_dim, 1)) 
        self.temperature = jnp.array(1.0)
    
    def compute_routing_scores(self, prompt: str) -> jnp.ndarray:
        """Compute routing scores for each model."""
        # Simple: use prompt length as feature (in practice, use embeddings)
        prompt_feature = jnp.array([len(prompt) / 100.0] * self.route_embeddings.shape[1])
        
        # Compute similarity scores
        similarities = self.route_embeddings @ prompt_feature
        scores = jax.nn.softmax(similarities / self.temperature)
        return scores
    
    def forward(self, prompt: str) -> dict:
        """Route prompt to best model and return response."""
        scores = self.compute_routing_scores(prompt)
        best_idx = jnp.argmax(scores)
        
        # Use the selected model
        selected_route = self.route_names[best_idx]
        response = self.routes[selected_route](prompt)
        
        return {
            "response": response.text,
            "selected_model": selected_route,
            "scores": scores,
            "confidence": float(jnp.max(scores))
        }

# Create router with different models
simple_model = models.instance("gpt-3.5-turbo", temperature=0.3)
creative_model = models.instance("claude-3-sonnet", temperature=0.9) 

router = LearnableRouter({
    "simple": simple_model,
    "creative": creative_model
})

# Test routing
test_prompts = [
    "What is 2+2?",
    "Write a poem about the ocean"
]

for prompt in test_prompts:
    result = router(prompt)
    print(f"\nPrompt: {prompt}")
    print(f"Selected: {result['selected_model']} (confidence: {result['confidence']:.2f})")
    print(f"Response: {result['response'][:100]}...")

In [None]:
# Train the router using gradients
def routing_loss(router, prompt: str, target_route: str) -> float:
    """Compute loss for routing decision."""
    scores = router.compute_routing_scores(prompt)
    target_idx = router.route_names.index(target_route)
    
    # Cross-entropy loss
    return -jnp.log(scores[target_idx] + 1e-8)

# Training data: short prompts -> simple, long prompts -> creative
training_data = [
    ("What is 2+2?", "simple"),
    ("Calculate 10*5", "simple"),
    ("Define AI", "simple"),
    ("Write a story about a magical forest where time flows backwards", "creative"),
    ("Compose a sonnet about the beauty of mathematics", "creative"),
    ("Create a metaphor for machine learning using cooking", "creative")
]

# Setup optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(router)

# Training loop
print("Training router...")
for epoch in range(50):
    total_loss = 0.0
    
    for prompt, target in training_data:
        # Compute loss and gradients
        loss_val = routing_loss(router, prompt, target)
        grads = jax.grad(routing_loss)(router, prompt, target)
        
        # Update only dynamic parameters
        updates, opt_state = optimizer.update(grads, opt_state)
        
        # Apply updates using functional update pattern
        router = router.update_params(
            route_embeddings=router.route_embeddings + updates.route_embeddings,
            projection=router.projection + updates.projection,
            temperature=router.temperature + updates.temperature
        )
        
        total_loss += float(loss_val)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

print("\nTesting trained router:")
test_cases = [
    "What is 5+5?",
    "Imagine a world where colors have sounds"
]

for prompt in test_cases:
    result = router(prompt)
    print(f"\nPrompt: {prompt}")
    print(f"Routed to: {result['selected_model']} (confidence: {result['confidence']:.2f})")

In [None]:
# Demonstrate gradient flow through mixed structures
print("Analyzing gradient flow through mixed structure:\n")

# Get gradients for a sample
sample_prompt = "Write a haiku"
grads = jax.grad(routing_loss)(router, sample_prompt, "creative")

# Check which fields received gradients
print("Fields with gradients (dynamic/learnable):")
for field in ['route_embeddings', 'projection', 'temperature']:
    if hasattr(grads, field) and getattr(grads, field) is not None:
        grad_val = getattr(grads, field)
        print(f"  - {field}: shape {grad_val.shape}, norm {jnp.linalg.norm(grad_val):.4f}")

print("\nStatic fields (preserved but no gradients):")
for field in ['routes', 'route_names']:
    if hasattr(grads, field):
        print(f"  - {field}: {type(getattr(grads, field)).__name__}")

# Demonstrate JAX transformations work seamlessly
print("\nJAX transformations on mixed structures:")

# Vectorize routing decisions
batch_prompts = ["Hello", "What is AI?", "Tell me a story"]
vmapped_scores = jax.vmap(router.compute_routing_scores)(jnp.array(batch_prompts))
print(f"Batched routing scores shape: {vmapped_scores.shape}")

print("\n✅ Mixed structure gradients working correctly!")
print("\nKey insights:")
print("- JAX arrays are automatically treated as learnable parameters")
print("- Python objects (dicts, strings) are preserved as static")
print("- Gradients flow only through dynamic fields")
print("- Standard optimizers (optax) work out of the box")

## Putting It All Together

Let's combine everything we've learned into a complete system for MMLU evaluation with learnable routing.

In [ ]:
# Complete MMLU evaluation system with all components
@jit
def evaluate_mmlu_batch(questions, router, judge_model):
    """Evaluate a batch of MMLU questions using ensemble + judge + routing."""
    results = []
    
    for q in questions:
        # Step 1: Route to appropriate expert ensemble
        routing_result = router(q['question'])
        
        # Step 2: Get multiple expert opinions (parallelized by XCS)
        expert1_answer = models("gpt-3.5-turbo", format_mmlu_question(q)).text
        expert2_answer = models("gpt-3.5-turbo", format_mmlu_question(q)).text
        expert3_answer = models("gpt-3.5-turbo", format_mmlu_question(q)).text
        
        # Step 3: Judge selects best answer
        judge_prompt = f"""Select the best answer:
        Expert 1: {expert1_answer}
        Expert 2: {expert2_answer}  
        Expert 3: {expert3_answer}
        
        Return only the expert number (1, 2, or 3)."""
        
        judge_decision = judge_model(judge_prompt).text.strip()
        
        # Step 4: Record results
        selected_answer = [expert1_answer, expert2_answer, expert3_answer][int(judge_decision)-1]
        results.append({
            'question': q['question'],
            'predicted': selected_answer,
            'correct': q['answer'],
            'routing_confidence': routing_result['confidence']
        })
    
    return results

# Load a small batch of MMLU questions
mmlu_batch = list(stream("mmlu", subset="elementary_mathematics").first(3))

# Create judge model
judge = models.instance("gpt-3.5-turbo", temperature=0.1)

# Run evaluation
print("Running complete MMLU evaluation pipeline...\n")
results = evaluate_mmlu_batch(mmlu_batch, router, judge)

# Calculate accuracy
correct = sum(1 for r in results if r['predicted'].strip() == r['correct'].strip())
accuracy = correct / len(results)

print(f"Results:")
for r in results:
    print(f"Q: {r['question'][:50]}...")
    print(f"   Predicted: {r['predicted']}, Correct: {r['correct']}")
    print(f"   Routing confidence: {r['routing_confidence']:.2f}\n")

print(f"Accuracy: {accuracy:.1%} ({correct}/{len(results)})")

# Show cost summary
total_cost = sum(s.cost_usd for s in [registry.get_usage_summary(m) 
                                      for m in ["gpt-3.5-turbo"]])
print(f"\nTotal cost for this session: ${total_cost:.4f}")

## Summary and Next Steps

Congratulations! You've learned the core concepts of Ember:

1. **Models API** - Unified interface for all LLM providers with automatic cost tracking
2. **Data API** - Memory-efficient streaming for large datasets
3. **Operators** - Composable building blocks from simple functions to complex systems
4. **Ensemble & Judge** - Sophisticated evaluation patterns for accuracy
5. **XCS** - Automatic optimization without configuration
6. **JAX Integration** - Gradients on mixed structures for learnable AI systems

### Next Steps:

1. **Explore more examples** in the `examples/` directory
2. **Read the documentation** for detailed API references
3. **Build your own operators** for your specific use cases
4. **Experiment with different models** and providers
5. **Join the community** to share ideas and get help

### Key Principles to Remember:

- **Simple things should be simple** - One-line model calls
- **Complex things should be possible** - Full JAX integration
- **Zero configuration** - Works out of the box
- **Performance by default** - Automatic optimization
- **Clean architecture** - Principled design throughout

Happy building with Ember!