# Unified Language Model Algebra: Projections + Constraints

This notebook explores how n-gram projections and schema-based constraints represent two sides of the same algebraic framework for composing and controlling language models.

## The Unified View

Both projects implement different forms of **language model algebra**:

1. **N-gram Projections**: Algebraic operations on model *inputs* and *combinations*
   - Projections transform context before feeding to models
   - Models combine through mixture operations (+, *, >>)
   - Focus: How models see and process input

2. **Schema Constraints (guidedgen)**: Algebraic operations on model *outputs*
   - Schemas constrain token generation through logit masking
   - Schemas combine through set operations (union, intersection)
   - Focus: What models are allowed to generate

Together, they form a complete **input → model → output** algebra.

In [None]:
# Setup
import sys
sys.path.append('.')
from typing import Dict, List, Set, Optional, Callable
from dataclasses import dataclass
import numpy as np

# Import our n-gram projection system
from ngram_projections.models.base import LanguageModel
from ngram_projections.models.ngram import NGramModel
from ngram_projections.projections.recency import RecencyProjection
from ngram_projections.projections.semantic import SemanticProjection

## Conceptual Bridge: Schemas as Output Projections

In [None]:
# Conceptual implementation showing how schemas work like output projections

class OutputProjection:
    """Base class for output projections (constraints)."""
    
    def project_logits(self, logits: Dict[str, float], state: dict) -> Dict[str, float]:
        """Project/constrain the output distribution."""
        raise NotImplementedError

class SchemaProjection(OutputProjection):
    """Project outputs to match a schema (like guidedgen)."""
    
    def __init__(self, valid_tokens_fn: Callable):
        self.valid_tokens_fn = valid_tokens_fn
    
    def project_logits(self, logits: Dict[str, float], state: dict) -> Dict[str, float]:
        # Get valid tokens based on current state
        valid_tokens = self.valid_tokens_fn(state)
        
        # Mask invalid tokens (set to -inf)
        return {
            token: (prob if token in valid_tokens else 0.0)
            for token, prob in logits.items()
        }

class BiasProjection(OutputProjection):
    """Bias certain tokens in the output."""
    
    def __init__(self, bias_weights: Dict[str, float]):
        self.bias_weights = bias_weights
    
    def project_logits(self, logits: Dict[str, float], state: dict) -> Dict[str, float]:
        # Apply bias weights
        biased = {}
        for token, prob in logits.items():
            bias = self.bias_weights.get(token, 1.0)
            biased[token] = prob * bias
        
        # Renormalize
        total = sum(biased.values())
        return {k: v/total for k, v in biased.items() if total > 0}

# Algebraic operations on output projections
class UnionProjection(OutputProjection):
    """Union of valid tokens from multiple projections."""
    
    def __init__(self, projections: List[OutputProjection]):
        self.projections = projections
    
    def project_logits(self, logits: Dict[str, float], state: dict) -> Dict[str, float]:
        # Union: token is valid if ANY projection allows it
        result = {token: 0.0 for token in logits}
        
        for proj in self.projections:
            projected = proj.project_logits(logits, state)
            for token, prob in projected.items():
                result[token] = max(result[token], prob)
        
        return result

class IntersectionProjection(OutputProjection):
    """Intersection of valid tokens from multiple projections."""
    
    def __init__(self, projections: List[OutputProjection]):
        self.projections = projections
    
    def project_logits(self, logits: Dict[str, float], state: dict) -> Dict[str, float]:
        # Intersection: token is valid if ALL projections allow it
        result = logits.copy()
        
        for proj in self.projections:
            projected = proj.project_logits(logits, state)
            result = {
                token: min(result.get(token, 0), projected.get(token, 0))
                for token in result
            }
        
        return result

## Complete Pipeline: Input → Model → Output Algebra

