# Guide to Subclassing LUMEModel

This guide demonstrates how to create custom implementations of the `LUMEModel` abstract base class for creating virtual accelerator models and digital twins.

## Overview

The `LUMEModel` class provides a standardized interface for:
- Getting measurements/state from simulators
- Setting control parameters
- Resetting simulator state
- Variable validation and type safety

## Key Concepts

1. **Abstract Methods**: Must be implemented by subclasses
   - `_get()`: Internal method to retrieve variable values
   - `_set()`: Internal method to set variables and run simulation
   - `reset()`: Reset simulator to initial state
   - `supported_variables` property: Define available variables

2. **Variable Types**: Define inputs/outputs with validation
   - `ScalarVariable`: For float values with ranges and units
   - Custom variables: For specialized validation

3. **Workflow**: The model handles validation, then calls your internal methods

In [None]:
# Import required modules
from typing import Any

import yaml

from lume.model import LUMEModel
from lume.variables import ScalarVariable

## Example 1: Simple Mathematical Model

Let's start with a basic model that performs simple mathematical operations. This example demonstrates:
- Defining input and output variables
- Implementing basic calculations
- State management and reset functionality

In [None]:
class SimpleMathModel(LUMEModel):
    """
    A simple mathematical model that demonstrates basic LUMEModel implementation.

    This model computes:
    - sum_output = input_a + input_b
    """

    def __init__(self):
        """Initialize the model with default state."""
        # Define initial values
        self._initial_state = {"input_a": 1.0, "input_b": 1.0, "sum_output": 2.0}
        # Current state (will be modified during simulation)
        self._state = self._initial_state.copy()

        # Define supported variables
        self._variables = {
            "input_a": ScalarVariable(
                name="input_a",
                default_value=1.0,
                value_range=(-10.0, 10.0),
                unit="dimensionless",
                read_only=False,
            ),
            "input_b": ScalarVariable(
                name="input_b",
                default_value=1.0,
                value_range=(-10.0, 10.0),
                unit="dimensionless",
                read_only=False,
            ),
            "sum_output": ScalarVariable(
                name="sum_output",
                default_value=2.0,
                unit="dimensionless",
                read_only=True,  # This is computed, not set directly
            ),
        }

    @property
    def supported_variables(self) -> dict[str, ScalarVariable]:
        """Return the dictionary of supported variables."""
        return self._variables

    def _get(self, names: list[str]) -> dict[str, Any]:
        """
        Internal method to retrieve current values for specified variables.

        Parameters
        ----------
        names : list[str]
            List of variable names to retrieve

        Returns
        -------
        dict[str, Any]
            Dictionary mapping variable names to their current values
        """
        return {name: self._state[name] for name in names}

    def _set(self, values: dict[str, Any]) -> None:
        """
        Internal method to set input variables and compute outputs.

        This method:
        1. Updates input variables in the state
        2. Performs calculations to update output variables
        3. Stores results in the state

        Parameters
        ----------
        values : dict[str, Any]
            Dictionary of variable names and values to set
        """
        # Update input values in state
        for name, value in values.items():
            self._state[name] = value

        # Perform calculations to update outputs
        input_a = self._state["input_a"]
        input_b = self._state["input_b"]

        # Calculate outputs
        self._state["sum_output"] = input_a + input_b

    def reset(self) -> None:
        """Reset the model to its initial state."""
        self._state = self._initial_state.copy()

### Using the SimpleMathModel

Now let's test our model implementation:

In [None]:
# Create an instance of our model
math_model = SimpleMathModel()

# Check initial state
print("Initial state:")
all_vars = list(math_model.supported_variables.keys())
initial_values = math_model.get(all_vars)
for name, value in initial_values.items():
    print(f"  {name}: {value}")

# Set new input values
print("\nSetting input_a=3.0, input_b=4.0")
math_model.set({"input_a": 3.0, "input_b": 4.0})

# Get updated values
updated_values = math_model.get(all_vars)
print("\nUpdated state:")
for name, value in updated_values.items():
    print(f"  {name}: {value}")

In [None]:
# Test reset functionality
print("\nTesting reset functionality:")
math_model.reset()
reset_values = math_model.get(all_vars)
print("State after reset:")
for name, value in reset_values.items():
    print(f"  {name}: {value}")

# Test validation - try to set read-only variable
print("\nTesting validation - attempting to set read-only variable:")
try:
    math_model.set({"sum_output": 999.0})
except ValueError as e:
    print(f"  Error (expected): {e}")

# Test validation - try to set out-of-range value
print("\nTesting validation - setting out-of-range value:")
math_model.set({"input_a": 15.0})  # Outside range [-10, 10]
out_of_range_values = math_model.get(["input_a", "sum_output"])
print("Values after setting out-of-range input:")
for name, value in out_of_range_values.items():
    print(f"  {name}: {value}")

## Best Practices and Key Concepts

### 1. **Variable Definition**
- Use appropriate `ScalarVariable` ranges to prevent invalid inputs
- Set `read_only=True` for computed/output variables
- Include units for physical quantities
- Choose meaningful default values

### 2. **Calculation Organization**
- Implement all physics/math in `_set()` method
- Validate inputs before performing calculations
- Handle edge cases and boundary conditions

