# CSSR Algorithm Debugging: Golden Mean Process

This notebook provides a step-by-step walkthrough of the CSSR (Causal State Splitting Reconstruction) algorithm applied to the Golden Mean process.

## Expected Result

The Golden Mean process forbids consecutive 1s. It has exactly **2 causal states**:
- **State A**: Last symbol was 0 → can emit 0 or 1 with probability p and 1-p
- **State B**: Last symbol was 1 → must emit 0 with probability 1.0

The true entropy rate is: $h_\mu = -p \log_2(p) / (1 + p)$ where $p$ is the probability of emitting 1 after a 0.

In [None]:
import itertools
import math
from collections import defaultdict

from emic.sources import GoldenMeanSource
from emic.inference import CSSR, CSSRConfig
from emic.inference.cssr.algorithm import SuffixTree, StatePartition

## Step 1: Generate Data

Generate a sample from the Golden Mean process with p=0.5.

In [None]:
# Generate Golden Mean data
source = GoldenMeanSource(p=0.5, _seed=42)
data = list(itertools.islice(source, 10000))

# Quick statistics
print(f"Data length: {len(data)}")
print(f"Symbol counts: 0={data.count(0)}, 1={data.count(1)}")
print(f"First 50 symbols: {data[:50]}")

# Verify no consecutive 1s
consecutive_ones = sum(1 for i in range(len(data) - 1) if data[i] == 1 and data[i + 1] == 1)
print(f"Consecutive 1s (should be 0): {consecutive_ones}")

## Step 2: Build the Suffix Tree

The suffix tree stores statistics for each history (suffix). For each history, we track:
- How many times it occurred
- The distribution of next symbols

In [None]:
# Build suffix tree with max_history=3
max_history = 3
alphabet = frozenset([0, 1])
tree = SuffixTree(max_depth=max_history, alphabet=alphabet)
tree.build_from_sequence(data)

print(f"Suffix tree built with max_history={max_history}")
print(f"Alphabet: {alphabet}")
print(f"Number of histories: {len(tree)}")

In [None]:
# Examine the suffix tree entries
print("=== Suffix Tree Statistics ===")

print("\nLength 0 (empty history):")
empty = tree.get_stats(())
if empty:
    total = sum(empty.next_symbol_counts.values())
    print(f"  (): count={empty.count}, next_symbols={dict(empty.next_symbol_counts)}")
    for sym, cnt in empty.next_symbol_counts.items():
        print(f"      P({sym}|()) = {cnt / total:.4f}")

print("\nLength 1 histories:")
for suffix in [(0,), (1,)]:
    stats = tree.get_stats(suffix)
    if stats and stats.count > 0:
        total = sum(stats.next_symbol_counts.values())
        print(f"  {suffix}: count={stats.count}, next_symbols={dict(stats.next_symbol_counts)}")
        for sym, cnt in stats.next_symbol_counts.items():
            print(f"      P({sym}|{suffix}) = {cnt / total:.4f}")

print("\nLength 2 histories:")
for suffix in [(0, 0), (0, 1), (1, 0)]:
    stats = tree.get_stats(suffix)
    if stats and stats.count > 0:
        total = sum(stats.next_symbol_counts.values())
        print(f"  {suffix}: count={stats.count}, next_symbols={dict(stats.next_symbol_counts)}")
        for sym, cnt in stats.next_symbol_counts.items():
            print(f"      P({sym}|{suffix}) = {cnt / total:.4f}")

print("\nLength 3 histories:")
for suffix in [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 0, 1)]:
    stats = tree.get_stats(suffix)
    if stats and stats.count > 0:
        total = sum(stats.next_symbol_counts.values())
        print(f"  {suffix}: count={stats.count}, next_symbols={dict(stats.next_symbol_counts)}")
        for sym, cnt in stats.next_symbol_counts.items():
            print(f"      P({sym}|{suffix}) = {cnt / total:.4f}")

# Note: (1,1) should have 0 count - forbidden
stats_11 = tree.get_stats((1, 1))
print(f"\n  (1,1): count={stats_11.count if stats_11 else 0} (forbidden - should be 0)")

## Step 3: Analyze Expected State Assignments