In [None]:
class AlgebraicLanguageModel:
    """Complete algebraic model with input and output projections."""
    
    def __init__(self, 
                 model: LanguageModel,
                 input_projection=None,
                 output_projection=None):
        self.model = model
        self.input_projection = input_projection
        self.output_projection = output_projection
    
    def predict(self, context: List[str], state: dict = None) -> Dict[str, float]:
        # Step 1: Apply input projection
        if self.input_projection:
            context = self.input_projection.project(context)
        
        # Step 2: Get model predictions
        logits = self.model.predict(context)
        
        # Step 3: Apply output projection
        if self.output_projection:
            logits = self.output_projection.project_logits(logits, state or {})
        
        return logits
    
    def __add__(self, other):
        """Mixture of models."""
        if isinstance(other, AlgebraicLanguageModel):
            return MixtureAlgebraicModel([self, other], [0.5, 0.5])
        return NotImplemented
    
    def __mul__(self, weight: float):
        """Weighted model for mixtures."""
        return WeightedAlgebraicModel(self, weight)
    
    def __matmul__(self, projection):
        """Apply projection with @ operator."""
        if hasattr(projection, 'project'):  # Input projection
            return AlgebraicLanguageModel(
                self.model, 
                projection,
                self.output_projection
            )
        elif hasattr(projection, 'project_logits'):  # Output projection
            return AlgebraicLanguageModel(
                self.model,
                self.input_projection,
                projection
            )
        return NotImplemented

class MixtureAlgebraicModel(AlgebraicLanguageModel):
    """Mixture of multiple algebraic models."""
    
    def __init__(self, models: List[AlgebraicLanguageModel], weights: List[float]):
        self.models = models
        self.weights = weights
        assert abs(sum(weights) - 1.0) < 1e-6, "Weights must sum to 1"
    
    def predict(self, context: List[str], state: dict = None) -> Dict[str, float]:
        result = {}
        
        for model, weight in zip(self.models, self.weights):
            pred = model.predict(context, state)
            for token, prob in pred.items():
                result[token] = result.get(token, 0) + weight * prob
        
        return result

## Example: JSON Generation with N-gram Grounding

In [None]:
# Simulate a JSON schema constraint
def json_object_validator(state: dict) -> Set[str]:
    """Returns valid tokens for JSON object generation."""
    depth = state.get('depth', 0)
    in_string = state.get('in_string', False)
    
    if in_string:
        # Any token except quotes (simplified)
        return {t for t in "abcdefghijklmnopqrstuvwxyz0123456789 "}
    
    if depth == 0:
        return {'{'}  # Must start with opening brace
    
    # Simplified: allow keys, values, and closing brace
    return {'"', ':', ',', '}'}

# Create models
ngram = NGramModel(n=3)
ngram.train(["{", '"name"', ":", '"John"', ",", '"age"', ":", "30", "}"])
ngram.train(["{", '"id"', ":", '"123"', ",", '"type"', ":", '"user"', "}"])

# Create a mock LLM that tends to generate valid JSON
class MockJSONLLM(LanguageModel):
    def predict(self, context: List[str]) -> Dict[str, float]:
        # Simplified: prefer JSON-like tokens
        if context and context[-1] == '{':
            return {'"': 0.9, 'null': 0.1}
        elif context and context[-1] == '"':
            return {'name': 0.3, 'id': 0.3, 'type': 0.2, 'age': 0.2}
        elif context and context[-1] == ':':
            return {'"': 0.5, '123': 0.25, '30': 0.25}
        else:
            return {',': 0.4, '}': 0.3, ':': 0.3}

llm = MockJSONLLM()

# Create projections
recency = RecencyProjection(max_suffix_len=3)
json_schema = SchemaProjection(json_object_validator)

# Compose: N-gram with recency + LLM, both constrained by JSON schema
model = ((0.3 * AlgebraicLanguageModel(ngram) @ recency) + 
         (0.7 * AlgebraicLanguageModel(llm))) @ json_schema

# Generate
context = ['{']
state = {'depth': 1, 'in_string': False}

print("Generating JSON with n-gram grounding and schema constraints:")
print("Context:", context)
print("Predictions:", model.predict(context, state))

## Theoretical Unification: Categories and Functors

In [None]:
# The algebraic structure can be formalized using category theory

@dataclass
class LanguageModelCategory:
    """
    Category where:
    - Objects: Language models
    - Morphisms: Projections/transformations
    - Composition: Function composition
    - Identity: Identity projection
    """
    
    class Object:
        """A language model as an object in the category."""
        def __init__(self, model: LanguageModel):
            self.model = model
    
    class Morphism:
        """A projection/transformation as a morphism."""
        def __init__(self, transform: Callable):
            self.transform = transform
        
        def __rshift__(self, other):
            """Compose morphisms."""
            return LanguageModelCategory.Morphism(
                lambda x: other.transform(self.transform(x))
            )
    
    @staticmethod
    def identity():
        """Identity morphism."""
        return LanguageModelCategory.Morphism(lambda x: x)