### 3. **Error Handling**
- Let the base class handle variable validation
- Focus on domain-specific validation in your implementation
- Provide meaningful error messages for invalid states

### 4. **Performance Considerations**
- Cache expensive computations when possible
- Only recalculate what changes when parameters are updated

## Example 2: Dynamic Variable Generation

Sometimes you need to generate variable definitions programmatically based on configuration, data files, or runtime parameters. This example shows how to create variables on the fly:

In [None]:
class DynamicVariableModel(LUMEModel):
    """Model that generates variables dynamically based on configuration."""

    def __init__(self, num_inputs=2):
        self.num_inputs = num_inputs
        self._initial_state = {}
        self._variables = {}

        # Create input variables dynamically
        for i in range(num_inputs):
            input_name = f"input_{i + 1}"
            self._initial_state[input_name] = 0.0
            self._variables[input_name] = ScalarVariable(
                name=input_name,
                default_value=0.0,
                value_range=(-5.0, 5.0),
                unit="units",
                read_only=False,
            )

        # Create simple sum output
        self._initial_state["sum_output"] = 0.0
        self._variables["sum_output"] = ScalarVariable(
            name="sum_output", default_value=0.0, unit="units", read_only=True
        )

        self._state = self._initial_state.copy()
        self._compute_outputs()

    @property
    def supported_variables(self) -> dict[str, ScalarVariable]:
        """Return dynamically generated variables."""
        return self._variables

    def _get(self, names: list[str]) -> dict[str, Any]:
        """Get current values for specified variables."""
        return {name: self._state[name] for name in names}

    def _set(self, values: dict[str, Any]) -> None:
        """Set input values and recompute all outputs."""
        # Update input values
        for name, value in values.items():
            self._state[name] = value

        # Recompute all output variables
        self._compute_outputs()

    def _compute_outputs(self):
        """Compute all output variables from current input values."""
        # Get all input values
        input_values = [self._state[f"input_{i + 1}"] for i in range(self.num_inputs)]

        # Compute outputs
        self._state["sum_output"] = sum(input_values)

    def reset(self) -> None:
        """Reset to initial state."""
        self._state = self._initial_state.copy()


# Demonstrate dynamic variable generation
print("Creating model with 4 inputs")
dynamic_model = DynamicVariableModel(num_inputs=4)

In [None]:
# Test the dynamic model
print("\nSetting inputs: input_1=1.0, input_2=2.0, input_3=3.0")
dynamic_model.set({"input_1": 1.0, "input_2": 2.0, "input_3": 3.0})

result = dynamic_model.get(["sum_output"])
print(f"Sum result: {result['sum_output']:.1f}")

## Example 3: YAML-Defined Variables

For complex models, it's often useful to define variables in external configuration files. This example shows how to load variable definitions from a YAML string:

In [None]:
# Simple mathematical model configuration
simple_yaml = """
variables:
  x:
    type: ScalarVariable
    default_value: 1.0
    value_range: [0.0, 10.0]
    unit: "units"

  y:
    type: ScalarVariable
    default_value: 2.0
    value_range: [0.0, 10.0]
    unit: "units"

  result:
    type: ScalarVariable
    default_value: 3.0
    unit: "units"
    read_only: true

model_config:
  name: "SimpleModel"
  description: "Basic addition model"
  equations:
    result: "x + y"
"""

print("Simple YAML configuration created!")
print("Variables:", list(yaml.safe_load(simple_yaml)["variables"].keys()))

In [None]:
class YAMLConfiguredModel(LUMEModel):
    """Model that loads variable definitions from YAML configuration."""

    def __init__(self, yaml_config_string):
        # Parse YAML configuration
        self.config = yaml.safe_load(yaml_config_string)
        self.equations_config = self.config.get("model_config", {}).get("equations", {})

        # Initialize state and variables from YAML
        self._variables = {}
        self._initial_state = {}

        # Create variables from YAML configuration
        for var_name, var_config in self.config["variables"].items():
            self._variables[var_name] = ScalarVariable(name=var_name, **var_config)
            self._initial_state[var_name] = var_config.get("default_value", 0.0)

        # Copy initial state to current state
        self._state = self._initial_state.copy()
        # Perform initial calculations
        self._compute_from_equations()

    @property
    def supported_variables(self):
        return self._variables

    def _get(self, names):
        return {name: self._state[name] for name in names}

    def _set(self, values):
        # Update input values
        for name, value in values.items():
            self._state[name] = value
        # Recompute outputs using equations
        self._compute_from_equations()

    def _compute_from_equations(self):
        """Compute output variables using equations from YAML config."""
        for output_var, equation in self.equations_config.items():
            if output_var in self._state:
                # Create evaluation context with current state
                eval_context = self._state.copy()
                # Evaluate equation
                result = eval(equation, {"__builtins__": {}}, eval_context)
                self._state[output_var] = float(result)

    def reset(self):
        self._state = self._initial_state.copy()

In [None]:
# Test the YAML model
print("\nTesting YAML model:")
yaml_model = YAMLConfiguredModel(simple_yaml)
yaml_model.set({"x": 3.0, "y": 4.0})
result = yaml_model.get(["result"])
print(f"x=3.0, y=4.0 â†’ result={result['result']:.1f}")