Based on the suffix tree, we should see:
- Histories ending in 1: `(1,)`, `(0,1)`, `(0,0,1)`, etc. → All have P(0|h) = 1.0 → **Same causal state**
- Histories ending in 0: `(0,)`, `(1,0)`, `(0,0)`, etc. → All have P(0|h) ≈ P(1|h) ≈ 0.5 → **Same causal state**

The key insight: **Causal states are determined by the distribution of futures, not the specific history.**

In [None]:
# Categorize histories by their predictive distribution
def get_distribution(tree, suffix):
    """Get normalized distribution for a suffix."""
    stats = tree.get_stats(suffix)
    if not stats or stats.count == 0:
        return None
    total = sum(stats.next_symbol_counts.values())
    return {sym: cnt / total for sym, cnt in stats.next_symbol_counts.items()}


# All possible histories up to length 3 that exist in Golden Mean
histories_ending_in_0 = [(0,), (1, 0), (0, 0), (0, 1, 0), (1, 0, 0), (0, 0, 0)]
histories_ending_in_1 = [(1,), (0, 1), (1, 0, 1), (0, 0, 1)]

print("=== Histories Ending in 0 (should all have ~50/50 distribution) ===")
for h in histories_ending_in_0:
    dist = get_distribution(tree, h)
    if dist:
        print(f"  {h}: P(0)={dist.get(0, 0):.4f}, P(1)={dist.get(1, 0):.4f}")

print("\n=== Histories Ending in 1 (should all have 100% → 0) ===")
for h in histories_ending_in_1:
    dist = get_distribution(tree, h)
    if dist:
        print(f"  {h}: P(0)={dist.get(0, 0):.4f}, P(1)={dist.get(1, 0):.4f}")

## Step 4: Run CSSR and Trace State Formation

In [None]:
# Run CSSR with our configuration
from emic.inference import CSSRConfig

config = CSSRConfig(max_history=3, significance=0.05, min_count=10)
cssr = CSSR(config)

# Step through the algorithm manually
print("=== CSSR Algorithm Trace ===")

# Build suffix tree (same as above)
tree2 = SuffixTree(max_depth=config.max_history, alphabet=frozenset([0, 1]))
tree2.build_from_sequence(data)


# Helper to display partition
def show_partition(partition, label):
    print(f"\n{label}: {len(partition.state_ids())} states")
    for state_id in partition.state_ids():
        histories = partition.get_histories(state_id)
        print(f"  {state_id}: {sorted(histories, key=lambda x: (len(x), x))}")


# Initialize partition
partition = cssr._initialize_partition(tree2)
show_partition(partition, "After initialization")

In [None]:
# Split states
partition = cssr._split_states(partition, tree2)
show_partition(partition, "After splitting")

In [None]:
# Merge states
partition = cssr._merge_states(partition, tree2)
show_partition(partition, "After merging")

In [None]:
# Examine distributions in each state
print("=== State Distributions ===")
for state_id in partition.state_ids():
    histories = partition.get_histories(state_id)
    print(f"\n{state_id}:")
    for h in sorted(histories, key=lambda x: (len(x), x)):
        dist = get_distribution(tree2, h)
        if dist:
            print(f"  {h}: P(0)={dist.get(0, 0):.4f}, P(1)={dist.get(1, 0):.4f}")

## Step 5: Run Full Inference and Compare

In [None]:
# Run the full inference
result = CSSR(config).infer(data)

print(f"=== Inferred Machine ===")
print(f"Number of states: {len(result.machine.states)}")

# Expected entropy rate for Golden Mean with p=0.5
p = 0.5
expected_h = -p * math.log2(p) / (1 + p) if p > 0 else 0
print(f"Expected entropy rate: {expected_h:.4f}")

print("\nTransitions:")
for state in result.machine.states:
    print(f"  State {state.id}:")
    for t in state.transitions:
        print(f"    --{t.symbol} (p={t.probability:.3f})--> {t.target}")

In [None]:
# Trace the full iteration loop
print("=== Tracing Full CSSR Iteration ===\n")

tree3 = SuffixTree(max_depth=config.max_history, alphabet=frozenset([0, 1]))
tree3.build_from_sequence(data)

partition3 = cssr._initialize_partition(tree3)
print(f"Initial: {len(partition3.state_ids())} states")