# Functors between categories
class ProjectionFunctor:
    """Maps from token sequences to projected sequences."""
    
    def __init__(self, projection):
        self.projection = projection
    
    def map_object(self, tokens: List[str]) -> List[str]:
        """Apply projection to tokens."""
        return self.projection.project(tokens)
    
    def map_morphism(self, f: Callable) -> Callable:
        """Lift a function on tokens to work on projected tokens."""
        return lambda x: f(self.map_object(x))

class SchemaFunctor:
    """Maps from distributions to constrained distributions."""
    
    def __init__(self, schema):
        self.schema = schema
    
    def map_object(self, logits: Dict[str, float]) -> Dict[str, float]:
        """Apply schema constraints to logits."""
        return self.schema.project_logits(logits, {})
    
    def map_morphism(self, f: Callable) -> Callable:
        """Lift a function on logits to work on constrained logits."""
        return lambda x: self.map_object(f(x))

print("Category Theory Formalization:")
print("- Objects: Language Models")
print("- Morphisms: Projections (input) and Constraints (output)")
print("- Composition: Function composition (>> operator)")
print("- Identity: No projection/constraint")
print("\nThis gives us a principled algebraic framework!")

## Practical Integration: Reliable Wikipedia-Grounded Generation

In [None]:
class WikipediaGroundedModel:
    """
    Combines:
    1. N-gram model trained on Wikipedia for factual grounding
    2. LLM for fluent generation
    3. Schema constraints for structured output
    """
    
    def __init__(self, wikipedia_ngram, llm, fact_weight=0.4):
        self.wiki_ngram = wikipedia_ngram
        self.llm = llm
        self.fact_weight = fact_weight
        
        # Define schema for factual statements
        self.fact_schema = self._create_fact_schema()
    
    def _create_fact_schema(self):
        """Schema for generating factual statements."""
        def fact_validator(state: dict) -> Set[str]:
            # Ensure generation follows fact pattern
            if state.get('generating_fact'):
                # Prefer tokens from Wikipedia n-gram
                return state.get('wiki_tokens', set())
            return set()  # All tokens valid otherwise
        
        return SchemaProjection(fact_validator)
    
    def generate_fact(self, topic: str, context: List[str]) -> Dict[str, float]:
        """
        Generate a fact about the topic, grounded in Wikipedia.
        """
        # Step 1: Find relevant Wikipedia context
        wiki_context = self._find_wiki_context(topic, context)
        
        # Step 2: Get Wikipedia n-gram predictions
        wiki_pred = self.wiki_ngram.predict(wiki_context)
        
        # Step 3: Get LLM predictions
        llm_pred = self.llm.predict(context)
        
        # Step 4: Mix with Wikipedia bias
        mixed = {}
        all_tokens = set(wiki_pred.keys()) | set(llm_pred.keys())
        
        for token in all_tokens:
            wiki_p = wiki_pred.get(token, 0.0)
            llm_p = llm_pred.get(token, 0.0)
            
            # Boost Wikipedia tokens for factual grounding
            if wiki_p > 0.01:  # Token appears in Wikipedia
                mixed[token] = self.fact_weight * wiki_p + (1 - self.fact_weight) * llm_p * 1.5
            else:
                mixed[token] = (1 - self.fact_weight) * llm_p
        
        # Step 5: Apply schema constraints
        state = {'generating_fact': True, 'wiki_tokens': set(wiki_pred.keys())}
        constrained = self.fact_schema.project_logits(mixed, state)
        
        # Normalize
        total = sum(constrained.values())
        if total > 0:
            constrained = {k: v/total for k, v in constrained.items()}
        
        return constrained
    
    def _find_wiki_context(self, topic: str, context: List[str]) -> List[str]:
        """Find relevant Wikipedia context for the topic."""
        # Simplified: use topic words as context
        return topic.lower().split() + context[-3:]

# Example usage
wiki_ngram = NGramModel(n=3)
# Train on "Wikipedia" data
wiki_ngram.train("Albert Einstein was a theoretical physicist".split())
wiki_ngram.train("Einstein developed the theory of relativity".split())
wiki_ngram.train("The theory of relativity revolutionized physics".split())