for i in range(5):  # Max 5 iterations
    old_state_ids = set(partition3.state_ids())
    old_assignments = {h: partition3.get_state(h) for h in tree3.all_histories()}

    # Split
    partition3 = cssr._split_states(partition3, tree3)
    print(f"  Iter {i + 1} after split: {len(partition3.state_ids())} states")

    # Merge
    partition3 = cssr._merge_states(partition3, tree3)
    print(f"  Iter {i + 1} after merge: {len(partition3.state_ids())} states")

    new_assignments = {h: partition3.get_state(h) for h in tree3.all_histories()}

    # Check if converged
    if new_assignments == old_assignments:
        print(f"\nConverged at iteration {i + 1}!")
        break

    # Show changes
    for h in sorted(tree3.all_histories(), key=lambda x: (len(x), x)):
        old_s = old_assignments.get(h)
        new_s = new_assignments.get(h)
        if old_s != new_s:
            print(f"    {h}: {old_s} -> {new_s}")
    print()

print(f"\nFinal partition: {len(partition3.state_ids())} states")
show_partition(partition3, "Final state assignments")

In [None]:
# Diagnose: Why is () being split from the other "ending in 0" histories?
print("=== Distribution Comparison ===\n")

print("Empty history vs length-1 histories:")
for h in [(), (0,), (1,)]:
    dist = get_distribution(tree3, h)
    if dist:
        print(f"  {h}: P(0)={dist.get(0, 0):.4f}, P(1)={dist.get(1, 0):.4f}")

print("\n() reflects the STATIONARY distribution of the process")
print("(0,) reflects the CONDITIONAL distribution given last symbol was 0")
print("\nFor Golden Mean with p=0.5:")
print("  Stationary: ~2/3 zeros, ~1/3 ones (because 1 must be followed by 0)")
print("  After 0: ~50/50 (can emit either)")
print("  After 1: 100% -> 0")

print("\n\n=== Chi-squared test: () vs (0,) ===")
from emic.inference.cssr.tests import distributions_differ

stats_empty = tree3.get_stats(())
stats_0 = tree3.get_stats((0,))

print(f"(): {dict(stats_empty.next_symbol_counts)}")
print(f"(0,): {dict(stats_0.next_symbol_counts)}")

# Check if they differ
differ = distributions_differ(
    stats_empty.next_symbol_counts, stats_0.next_symbol_counts, config.significance
)
print(f"\nDo they differ (α=0.05)? {differ}")

## Root Cause Analysis

The issue is that the **empty history `()`** represents a mixture of causal states (weighted by the stationary distribution), not a single causal state.

For the Golden Mean process:
- `()` has distribution ~66% 0, ~33% 1 (the stationary distribution)
- `(0,)` has distribution ~50% 0, ~50% 1 (conditional on being in "state after 0")
- `(1,)` has distribution 100% 0 (conditional on being in "state after 1")

Since `()` is statistically different from both `(0,)` and `(1,)`, the chi-squared test correctly identifies it as distinct, leading to 3 states instead of 2.

### Solution Options

1. **Exclude empty history**: Don't include `()` in the partition. Only use histories of length ≥ 1.

2. **Assign `()` to the most likely state**: After inferring states from length ≥ 1 histories, assign `()` to the state with highest stationary probability.

3. **Use `()` only for machine building**: Include `()` in the suffix tree for transition calculations but not in state equivalence testing.

The standard CSSR approach is option 1: only partition histories of length 1 to L.

In [None]:
# Test fix: Exclude empty history from partition
print("=== Testing Fix: Exclude Empty History ===\n")

# Manually trace CSSR without empty history
tree4 = SuffixTree(max_depth=config.max_history, alphabet=frozenset([0, 1]))
tree4.build_from_sequence(data)

# Modified initialization: only include histories of length >= 1
partition4 = StatePartition()
valid_histories = [
    h for h in tree4.all_histories() if len(h) >= 1 and tree4.get_stats(h).count >= config.min_count
]

initial_state = partition4.new_state_id()
for h in valid_histories:
    partition4.assign(h, initial_state)

print(f"Initial (length >= 1 only): {len(partition4.state_ids())} states")
print(f"  Histories: {sorted(valid_histories, key=lambda x: (len(x), x))}")