class SimpleLLM(LanguageModel):
    def predict(self, context: List[str]) -> Dict[str, float]:
        # Simple mock LLM
        return {
            "was": 0.2, "is": 0.15, "developed": 0.15,
            "the": 0.1, "theory": 0.1, "of": 0.1,
            "relativity": 0.05, "physics": 0.05, "scientist": 0.1
        }

llm = SimpleLLM()
grounded_model = WikipediaGroundedModel(wiki_ngram, llm, fact_weight=0.6)

# Generate a fact
result = grounded_model.generate_fact(
    topic="Einstein",
    context=["Albert", "Einstein"]
)

print("Wikipedia-Grounded Generation:")
print("Topic: Einstein")
print("Context: ['Albert', 'Einstein']")
print("\nTop predictions (Wikipedia-grounded):")
for token, prob in sorted(result.items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"  {token}: {prob:.3f}")

## The Complete Algebra: Composition Laws

In [None]:
# Define the algebraic laws that our system satisfies

class AlgebraicLaws:
    """
    Laws that hold in our language model algebra.
    """
    
    @staticmethod
    def commutativity_of_mixture():
        """A + B = B + A for model mixtures."""
        return "model1 + model2 == model2 + model1"
    
    @staticmethod
    def associativity_of_mixture():
        """(A + B) + C = A + (B + C) for model mixtures."""
        return "(model1 + model2) + model3 == model1 + (model2 + model3)"
    
    @staticmethod
    def distributivity_of_projection():
        """P @ (A + B) = (P @ A) + (P @ B) for input projections."""
        return "projection @ (model1 + model2) == (projection @ model1) + (projection @ model2)"
    
    @staticmethod
    def composition_of_projections():
        """(P1 >> P2) @ M = P1 @ (P2 @ M) for projection composition."""
        return "(proj1 >> proj2) @ model == proj1 @ (proj2 @ model)"
    
    @staticmethod
    def identity_projection():
        """I @ M = M for identity projection."""
        return "identity @ model == model"
    
    @staticmethod
    def schema_intersection_associativity():
        """(S1 ∩ S2) ∩ S3 = S1 ∩ (S2 ∩ S3) for schema intersection."""
        return "(schema1 & schema2) & schema3 == schema1 & (schema2 & schema3)"
    
    @staticmethod
    def schema_union_associativity():
        """(S1 ∪ S2) ∪ S3 = S1 ∪ (S2 ∪ S3) for schema union."""
        return "(schema1 | schema2) | schema3 == schema1 | (schema2 | schema3)"
    
    @staticmethod
    def de_morgans_law():
        """¬(S1 ∪ S2) = ¬S1 ∩ ¬S2 for schema complement."""
        return "not (schema1 | schema2) == (not schema1) & (not schema2)"

print("Algebraic Laws of Language Model Composition:")
print("="*50)

laws = AlgebraicLaws()
for law_name in dir(laws):
    if not law_name.startswith('_'):
        law_fn = getattr(laws, law_name)
        if callable(law_fn):
            print(f"\n{law_name.replace('_', ' ').title()}:")
            print(f"  {law_fn()}")

print("\n" + "="*50)
print("These laws enable reasoning about complex model compositions!")

## Unified Framework Summary

The combination of **n-gram projections** and **schema constraints** gives us a complete algebraic framework:

### Input Algebra (N-gram Projections)
- **Recency**: Project to longest matching suffix
- **Semantic**: Project using embeddings
- **Edit Distance**: Find similar contexts
- **Composition**: Chain projections with `>>`

### Model Algebra (Mixtures)
- **Addition**: Equal-weight mixture `model1 + model2`
- **Scaling**: Weighted contribution `0.3 * model`
- **Union**: Fallback/ensemble `model1 | model2`

### Output Algebra (Schema Constraints)
- **Intersection**: All constraints must hold `schema1 & schema2`
- **Union**: Any constraint can hold `schema1 | schema2`
- **Switch**: Discriminated unions based on fields
- **Sequence**: Different schemas per position

### The Complete Pipeline
```python
model = (
    (0.3 * ngram @ recency_projection) +     # Grounded n-gram
    (0.7 * llm @ semantic_projection)        # Semantic LLM
) @ json_schema                              # Output constraints
```

This unified algebra enables:
1. **Reliable generation** through constraints
2. **Factual grounding** through n-gram biasing
3. **Continuous learning** via n-gram updates
4. **Interpretable control** over model behavior
5. **Composable building blocks** for complex systems