# Run split/merge iterations
for i in range(5):
    old_assignments = {h: partition4.get_state(h) for h in valid_histories}

    partition4 = cssr._split_states(partition4, tree4)
    partition4 = cssr._merge_states(partition4, tree4)

    new_assignments = {h: partition4.get_state(h) for h in valid_histories}

    print(f"Iter {i + 1}: {len(partition4.state_ids())} states")

    if new_assignments == old_assignments:
        print("Converged!")
        break

print(f"\nFinal: {len(partition4.state_ids())} states (expected: 2)")
show_partition(partition4, "Final partition")

## Step 7: Verify Entropy Rate Calculation

Now that we have 2 states, let's verify the entropy rate is calculated correctly.

In [None]:
# Calculate expected entropy rate for Golden Mean
# For Golden Mean with parameter p (prob of 1 after 0):
# Entropy rate h = H(p) / (1 + p) where H is binary entropy


def binary_entropy(p):
    """Binary entropy H(p) = -p*log2(p) - (1-p)*log2(1-p)"""
    if p == 0 or p == 1:
        return 0
    return -p * math.log2(p) - (1 - p) * math.log2(1 - p)


p = 0.5  # Our Golden Mean parameter
expected_entropy_rate = binary_entropy(p) / (1 + p)
print(f"Expected entropy rate for Golden Mean (p={p}): {expected_entropy_rate:.4f}")

In [None]:
# Manual entropy rate calculation from the inferred machine
# Entropy rate = sum over states: pi_s * H(P(·|s))
# where pi_s is the stationary probability of state s

# Our inferred machine:
# S2 (after 0): P(0) ≈ 0.495, P(1) ≈ 0.505 -> transitions to S2 and S1
# S1 (after 1): P(0) = 1.0 -> transitions to S2

# Get the actual probabilities from the inferred machine
p_1_given_A = None  # P(1 | state A)
for state in result.machine.states:
    for t in state.transitions:
        if t.symbol == 1:
            p_1_given_A = t.probability
            state_A = state.id
            break
    if p_1_given_A:
        break

print(f"From inferred machine:")
print(f"  State A ({state_A}): P(1) = {p_1_given_A:.4f}")

# Stationary distribution
# Balance: pi_A * P(1|A) = pi_B (since B always goes to A)
# pi_A + pi_B = 1
# => pi_A = 1 / (1 + P(1|A))
pi_A = 1 / (1 + p_1_given_A)
pi_B = p_1_given_A / (1 + p_1_given_A)

print(f"\nStationary distribution:")
print(f"  pi_A = {pi_A:.4f}")
print(f"  pi_B = {pi_B:.4f}")

# Entropy of each state
H_A = binary_entropy(p_1_given_A)
H_B = 0  # deterministic

print(f"\nState entropies:")
print(f"  H(A) = {H_A:.4f}")
print(f"  H(B) = {H_B:.4f}")

# Entropy rate
manual_entropy_rate = pi_A * H_A + pi_B * H_B
print(f"\nManual entropy rate: {manual_entropy_rate:.4f}")
print(f"Expected: {expected_entropy_rate:.4f}")

In [None]:
# Check what analyze() returns
from emic.analysis import analyze

summary = analyze(result.machine)
print("Analysis summary:")
print(f"  Entropy rate (from analyze): {summary.entropy_rate:.4f}")
print(f"  Manual calculation: {manual_entropy_rate:.4f}")
print(f"  Expected: {expected_entropy_rate:.4f}")
print(f"  Number of states: {summary.num_states}")
print(f"  Statistical complexity: {summary.statistical_complexity:.4f}")

## Conclusion

### Key Findings

1. **Root Cause of 3-State Bug**: The empty history `()` was being included in the partition. Since `()` reflects the stationary distribution (mixture of causal states), not a specific causal state, it was incorrectly identified as a distinct state.

2. **Fix Applied**: Modified `_initialize_partition()` to exclude the empty history `()` from partitioning. Only histories of length ≥ 1 are used for state equivalence testing.

3. **Result**: After the fix, CSSR correctly infers 2 states for the Golden Mean process:
   - **State A** (histories ending in 0): ~50/50 distribution
   - **State B** (histories ending in 1): 100% → 0 (deterministic)

4. **Entropy Rate**: The inferred machine's entropy rate matches the theoretical value of ~0.667 bits for Golden Mean with p=0.5.

In [None]:
# Final verification: Run fresh CSSR inference with the fix
# Reload modules to pick up the fix
import importlib
import emic.inference.cssr.algorithm

importlib.reload(emic.inference.cssr.algorithm)

from emic.sources import GoldenMeanSource
from emic.sources.transforms import TakeN
from emic.inference import CSSR, CSSRConfig
from emic.analysis import analyze

# Generate fresh data and run inference
final_result = GoldenMeanSource(p=0.5, _seed=42) >> TakeN(10000) >> CSSR(CSSRConfig(max_history=3))

print("=== Final Verification ===")
print(f"\nNumber of states: {len(final_result.machine.states)} (expected: 2)")

print("\nMachine structure:")
for state in final_result.machine.states:
    print(f"  {state.id}:")
    for t in state.transitions:
        print(f"    --{t.symbol} (p={t.probability:.4f})--> {t.target}")

final_summary = analyze(final_result.machine)
print(f"\nEntropy rate: {final_summary.entropy_rate:.4f}")
print(f"Expected: {expected_entropy_rate:.4f}")
print(f"Match: {abs(final_summary.entropy_rate - expected_entropy_rate) < 0.05}")

In [None]:
# Chi-squared comparisons using our implementation
from emic.inference.cssr.tests import distributions_differ


def compare_histories(tree, h1, h2, significance=0.05):
    """Compare two histories' distributions using chi-squared test."""
    stats1 = tree.get_stats(h1)
    stats2 = tree.get_stats(h2)

    if not stats1 or not stats2:
        return None

    return distributions_differ(stats1.next_symbol_counts, stats2.next_symbol_counts, significance)


print("=== Chi-Squared Comparisons (True = distributions DIFFER) ===")
print("\nHistories ending in 0 (should all be similar -> False):")
pairs = [((0,), (1, 0)), ((0,), (0, 0)), ((1, 0), (0, 0))]
for h1, h2 in pairs:
    result_cmp = compare_histories(tree, h1, h2)
    print(f"  {h1} vs {h2}: differ={result_cmp}")

print("\nHistories ending in 1 (should all be similar -> False):")
pairs = [((1,), (0, 1))]
for h1, h2 in pairs:
    result_cmp = compare_histories(tree, h1, h2)
    print(f"  {h1} vs {h2}: differ={result_cmp}")

print("\nCross comparisons (should be True - different distributions):")
pairs = [((0,), (1,)), ((0, 0), (0, 1))]
for h1, h2 in pairs:
    result_cmp = compare_histories(tree, h1, h2)
    print(f"  {h1} vs {h2}: differ={result_cmp}")

## Detailed Chi-Squared Analysis

Let's examine the chi-squared test with scipy for detailed statistics.

In [None]:
# Detailed chi-squared analysis using scipy
from scipy import stats as scipy_stats


def detailed_chi_squared(tree, h1, h2):
    """Detailed chi-squared analysis."""
    stats1 = tree.get_stats(h1)
    stats2 = tree.get_stats(h2)

    if not stats1 or not stats2:
        return

    print(f"\n{h1} vs {h2}:")
    print(
        f"  {h1}: {dict(stats1.next_symbol_counts)} (total: {sum(stats1.next_symbol_counts.values())})"
    )
    print(
        f"  {h2}: {dict(stats2.next_symbol_counts)} (total: {sum(stats2.next_symbol_counts.values())})"
    )

    # Build contingency table
    alphabet = sorted(set(stats1.next_symbol_counts.keys()) | set(stats2.next_symbol_counts.keys()))
    observed = []
    for sym in alphabet:
        observed.append(
            [stats1.next_symbol_counts.get(sym, 0), stats2.next_symbol_counts.get(sym, 0)]
        )

    print(f"  Contingency table: {observed}")

    # Chi-squared test
    chi2, p_value, dof, expected = scipy_stats.chi2_contingency(observed)
    print(f"  Chi2={chi2:.4f}, p-value={p_value:.6f}, dof={dof}")
    print(f"  Same distribution (α=0.05)? {p_value > 0.05}")


print("=== Detailed Chi-Squared Tests ===")
detailed_chi_squared(tree, (0,), (1, 0))
detailed_chi_squared(tree, (0,), (0, 0))
detailed_chi_squared(tree, (1, 0), (0, 0))
detailed_chi_squared(tree, (1,), (0, 1))