# MDP Hyperbolic Embeddings: Hitting Time Conjecture

This notebook explores hyperbolic embeddings for (state, goal) pairs in a 6-state MDP.

**Conjecture:**
- **Norm** (distance from origin) correlates **negatively** with **variance** in hitting times (low variance → high norm, high variance → low norm)
- **Angular coordinate** correlates with **mean** hitting times

## MDP Structure
```
State 1 (start)
    |
    |-- a11 (stochastic) --> 4 (p=0.5) or 5 (p=0.5) --> 6 (goal)
    |
    |-- a12 (deterministic) --> 2 --> (0.9 self-loop, 0.1 to 4) --> 6
    |
    |-- a13 (deterministic) --> 3 --> (0.9 self-loop, 0.1 to 5) --> 6
```

In [None]:
!pip install hypll geoopt

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
from scipy import stats

# Import hypll library
from hypll.manifolds.poincare_ball import PoincareBall, Curvature
from hypll.tensors.manifold_tensor import ManifoldTensor
from hypll.tensors import TangentTensor
import hypll.nn as hnn

# Setup
sns.set_style("whitegrid")
np.random.seed(42)
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## MDP Definition

In [None]:
class SimpleMDP:
    """
    6-state MDP with the following structure:
    
    States (0-indexed internally, displayed as 1-6):
        0 (State 1): Start state, 3 actions available
        1 (State 2): Self-loop (0.9) or to state 3 (0.1)
        2 (State 3): Self-loop (0.9) or to state 4 (0.1)
        3 (State 4): Deterministic to goal (state 5)
        4 (State 5): Deterministic to goal (state 5)
        5 (State 6): Goal (terminal)
    
    From state 0:
        - a0: Stochastic -> state 3 (p=0.5) or state 4 (p=0.5)
        - a1: Deterministic -> state 1
        - a2: Deterministic -> state 2
    """
    
    def __init__(self):
        self.n_states = 6
        self.start_state = 0  # State 1
        self.goal_state = 5   # State 6
        
        # Number of actions per state
        self.n_actions = {
            0: 3,  # State 1: a11, a12, a13
            1: 1,  # State 2: only one action (stochastic outcome)
            2: 1,  # State 3: only one action (stochastic outcome)
            3: 1,  # State 4: deterministic to goal
            4: 1,  # State 5: deterministic to goal
            5: 0,  # State 6: terminal
        }
        
    def get_transitions(self, state, action):
        """
        Get transition probabilities for (state, action) pair.
        Returns list of (next_state, probability) tuples.
        """
        if state == 0:  # State 1 (start)
            if action == 0:  # a11: stochastic
                return [(3, 0.5), (4, 0.5)]  # -> State 4 or 5
            elif action == 1:  # a12: deterministic
                return [(1, 1.0)]  # -> State 2
            elif action == 2:  # a13: deterministic
                return [(2, 1.0)]  # -> State 3
        
        elif state == 1:  # State 2
            return [(1, 0.9), (3, 0.1)]  # Self-loop or -> State 4
        
        elif state == 2:  # State 3
            return [(2, 0.9), (4, 0.1)]  # Self-loop or -> State 5
        
        elif state == 3:  # State 4
            return [(5, 1.0)]  # -> Goal (State 6)
        
        elif state == 4:  # State 5
            return [(5, 1.0)]  # -> Goal (State 6)
        
        elif state == 5:  # State 6 (goal)
            return [(5, 1.0)]  # Stay at goal
        
        return []
    
    def step(self, state, action=None):
        """
        Take a step from state with given action.
        If action is None, sample uniformly from available actions.
        Returns next_state.
        """
        if state == self.goal_state:
            return state
        
        if action is None:
            n_actions = self.n_actions[state]
            action = np.random.randint(0, max(1, n_actions))
        
        transitions = self.get_transitions(state, action)
        
        if len(transitions) == 1:
            return transitions[0][0]
        
        # Sample according to probabilities
        probs = [t[1] for t in transitions]
        next_states = [t[0] for t in transitions]
        return np.random.choice(next_states, p=probs)
    
    def state_name(self, state):
        """Return human-readable state name."""
        return f"S{state + 1}"


# Test the MDP
mdp = SimpleMDP()
print("MDP Structure Test:")
print(f"Start state: {mdp.state_name(mdp.start_state)}")
print(f"Goal state: {mdp.state_name(mdp.goal_state)}")
print()
for state in range(mdp.n_states):
    print(f"{mdp.state_name(state)}: {mdp.n_actions[state]} actions")
    for action in range(mdp.n_actions[state]):
        transitions = mdp.get_transitions(state, action)
        trans_str = ", ".join([f"{mdp.state_name(s)} (p={p})" for s, p in transitions])
        print(f"  a{action}: {trans_str}")

## Trajectory Generation

In [None]:
def generate_mdp_trajectories(mdp, n_trajectories, max_length=1000, seed=42):
    """
    Generate trajectories from start to goal under uniform random policy.
    
    Args:
        mdp: SimpleMDP instance
        n_trajectories: Number of trajectories to generate
        max_length: Maximum trajectory length
        seed: Random seed
    
    Returns:
        List of state sequences (lists of state indices)
    """
    np.random.seed(seed)
    trajectories = []
    
    for _ in range(n_trajectories):
        traj = [mdp.start_state]
        state = mdp.start_state
        
        for _ in range(max_length):
            if state == mdp.goal_state:
                break
            
            # Uniform random action selection
            state = mdp.step(state, action=None)
            traj.append(state)
        
        trajectories.append(traj)
    
    return trajectories


# Generate trajectories
trajectories = generate_mdp_trajectories(mdp, n_trajectories=5000, max_length=1000, seed=42)

# Analyze trajectories
lengths = [len(t) for t in trajectories]
print(f"Generated {len(trajectories)} trajectories")
print(f"Length stats: mean={np.mean(lengths):.1f}, std={np.std(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")

# Show sample trajectories
print("\nSample trajectories:")
for i in range(5):
    traj_str = " -> ".join([mdp.state_name(s) for s in trajectories[i][:15]])
    if len(trajectories[i]) > 15:
        traj_str += f" ... ({len(trajectories[i])} states total)"
    print(f"  {i+1}: {traj_str}")

## Contrastive Dataset

In [None]:
class MDPContrastiveDataset(Dataset):
    """
    Dataset for contrastive learning on MDP trajectory intervals.
    
    For each sample:
    - Anchor: [start_state, goal_state] interval from trajectory
    - Positive: Subinterval [k, l] where anchor_i <= k <= l <= anchor_j
    - Negatives: Intervals that are NOT subintervals of anchor
    
    States are normalized to [0, 1] by dividing by (n_states - 1).
    """
    
    def __init__(self, trajectories, n_states=6, num_samples=10000, n_negatives=5, seed=42):
        self.trajectories = trajectories
        self.n_states = n_states
        self.num_samples = num_samples
        self.n_negatives = n_negatives
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Filter valid trajectories (need at least 2 states for intervals)
        self.valid_traj_indices = [i for i, t in enumerate(trajectories) if len(t) >= 2]
        
        # Pre-generate all samples
        self.anchors, self.positives, self.negatives = self._generate_all_samples()
    
    def _normalize_state(self, state):
        """Normalize state index to [0, 1]."""
        return state / (self.n_states - 1)
    
    def _generate_all_samples(self):
        anchors = []
        positives = []
        negatives_list = []
        
        for _ in range(self.num_samples):
            anchor, positive, negs = self._generate_single_sample()
            anchors.append(anchor)
            positives.append(positive)
            negatives_list.append(negs)
        
        return (
            torch.tensor(anchors, dtype=torch.float32),
            torch.tensor(positives, dtype=torch.float32),
            torch.tensor(negatives_list, dtype=torch.float32),
        )
    
    def _generate_single_sample(self):
        """Generate a single (anchor, positive, negatives) tuple."""
        # Sample trajectory
        traj_idx = np.random.choice(self.valid_traj_indices)
        traj = self.trajectories[traj_idx]
        T = len(traj) - 1
        
        # Sample anchor interval [i, j] where i <= j
        j = np.random.randint(0, T + 1)
        i = np.random.randint(0, j + 1)
        
        anchor = [self._normalize_state(traj[i]), self._normalize_state(traj[j])]
        
        # Sample positive (subinterval): i <= k <= l <= j
        l = np.random.randint(i, j + 1)
        k = np.random.randint(i, l + 1)
        
        positive = [self._normalize_state(traj[k]), self._normalize_state(traj[l])]
        
        # Sample negatives (non-subintervals)
        negatives = []
        for _ in range(self.n_negatives):
            neg = self._sample_negative(traj, i, j, T)
            negatives.append(neg)
        
        return anchor, positive, negatives
    
    def _sample_negative(self, traj, anchor_i, anchor_j, T, max_attempts=1000):
        """Sample an interval that is NOT a subinterval of anchor."""
        for _ in range(max_attempts):
            l = np.random.randint(0, T + 1)
            k = np.random.randint(0, l + 1)
            
            # Check if NOT a subinterval (temporal containment)
            is_subinterval = (anchor_i <= k) and (l <= anchor_j)
            
            if not is_subinterval:
                return [self._normalize_state(traj[k]), self._normalize_state(traj[l])]
        
        # Fallback
        if anchor_i > 0:
            return [self._normalize_state(traj[0]), self._normalize_state(traj[0])]
        else:
            return [self._normalize_state(traj[T]), self._normalize_state(traj[T])]
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.anchors[idx], self.positives[idx], self.negatives[idx]


# Create dataset
dataset = MDPContrastiveDataset(
    trajectories=trajectories,
    n_states=6,
    num_samples=10000,
    n_negatives=5,
    seed=42
)

# Check shapes
anchor, positive, negatives = dataset[0]
print(f"Anchor shape: {anchor.shape}")
print(f"Positive shape: {positive.shape}")
print(f"Negatives shape: {negatives.shape}")
print(f"\nSample anchor: {anchor}")
print(f"Sample positive: {positive}")
print(f"Sample negatives[0]: {negatives[0]}")

## Hyperbolic Encoder

In [None]:
def manifold_map(x, manifold):
    """Map Euclidean tensor to hyperbolic manifold via exponential map."""
    tangents = TangentTensor(x, man_dim=-1, manifold=manifold)
    return manifold.expmap(tangents)


class HyperbolicIntervalEncoder(nn.Module):
    """
    Encode (state, goal) pairs to Poincare ball.
    Architecture: 2 Euclidean layers + 2 Hyperbolic layers
    """
    
    def __init__(self, embedding_dim=2, c=1.0, euc_width=128, hyp_width=128):
        super().__init__()
        
        # Create manifold
        curvature = Curvature(value=c, requires_grad=False)
        self.manifold = PoincareBall(c=curvature)
        
        # Euclidean layers
        self.euc_layer1 = nn.Linear(2, euc_width)
        self.euc_layer2 = nn.Linear(euc_width, hyp_width)
        self.euc_relu = nn.ReLU()
        
        # Hyperbolic layers
        self.hyp_layer1 = hnn.HLinear(
            in_features=hyp_width,
            out_features=hyp_width,
            manifold=self.manifold
        )
        self.hyp_layer2 = hnn.HLinear(
            in_features=hyp_width,
            out_features=embedding_dim,
            manifold=self.manifold
        )
        self.hyp_relu = hnn.HReLU(manifold=self.manifold)
    
    def forward(self, x):
        # Euclidean part
        x = self.euc_relu(self.euc_layer1(x))
        x = self.euc_layer2(x)
        
        # Map to hyperbolic space
        x = manifold_map(x, self.manifold)
        
        # Hyperbolic part
        x = self.hyp_relu(self.hyp_layer1(x))
        x = self.hyp_layer2(x)
        
        return x


# Test encoder
model = HyperbolicIntervalEncoder(embedding_dim=2, c=1.0).to(device)
test_input = torch.tensor([[0.0, 1.0], [0.2, 0.8]], dtype=torch.float32).to(device)
test_output = model(test_input)
print(f"Test input shape: {test_input.shape}")
print(f"Test output type: {type(test_output)}")
if isinstance(test_output, ManifoldTensor):
    print(f"Test output tensor shape: {test_output.tensor.shape}")
    print(f"Test output: {test_output.tensor}")

## Loss Function

In [None]:
def info_nce_loss(anchor, positive, negatives, manifold, temperature=0.5):
    """InfoNCE loss using hyperbolic distance."""
    batch_size = anchor.shape[0] if not isinstance(anchor, ManifoldTensor) else anchor.tensor.shape[0]
    num_neg = negatives.shape[1] if not isinstance(negatives, ManifoldTensor) else negatives.tensor.shape[1]
    
    # Compute positive distance
    pos_dist = manifold.dist(x=anchor, y=positive)
    
    # Expand anchor for negative comparisons
    if isinstance(anchor, ManifoldTensor):
        anchor_tensor = anchor.tensor.unsqueeze(1).expand(-1, num_neg, -1)
        anchor_expanded = ManifoldTensor(anchor_tensor, manifold=manifold)
    else:
        anchor_expanded = anchor.unsqueeze(1).expand(-1, num_neg, -1)
    
    neg_dist = manifold.dist(x=anchor_expanded, y=negatives)
    
    # Extract tensors
    if isinstance(pos_dist, ManifoldTensor):
        pos_dist = pos_dist.tensor
    if isinstance(neg_dist, ManifoldTensor):
        neg_dist = neg_dist.tensor
    
    # Reshape
    if pos_dist.dim() > 1:
        pos_dist = pos_dist.squeeze(-1)
    if neg_dist.dim() > 2:
        neg_dist = neg_dist.squeeze(-1)
    
    # Check for NaN
    if torch.isnan(pos_dist).any() or torch.isnan(neg_dist).any():
        print("WARNING: NaN in distances")
        return torch.tensor(0.0, device=pos_dist.device, requires_grad=True)
    
    # Convert distances to similarities
    pos_sim = -pos_dist / temperature
    neg_sim = -neg_dist / temperature
    
    # Combine for cross-entropy
    logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
    labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)
    
    loss = nn.functional.cross_entropy(logits, labels)
    return loss

## Training

In [None]:
def train_model(model, dataset, num_epochs=200, batch_size=32, lr=0.001, temperature=0.1):
    """Train the hyperbolic encoder."""
    from hypll.optim import RiemannianAdam
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    optimizer = RiemannianAdam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
    
    losses = []
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        for anchor, positive, negatives in dataloader:
            anchor = anchor.to(device)
            positive = positive.to(device)
            negatives = negatives.to(device)
            
            # Forward pass
            anchor_emb = model(anchor)
            positive_emb = model(positive)
            
            bs, num_neg, _ = negatives.shape
            negatives_emb = model(negatives.view(-1, 2))
            
            # Reshape negatives
            if isinstance(negatives_emb, ManifoldTensor):
                neg_tensor = negatives_emb.tensor.view(bs, num_neg, -1)
                negatives_emb = ManifoldTensor(neg_tensor, manifold=model.manifold)
            else:
                negatives_emb = negatives_emb.view(bs, num_neg, -1)
            
            # Loss
            loss = info_nce_loss(anchor_emb, positive_emb, negatives_emb,
                                model.manifold, temperature=temperature)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Skipping batch due to NaN/Inf loss")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        if num_batches > 0:
            avg_loss = epoch_loss / num_batches
            losses.append(avg_loss)
            scheduler.step()
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    return losses

In [None]:
# Train the model
model = HyperbolicIntervalEncoder(embedding_dim=2, c=1.0, euc_width=128, hyp_width=128).to(device)
losses = train_model(model, dataset, num_epochs=200, batch_size=32, lr=0.001, temperature=0.1)

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## Hitting Time Analysis

In [None]:
def compute_hitting_times(mdp, start_state, goal_state=5, n_simulations=10000, max_steps=10000):
    """
    Monte Carlo estimation of hitting times from start_state to goal_state
    under uniform random policy.
    
    Returns (mean, variance, std) of hitting times.
    """
    hitting_times = []
    
    for _ in range(n_simulations):
        state = start_state
        steps = 0
        
        while state != goal_state and steps < max_steps:
            state = mdp.step(state, action=None)
            steps += 1
        
        if state == goal_state:
            hitting_times.append(steps)
    
    if len(hitting_times) == 0:
        return float('inf'), float('inf'), float('inf')
    
    return np.mean(hitting_times), np.var(hitting_times), np.std(hitting_times)


def compute_all_hitting_times(mdp, n_simulations=10000):
    """
    Compute hitting time statistics for all (start, goal) pairs.
    
    Returns dict: {start_state: {'mean': ..., 'var': ..., 'std': ...}}
    """
    hitting_stats = {}
    goal = mdp.goal_state
    
    for start in range(mdp.n_states):
        if start != goal:
            mean, var, std = compute_hitting_times(mdp, start, goal, n_simulations)
            hitting_stats[start] = {'mean': mean, 'var': var, 'std': std}
            print(f"{mdp.state_name(start)} -> {mdp.state_name(goal)}: mean={mean:.2f}, var={var:.2f}, std={std:.2f}")
    
    return hitting_stats


# Compute hitting times
print("Computing hitting time statistics...")
print("="*60)
hitting_stats = compute_all_hitting_times(mdp, n_simulations=50000)

## Visualization on Poincare Ball

In [None]:
# =============================================================================
# Utility Functions
# =============================================================================

def pair_label(start, goal):
    """
    Create label for (start, goal) pair.
    - Atomic pairs (s, s) displayed as just the number: "3"
    - Non-atomic pairs displayed as: "(1,6)"
    """
    # Convert from 0-indexed to 1-indexed for display
    s, g = start + 1, goal + 1
    if start == goal:
        return f"{s}"
    return f"({s},{g})"


def plot_poincare_disk(ax, title=""):
    """Draw Poincare disk boundary and origin."""
    circle = plt.Circle((0, 0), 1, color='black', fill=False, linewidth=2, linestyle='--')
    ax.add_patch(circle)
    ax.scatter([0], [0], s=60, c='black', marker='+', linewidth=2, zorder=10)
    ax.set_xlim(-1.15, 1.15)
    ax.set_ylim(-1.15, 1.15)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.2)
    if title:
        ax.set_title(title, fontsize=13)


def hyperbolic_norm(z, curvature=1.0):
    """Compute hyperbolic distance from origin to point z."""
    K = curvature
    z_norm = np.linalg.norm(z)
    z_norm = np.clip(z_norm, 0, (1 / np.sqrt(K)) - 1e-10)
    return (2 / np.sqrt(K)) * np.arctanh(np.sqrt(K) * z_norm)


def get_embedding(model, start, goal, n_states=6):
    """
    Get the embedding for a specific (start, goal) pair.
    
    Args:
        model: Trained HyperbolicIntervalEncoder
        start: Start state (0-indexed)
        goal: Goal state (0-indexed)
        n_states: Total number of states
    
    Returns:
        numpy array of shape (2,) - the embedding coordinates
    """
    model.eval()
    with torch.no_grad():
        s_norm = start / (n_states - 1)
        g_norm = goal / (n_states - 1)
        x = torch.tensor([[s_norm, g_norm]], dtype=torch.float32).to(device)
        emb = model(x)
        if isinstance(emb, ManifoldTensor):
            emb = emb.tensor
        return emb.squeeze(0).cpu().numpy()


def get_all_embeddings(model, n_states=6, include_atomic=True):
    """
    Get embeddings for all (state, goal) pairs.
    
    Args:
        model: Trained HyperbolicIntervalEncoder
        n_states: Total number of states
        include_atomic: Whether to include atomic (s, s) pairs
    
    Returns:
        Dict mapping (start, goal) -> embedding array
    """
    model.eval()
    embeddings = {}
    
    with torch.no_grad():
        for start in range(n_states):
            for goal in range(n_states):
                if include_atomic or start != goal:
                    emb = get_embedding(model, start, goal, n_states)
                    embeddings[(start, goal)] = emb
    
    return embeddings


def get_state_embeddings(model, mdp):
    """
    Get embeddings for all (state, final_goal) pairs.
    This is for backwards compatibility.
    
    Returns dict: {start_state: embedding_array}
    """
    model.eval()
    embeddings = {}
    goal = mdp.goal_state
    n_states = mdp.n_states
    
    with torch.no_grad():
        for state in range(n_states):
            if state != goal:
                emb = get_embedding(model, state, goal, n_states)
                embeddings[state] = emb
    
    return embeddings


# Get embeddings for all (state, goal=6) pairs
embeddings = get_state_embeddings(model, mdp)

# Print embeddings
print("State embeddings (state -> goal=6):")
for state, emb in embeddings.items():
    norm = hyperbolic_norm(emb)
    angle = np.degrees(np.arctan2(emb[1], emb[0]))
    label = pair_label(state, mdp.goal_state)
    print(f"  {label}: ({emb[0]:.4f}, {emb[1]:.4f}), norm={norm:.4f}, angle={angle:.1f}deg")

In [None]:
def plot_mdp_embeddings(embeddings, hitting_stats, mdp):
    """
    Visualize embeddings on the Poincare disk, colored by hitting time statistics.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    states = sorted(embeddings.keys())
    coords = np.array([embeddings[s] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # Plot 1: Colored by mean hitting time
    ax = axes[0]
    plot_poincare_disk(ax, "Colored by Mean Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=means, cmap='viridis',
                        s=120, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    # Plot 2: Colored by variance
    ax = axes[1]
    plot_poincare_disk(ax, "Colored by Variance of Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma',
                        s=120, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    # Plot 3: Labeled with all info
    ax = axes[2]
    plot_poincare_disk(ax, "All State Embeddings")
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(states)))
    for i, s in enumerate(states):
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=120,
                  edgecolors='black', linewidth=1.5, zorder=5,
                  label=f"{label}: T={hitting_stats[s]['mean']:.1f}")
        ax.annotate(label, (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    ax.legend(loc='upper right', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('mdp_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_mdp_embeddings(embeddings, hitting_stats, mdp)

In [None]:
# Duplicate cell removed - using cell-20 instead

In [None]:
def test_conjecture(embeddings, hitting_stats, mdp):
    """
    Test the conjecture:
    - Norm (distance from origin) correlates NEGATIVELY with variance in hitting times
      (low variance -> high norm, high variance -> low norm)
    - Angular coordinate correlates with mean hitting times
    """
    states = sorted(embeddings.keys())
    
    # Extract data
    coords = np.array([embeddings[s] for s in states])
    norms = np.array([hyperbolic_norm(embeddings[s]) for s in states])
    angles = np.array([np.arctan2(embeddings[s][1], embeddings[s][0]) for s in states])
    
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    stds = np.array([hitting_stats[s]['std'] for s in states])
    
    # Compute correlations
    corr_norm_var, p_norm_var = stats.spearmanr(norms, vars_)
    corr_norm_std, p_norm_std = stats.spearmanr(norms, stds)
    corr_angle_mean, p_angle_mean = stats.spearmanr(angles, means)
    
    # Also compute Pearson
    pearson_norm_var = np.corrcoef(norms, vars_)[0, 1]
    pearson_angle_mean = np.corrcoef(angles, means)[0, 1]
    
    print("="*70)
    print("CONJECTURE TEST RESULTS")
    print("="*70)
    
    print("\n1. NORM vs VARIANCE Conjecture:")
    print(f"   Spearman correlation: {corr_norm_var:.4f} (p={p_norm_var:.4f})")
    print(f"   Pearson correlation:  {pearson_norm_var:.4f}")
    print(f"   Expected: NEGATIVE (low variance -> high norm)")
    norm_var_supported = corr_norm_var < -0.3
    print(f"   Result: {'SUPPORTED' if norm_var_supported else 'NOT SUPPORTED'}")
    
    print("\n2. ANGLE vs MEAN Conjecture:")
    print(f"   Spearman correlation: {corr_angle_mean:.4f} (p={p_angle_mean:.4f})")
    print(f"   Pearson correlation:  {pearson_angle_mean:.4f}")
    print(f"   Expected: Strong correlation (similar means -> similar angles)")
    angle_mean_supported = abs(corr_angle_mean) > 0.3
    print(f"   Result: {'SUPPORTED' if angle_mean_supported else 'NOT SUPPORTED'}")
    
    print("\n" + "="*70)
    print("STATE-BY-STATE ANALYSIS")
    print("="*70)
    print(f"{'Pair':<10} {'Norm':<10} {'Angle(deg)':<12} {'Mean T':<10} {'Var T':<10} {'Std T':<10}")
    print("-"*70)
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        print(f"{label:<10} {norms[i]:<10.4f} {np.degrees(angles[i]):<12.1f} "
              f"{means[i]:<10.2f} {vars_[i]:<10.2f} {stds[i]:<10.2f}")
    
    return {
        'corr_norm_var': corr_norm_var,
        'corr_angle_mean': corr_angle_mean,
        'p_norm_var': p_norm_var,
        'p_angle_mean': p_angle_mean,
        'norms': norms,
        'angles': angles,
        'means': means,
        'vars': vars_,
        'states': states,
        'norm_var_supported': norm_var_supported,
        'angle_mean_supported': angle_mean_supported
    }


results = test_conjecture(embeddings, hitting_stats, mdp)

In [None]:
def plot_summary(embeddings, results, hitting_stats, mdp):
    """
    Create a comprehensive summary plot with Poincare disk, scatter plots, and statistics.
    """
    fig = plt.figure(figsize=(16, 14))
    
    # Create grid spec for layout
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.25)
    
    states = results['states']
    norms = results['norms']
    angles = results['angles']
    means = results['means']
    vars_ = results['vars']
    coords = np.array([embeddings[s] for s in states])
    
    # =========================================================================
    # Plot 1: Poincare Ball (top-left)
    # =========================================================================
    ax1 = fig.add_subplot(gs[0, 0])
    
    # Draw disk
    circle = plt.Circle((0, 0), 1, color='black', fill=False, linewidth=2, linestyle='--')
    ax1.add_patch(circle)
    ax1.scatter([0], [0], s=60, c='black', marker='+', linewidth=2, zorder=10)
    
    # Color by variance (inverse - low var = brighter)
    scatter = ax1.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma_r',
                         s=100, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax1, label='Variance', shrink=0.8)
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax1.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    ax1.set_xlim(-1.15, 1.15)
    ax1.set_ylim(-1.15, 1.15)
    ax1.set_aspect('equal')
    ax1.grid(True, alpha=0.2)
    ax1.set_title('Poincaré Ball Embeddings\n(colored by variance)', fontsize=13, fontweight='bold')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    
    # =========================================================================
    # Plot 2: Norm vs Variance (top-right)
    # =========================================================================
    ax2 = fig.add_subplot(gs[0, 1])
    
    scatter = ax2.scatter(vars_, norms, c=means, cmap='viridis', s=100, 
                         edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax2, label='Mean Hitting Time', shrink=0.8)
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax2.annotate(label, (vars_[i], norms[i]), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(vars_, norms, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(vars_) - 5, max(vars_) + 5, 100)
    ax2.plot(x_trend, p(x_trend), 'r--', linewidth=2, alpha=0.7, label=f'Trend (slope={z[0]:.4f})')
    
    ax2.set_xlabel('Variance of Hitting Time', fontsize=11)
    ax2.set_ylabel('Hyperbolic Norm', fontsize=11)
    
    # Add correlation annotation
    corr = results['corr_norm_var']
    status = "SUPPORTED" if corr < -0.3 else "NOT SUPPORTED"
    color = 'green' if corr < -0.3 else 'red'
    ax2.set_title(f'Conjecture 1: Norm vs Variance\nρ = {corr:.3f} ({status})', 
                 fontsize=13, fontweight='bold', color='black')
    ax2.legend(loc='upper right')
    ax2.grid(True, alpha=0.3)
    
    # =========================================================================
    # Plot 3: Angle vs Mean (bottom-left)
    # =========================================================================
    ax3 = fig.add_subplot(gs[1, 0])
    
    scatter = ax3.scatter(means, np.degrees(angles), c=vars_, cmap='plasma', s=100,
                         edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax3, label='Variance', shrink=0.8)
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax3.annotate(label, (means[i], np.degrees(angles[i])), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(means, np.degrees(angles), 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(means) - 1, max(means) + 1, 100)
    ax3.plot(x_trend, p(x_trend), 'r--', linewidth=2, alpha=0.7, label=f'Trend (slope={z[0]:.2f})')
    
    ax3.set_xlabel('Mean Hitting Time', fontsize=11)
    ax3.set_ylabel('Angular Coordinate (degrees)', fontsize=11)
    
    corr = results['corr_angle_mean']
    status = "SUPPORTED" if abs(corr) > 0.3 else "NOT SUPPORTED"
    ax3.set_title(f'Conjecture 2: Angle vs Mean\nρ = {corr:.3f} ({status})', 
                 fontsize=13, fontweight='bold')
    ax3.legend(loc='best')
    ax3.grid(True, alpha=0.3)
    
    # =========================================================================
    # Plot 4: Summary Statistics (bottom-right)
    # =========================================================================
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.axis('off')
    
    # Create summary text
    norm_var_status = "SUPPORTED" if results['norm_var_supported'] else "NOT SUPPORTED"
    angle_mean_status = "SUPPORTED" if results['angle_mean_supported'] else "NOT SUPPORTED"
    
    summary_text = f"""
╔══════════════════════════════════════════════════════════╗
║              HYPOTHESIS TEST SUMMARY                      ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  CONJECTURE 1: Norm ∝ 1/Variance                         ║
║  ─────────────────────────────────────────────────────   ║
║  Expected: Low variance → High norm (negative corr)      ║
║  Spearman ρ = {results['corr_norm_var']:+.4f}  (p = {results['p_norm_var']:.4f})                  ║
║  Status: {norm_var_status:<20}                       ║
║                                                          ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  CONJECTURE 2: Angle ∝ Mean                              ║
║  ─────────────────────────────────────────────────────   ║
║  Expected: Similar mean → Similar angle (strong corr)    ║
║  Spearman ρ = {results['corr_angle_mean']:+.4f}  (p = {results['p_angle_mean']:.4f})                  ║
║  Status: {angle_mean_status:<20}                       ║
║                                                          ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  STATE STATISTICS:                                       ║
║  ─────────────────────────────────────────────────────   ║"""
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        summary_text += f"\n║  {label}: norm={norms[i]:.2f}, angle={np.degrees(angles[i]):+6.1f}°, "
        summary_text += f"μ={means[i]:5.1f}, σ²={vars_[i]:6.1f}  ║"
    
    summary_text += """
║                                                          ║
╚══════════════════════════════════════════════════════════╝
"""
    
    ax4.text(0.02, 0.98, summary_text, transform=ax4.transAxes, fontsize=10,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='lightyellow', alpha=0.9, edgecolor='black'))
    
    plt.suptitle('MDP Hyperbolic Embedding Analysis', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.savefig('mdp_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()


plot_summary(embeddings, results, hitting_stats, mdp)

In [None]:
def plot_conjecture_analysis(results, mdp):
    """
    Detailed visualization of the conjecture analysis.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    norms = results['norms']
    angles = results['angles']
    means = results['means']
    vars_ = results['vars']
    states = results['states']
    
    # Plot 1: Norm vs Variance
    ax = axes[0, 0]
    scatter = ax.scatter(vars_, norms, c=means, cmap='viridis', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (vars_[i], norms[i]), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(vars_, norms, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(vars_), max(vars_), 100)
    ax.plot(x_trend, p(x_trend), 'r--', linewidth=2, label=f'Trend (slope={z[0]:.4f})')
    
    ax.set_xlabel('Variance of Hitting Time', fontsize=12)
    ax.set_ylabel('Hyperbolic Norm', fontsize=12)
    ax.set_title(f'Conjecture 1: Norm vs Variance\nCorr={results["corr_norm_var"]:.3f}', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Angle vs Mean
    ax = axes[0, 1]
    scatter = ax.scatter(means, np.degrees(angles), c=vars_, cmap='plasma', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (means[i], np.degrees(angles[i])), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    ax.set_xlabel('Mean Hitting Time', fontsize=12)
    ax.set_ylabel('Angular Coordinate (degrees)', fontsize=12)
    ax.set_title(f'Conjecture 2: Angle vs Mean\nCorr={results["corr_angle_mean"]:.3f}', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Norm vs Mean
    ax = axes[1, 0]
    scatter = ax.scatter(means, norms, c=vars_, cmap='plasma', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (means[i], norms[i]), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    corr_norm_mean = np.corrcoef(norms, means)[0, 1]
    ax.set_xlabel('Mean Hitting Time', fontsize=12)
    ax.set_ylabel('Hyperbolic Norm', fontsize=12)
    ax.set_title(f'Additional: Norm vs Mean\nCorr={corr_norm_mean:.3f}', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Summary stats
    ax = axes[1, 1]
    ax.axis('off')
    
    summary_text = f"""
    CONJECTURE ANALYSIS SUMMARY
    {'='*50}
    
    Conjecture 1: Norm ~ Variance
    -----------------------------
    Spearman correlation: {results['corr_norm_var']:.4f}
    Interpretation: {'SUPPORTED' if results['corr_norm_var'] > 0.3 else 'WEAK/NOT SUPPORTED'}
    
    Conjecture 2: Angle ~ Mean
    ---------------------------
    Spearman correlation: {results['corr_angle_mean']:.4f}
    Interpretation: {'SUPPORTED' if abs(results['corr_angle_mean']) > 0.3 else 'WEAK/NOT SUPPORTED'}
    
    {'='*50}
    
    Note: The MDP has structure that affects embeddings:
    - States 4, 5: 1-step to goal (low var, low mean)
    - States 2, 3: Geometric waiting (high var, high mean)
    - State 1: Depends on action choice
    """
    
    ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('conjecture_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_conjecture_analysis(results, mdp)

In [None]:
def plot_all_state_pairs(model, mdp):
    """
    Plot embeddings for ALL (state, goal) pairs, not just (state, final_goal).
    """
    model.eval()
    n_states = mdp.n_states
    
    fig, ax = plt.subplots(figsize=(10, 10))
    plot_poincare_disk(ax, "All (State, Goal) Pair Embeddings")
    
    all_embeddings = []
    labels = []
    colors = []
    
    with torch.no_grad():
        for start in range(n_states):
            for goal in range(n_states):
                if start != goal:  # Skip self-loops
                    s_norm = start / (n_states - 1)
                    g_norm = goal / (n_states - 1)
                    
                    x = torch.tensor([[s_norm, g_norm]], dtype=torch.float32).to(device)
                    emb = model(x)
                    
                    if isinstance(emb, ManifoldTensor):
                        emb = emb.tensor
                    
                    emb = emb.squeeze(0).cpu().numpy()
                    all_embeddings.append(emb)
                    labels.append(pair_label(start, goal))
                    colors.append(goal)  # Color by goal state
    
    all_embeddings = np.array(all_embeddings)
    colors = np.array(colors)
    
    scatter = ax.scatter(all_embeddings[:, 0], all_embeddings[:, 1], 
                        c=colors, cmap='tab10', s=80, edgecolors='black', linewidth=1, zorder=5)
    
    # Add labels
    for i, label in enumerate(labels):
        ax.annotate(label, all_embeddings[i], fontsize=7, 
                   xytext=(3, 3), textcoords='offset points')
    
    plt.colorbar(scatter, ax=ax, label='Goal State Index')
    plt.tight_layout()
    plt.savefig('all_state_pairs.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_all_state_pairs(model, mdp)

In [None]:
def plot_polar_representation(embeddings, hitting_stats, mdp):
    """
    Polar plot showing angle vs hyperbolic norm.
    """
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': 'polar'})
    
    states = sorted(embeddings.keys())
    
    angles = [np.arctan2(embeddings[s][1], embeddings[s][0]) for s in states]
    norms = [hyperbolic_norm(embeddings[s]) for s in states]
    means = [hitting_stats[s]['mean'] for s in states]
    vars_ = [hitting_stats[s]['var'] for s in states]
    
    # Size by inverse variance (low variance = larger), color by mean
    max_var = max(vars_) if max(vars_) > 0 else 1
    sizes = 50 + 100 * (1 - np.array(vars_) / max_var)
    
    scatter = ax.scatter(angles, norms, c=means, cmap='viridis', s=sizes, 
                        edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time', pad=0.1)
    
    for i, s in enumerate(states):
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (angles[i], norms[i]), fontsize=11, fontweight='bold')
    
    ax.set_title('Polar View: Angle vs Hyperbolic Norm\n(Size ~ 1/Variance, Color ~ Mean)', fontsize=13, pad=20)
    
    plt.tight_layout()
    plt.savefig('polar_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_polar_representation(embeddings, hitting_stats, mdp)

## Discussion

### MDP Structure Analysis

The MDP has distinct structural properties that affect hitting times:

1. **States 4 and 5 (S4, S5)**: Deterministic 1-step transitions to goal
   - Mean hitting time: 1
   - Variance: 0
   
2. **States 2 and 3 (S2, S3)**: Geometric waiting times due to self-loops
   - Mean hitting time: ~10-11 (expected value of geometric(0.1) + 1)
   - Variance: ~90 (variance of geometric distribution)
   
3. **State 1 (S1)**: Depends on action selection
   - Under random policy: mix of 2-step (via a11) and longer paths (via a12, a13)

### Conjecture Interpretation

The conjecture states:
- **Norm ~ 1/Variance**: States with **lower** hitting time variance should be embedded **farther** from the origin (higher norm)
- **Angle ~ Mean**: States with similar mean hitting times should have similar angular coordinates

This would create a structure where:
- S4, S5 (low variance, deterministic) should have **high norms** (far from origin)
- S2, S3 (high variance, geometric waiting) should have **low norms** (near origin)
- States with similar means lie along similar radial directions

### Intuition

In hyperbolic space, the origin represents maximal uncertainty/abstraction. States with:
- **Low variance** (predictable hitting times) are more "specific" → farther from origin
- **High variance** (unpredictable hitting times) are more "abstract" → closer to origin

In [None]:
# Final summary
print("="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nTraining completed with {len(trajectories)} trajectories")
print(f"Final training loss: {losses[-1]:.4f}")

print(f"\n--- Conjecture 1: Norm ~ 1/Variance ---")
print(f"Spearman correlation: {results['corr_norm_var']:.4f}")
print(f"Expected: Negative (low variance -> high norm)")
print(f"Status: {'SUPPORTED' if results['norm_var_supported'] else 'NOT SUPPORTED'}")

print(f"\n--- Conjecture 2: Angle ~ Mean ---")
print(f"Spearman correlation: {results['corr_angle_mean']:.4f}")
print(f"Expected: Strong correlation")
print(f"Status: {'SUPPORTED' if results['angle_mean_supported'] else 'NOT SUPPORTED'}")

print("\n" + "="*70)
print("See the summary plot above for detailed visualization.")

In [None]:
# Final summary
print("="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nTraining completed with {len(trajectories)} trajectories")
print(f"Final training loss: {losses[-1]:.4f}")
print(f"\nConjecture 1 (Norm ~ Variance): correlation = {results['corr_norm_var']:.4f}")
print(f"Conjecture 2 (Angle ~ Mean): correlation = {results['corr_angle_mean']:.4f}")
print("\nSee visualizations above for detailed analysis.")

## Goal-Conditioned Behavioral Cloning (GCBC)

We train two policies:
1. **Raw State Policy**: Takes (state, goal) as one-hot encoded inputs
2. **Embedding Policy**: Takes hyperbolic embeddings of (state, goal) pairs

Both are trained via behavioral cloning on the generated trajectories.

In [None]:
# =============================================================================
# GCBC Dataset Creation
# =============================================================================

def create_gcbc_dataset(trajectories, mdp, goal_type="final"):
    """
    Create dataset for goal-conditioned behavioral cloning.
    
    For each (state, action) in a trajectory, we create a sample:
    - state: current state (0-indexed)
    - goal: final state of trajectory (or random future state)
    - action: action taken (inferred from transition)
    
    Args:
        trajectories: List of state sequences
        mdp: SimpleMDP instance
        goal_type: "final" (trajectory endpoint) or "random" (random future state)
    
    Returns:
        states, goals, actions as numpy arrays
    """
    states = []
    goals = []
    actions = []
    
    for traj in trajectories:
        final_state = traj[-1]  # Usually the goal state
        
        for t in range(len(traj) - 1):
            state = traj[t]
            next_state = traj[t + 1]
            
            if state == mdp.goal_state:
                continue  # Skip terminal state
            
            # Infer action from (state, next_state) transition
            action = infer_action(mdp, state, next_state)
            if action is None:
                continue  # Skip invalid transitions
            
            if goal_type == "final":
                goal = final_state
            else:  # random future
                future_idx = np.random.randint(t + 1, len(traj))
                goal = traj[future_idx]
            
            states.append(state)
            goals.append(goal)
            actions.append(action)
    
    return np.array(states), np.array(goals), np.array(actions)


def infer_action(mdp, state, next_state):
    """
    Infer which action was taken to transition from state to next_state.
    Returns action index or None if transition is impossible.
    """
    for action in range(mdp.n_actions[state]):
        transitions = mdp.get_transitions(state, action)
        for ns, prob in transitions:
            if ns == next_state and prob > 0:
                return action
    return None


# Create GCBC dataset
gcbc_states, gcbc_goals, gcbc_actions = create_gcbc_dataset(trajectories, mdp, goal_type="final")

print(f"GCBC Dataset created:")
print(f"  Total samples: {len(gcbc_states)}")
print(f"  Unique states: {np.unique(gcbc_states)}")
print(f"  Unique goals: {np.unique(gcbc_goals)}")
print(f"  Unique actions: {np.unique(gcbc_actions)}")

# Distribution of actions per state
print("\nAction distribution per state:")
for s in range(mdp.n_states - 1):  # Exclude goal state
    mask = gcbc_states == s
    if mask.sum() > 0:
        action_counts = np.bincount(gcbc_actions[mask], minlength=3)
        print(f"  State {s+1}: {dict(zip(range(len(action_counts)), action_counts))}")

In [None]:
# =============================================================================
# GCBC Policy Networks
# =============================================================================

class GCBCPolicyRaw(nn.Module):
    """
    Goal-Conditioned Policy using raw (one-hot) state representations.
    
    Input: concatenated one-hot encodings of state and goal
    Output: action logits
    """
    
    def __init__(self, n_states=6, n_actions=3, hidden_dim=64):
        super().__init__()
        self.n_states = n_states
        self.n_actions = n_actions
        
        # Input: one-hot(state) + one-hot(goal) = 2 * n_states
        input_dim = 2 * n_states
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    
    def forward(self, states, goals):
        """
        Args:
            states: tensor of shape (batch,) with state indices
            goals: tensor of shape (batch,) with goal indices
        Returns:
            action logits of shape (batch, n_actions)
        """
        # One-hot encode
        state_onehot = torch.nn.functional.one_hot(states.long(), self.n_states).float()
        goal_onehot = torch.nn.functional.one_hot(goals.long(), self.n_states).float()
        
        # Concatenate
        x = torch.cat([state_onehot, goal_onehot], dim=-1)
        
        return self.network(x)
    
    def get_action(self, state, goal, deterministic=True):
        """Get action for a single (state, goal) pair."""
        self.eval()
        with torch.no_grad():
            s = torch.tensor([state], dtype=torch.long, device=device)
            g = torch.tensor([goal], dtype=torch.long, device=device)
            logits = self.forward(s, g)
            
            if deterministic:
                return logits.argmax(dim=-1).item()
            else:
                probs = torch.softmax(logits, dim=-1)
                return torch.multinomial(probs, 1).item()


class GCBCPolicyEmbedding(nn.Module):
    """
    Goal-Conditioned Policy using hyperbolic embeddings.
    
    Input: embedding of (state, goal) pair from trained encoder
    Output: action logits
    """
    
    def __init__(self, embedding_dim=2, n_actions=3, hidden_dim=64):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.n_actions = n_actions
        
        self.network = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    
    def forward(self, embeddings):
        """
        Args:
            embeddings: tensor of shape (batch, embedding_dim)
        Returns:
            action logits of shape (batch, n_actions)
        """
        return self.network(embeddings)
    
    def get_action(self, embedding, deterministic=True):
        """Get action for a single embedding."""
        self.eval()
        with torch.no_grad():
            emb = torch.tensor(embedding, dtype=torch.float32, device=device).unsqueeze(0)
            logits = self.forward(emb)
            
            if deterministic:
                return logits.argmax(dim=-1).item()
            else:
                probs = torch.softmax(logits, dim=-1)
                return torch.multinomial(probs, 1).item()


# Test policy networks
raw_policy = GCBCPolicyRaw(n_states=6, n_actions=3).to(device)
emb_policy = GCBCPolicyEmbedding(embedding_dim=2, n_actions=3).to(device)

# Test forward pass
test_states = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
test_goals = torch.tensor([5, 5, 5], dtype=torch.long, device=device)
test_emb = torch.randn(3, 2, device=device)

print("Raw Policy output shape:", raw_policy(test_states, test_goals).shape)
print("Embedding Policy output shape:", emb_policy(test_emb).shape)

In [None]:
# =============================================================================
# Training Functions
# =============================================================================

def train_raw_policy(policy, states, goals, actions, epochs=100, batch_size=64, lr=0.001):
    """Train the raw state policy via behavioral cloning."""
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Convert to tensors
    states_t = torch.tensor(states, dtype=torch.long, device=device)
    goals_t = torch.tensor(goals, dtype=torch.long, device=device)
    actions_t = torch.tensor(actions, dtype=torch.long, device=device)
    
    n_samples = len(states)
    losses = []
    
    policy.train()
    for epoch in range(epochs):
        # Shuffle data
        perm = torch.randperm(n_samples)
        epoch_loss = 0.0
        n_batches = 0
        
        for i in range(0, n_samples, batch_size):
            idx = perm[i:i+batch_size]
            batch_states = states_t[idx]
            batch_goals = goals_t[idx]
            batch_actions = actions_t[idx]
            
            # Forward pass
            logits = policy(batch_states, batch_goals)
            loss = criterion(logits, batch_actions)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            n_batches += 1
        
        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    return losses


def train_embedding_policy(policy, encoder, states, goals, actions, epochs=100, batch_size=64, lr=0.001):
    """Train the embedding policy via behavioral cloning."""
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Get embeddings for all (state, goal) pairs
    encoder.eval()
    n_states = 6
    
    # Pre-compute embeddings
    with torch.no_grad():
        embeddings = []
        for s, g in zip(states, goals):
            emb = get_embedding(encoder, s, g, n_states)
            embeddings.append(emb)
        embeddings = np.array(embeddings)
    
    # Convert to tensors
    embeddings_t = torch.tensor(embeddings, dtype=torch.float32, device=device)
    actions_t = torch.tensor(actions, dtype=torch.long, device=device)
    
    n_samples = len(states)
    losses = []
    
    policy.train()
    for epoch in range(epochs):
        # Shuffle data
        perm = torch.randperm(n_samples)
        epoch_loss = 0.0
        n_batches = 0
        
        for i in range(0, n_samples, batch_size):
            idx = perm[i:i+batch_size]
            batch_emb = embeddings_t[idx]
            batch_actions = actions_t[idx]
            
            # Forward pass
            logits = policy(batch_emb)
            loss = criterion(logits, batch_actions)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            n_batches += 1
        
        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    return losses


# Train both policies
print("="*60)
print("Training Raw State Policy")
print("="*60)
raw_policy = GCBCPolicyRaw(n_states=6, n_actions=3, hidden_dim=64).to(device)
raw_losses = train_raw_policy(raw_policy, gcbc_states, gcbc_goals, gcbc_actions, epochs=100)

print("\n" + "="*60)
print("Training Embedding Policy")
print("="*60)
emb_policy = GCBCPolicyEmbedding(embedding_dim=2, n_actions=3, hidden_dim=64).to(device)
emb_losses = train_embedding_policy(emb_policy, model, gcbc_states, gcbc_goals, gcbc_actions, epochs=100)

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(raw_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Raw State Policy Training')
axes[0].grid(True, alpha=0.3)

axes[1].plot(emb_losses)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Embedding Policy Training')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Policy Evaluation

Evaluate both policies (raw and embedding-based) in two modes:
1. **Deterministic**: Always take argmax action
2. **Stochastic**: Sample from action distribution

Metrics: Average steps to reach goal from start state

In [None]:
# =============================================================================
# Policy Evaluation
# =============================================================================

def evaluate_raw_policy(policy, mdp, n_episodes=1000, max_steps=1000, deterministic=True):
    """
    Evaluate raw state policy.
    
    Args:
        policy: GCBCPolicyRaw
        mdp: SimpleMDP
        n_episodes: Number of evaluation episodes
        max_steps: Maximum steps per episode
        deterministic: If True, use argmax; else sample from distribution
    
    Returns:
        dict with mean_steps, std_steps, success_rate
    """
    policy.eval()
    steps_list = []
    successes = 0
    
    for _ in range(n_episodes):
        state = mdp.start_state
        goal = mdp.goal_state
        steps = 0
        
        while state != goal and steps < max_steps:
            # Get action from policy
            action = policy.get_action(state, goal, deterministic=deterministic)
            
            # Clip action to valid range for this state
            action = min(action, mdp.n_actions[state] - 1)
            
            # Take step in environment
            state = mdp.step(state, action)
            steps += 1
        
        if state == goal:
            successes += 1
            steps_list.append(steps)
        else:
            steps_list.append(max_steps)  # Failure case
    
    return {
        'mean_steps': np.mean(steps_list),
        'std_steps': np.std(steps_list),
        'success_rate': successes / n_episodes,
        'steps_list': steps_list
    }


def evaluate_embedding_policy(policy, encoder, mdp, n_episodes=1000, max_steps=1000, deterministic=True):
    """
    Evaluate embedding policy.
    
    Args:
        policy: GCBCPolicyEmbedding
        encoder: HyperbolicIntervalEncoder
        mdp: SimpleMDP
        n_episodes: Number of evaluation episodes
        max_steps: Maximum steps per episode
        deterministic: If True, use argmax; else sample from distribution
    
    Returns:
        dict with mean_steps, std_steps, success_rate
    """
    policy.eval()
    encoder.eval()
    steps_list = []
    successes = 0
    
    for _ in range(n_episodes):
        state = mdp.start_state
        goal = mdp.goal_state
        steps = 0
        
        while state != goal and steps < max_steps:
            # Get embedding
            emb = get_embedding(encoder, state, goal, mdp.n_states)
            
            # Get action from policy
            action = policy.get_action(emb, deterministic=deterministic)
            
            # Clip action to valid range for this state
            action = min(action, mdp.n_actions[state] - 1)
            
            # Take step in environment
            state = mdp.step(state, action)
            steps += 1
        
        if state == goal:
            successes += 1
            steps_list.append(steps)
        else:
            steps_list.append(max_steps)
    
    return {
        'mean_steps': np.mean(steps_list),
        'std_steps': np.std(steps_list),
        'success_rate': successes / n_episodes,
        'steps_list': steps_list
    }


def evaluate_random_policy(mdp, n_episodes=1000, max_steps=1000):
    """Evaluate random policy (baseline)."""
    steps_list = []
    successes = 0
    
    for _ in range(n_episodes):
        state = mdp.start_state
        steps = 0
        
        while state != mdp.goal_state and steps < max_steps:
            state = mdp.step(state, action=None)  # Random action
            steps += 1
        
        if state == mdp.goal_state:
            successes += 1
            steps_list.append(steps)
        else:
            steps_list.append(max_steps)
    
    return {
        'mean_steps': np.mean(steps_list),
        'std_steps': np.std(steps_list),
        'success_rate': successes / n_episodes,
        'steps_list': steps_list
    }


# Evaluate all policies
print("="*70)
print("POLICY EVALUATION")
print("="*70)

n_eval_episodes = 5000

# Random baseline
print("\nRandom Policy (Baseline):")
random_results = evaluate_random_policy(mdp, n_episodes=n_eval_episodes)
print(f"  Mean steps: {random_results['mean_steps']:.2f} ± {random_results['std_steps']:.2f}")
print(f"  Success rate: {random_results['success_rate']*100:.1f}%")

# Raw policy - deterministic
print("\nRaw State Policy (Deterministic):")
raw_det_results = evaluate_raw_policy(raw_policy, mdp, n_episodes=n_eval_episodes, deterministic=True)
print(f"  Mean steps: {raw_det_results['mean_steps']:.2f} ± {raw_det_results['std_steps']:.2f}")
print(f"  Success rate: {raw_det_results['success_rate']*100:.1f}%")

# Raw policy - stochastic
print("\nRaw State Policy (Stochastic):")
raw_stoch_results = evaluate_raw_policy(raw_policy, mdp, n_episodes=n_eval_episodes, deterministic=False)
print(f"  Mean steps: {raw_stoch_results['mean_steps']:.2f} ± {raw_stoch_results['std_steps']:.2f}")
print(f"  Success rate: {raw_stoch_results['success_rate']*100:.1f}%")

# Embedding policy - deterministic
print("\nEmbedding Policy (Deterministic):")
emb_det_results = evaluate_embedding_policy(emb_policy, model, mdp, n_episodes=n_eval_episodes, deterministic=True)
print(f"  Mean steps: {emb_det_results['mean_steps']:.2f} ± {emb_det_results['std_steps']:.2f}")
print(f"  Success rate: {emb_det_results['success_rate']*100:.1f}%")

# Embedding policy - stochastic
print("\nEmbedding Policy (Stochastic):")
emb_stoch_results = evaluate_embedding_policy(emb_policy, model, mdp, n_episodes=n_eval_episodes, deterministic=False)
print(f"  Mean steps: {emb_stoch_results['mean_steps']:.2f} ± {emb_stoch_results['std_steps']:.2f}")
print(f"  Success rate: {emb_stoch_results['success_rate']*100:.1f}%")

In [None]:
# Visualize evaluation results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar plot of mean steps
ax = axes[0]
policies = ['Random', 'Raw\n(Det)', 'Raw\n(Stoch)', 'Emb\n(Det)', 'Emb\n(Stoch)']
means = [random_results['mean_steps'], raw_det_results['mean_steps'], 
         raw_stoch_results['mean_steps'], emb_det_results['mean_steps'], 
         emb_stoch_results['mean_steps']]
stds = [random_results['std_steps'], raw_det_results['std_steps'], 
        raw_stoch_results['std_steps'], emb_det_results['std_steps'], 
        emb_stoch_results['std_steps']]

colors = ['gray', 'steelblue', 'lightblue', 'coral', 'lightsalmon']
bars = ax.bar(policies, means, yerr=stds, capsize=5, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Mean Steps to Goal', fontsize=12)
ax.set_title('Policy Evaluation: Steps to Goal', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, mean in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
            f'{mean:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

# Success rate bar plot
ax = axes[1]
success_rates = [random_results['success_rate']*100, raw_det_results['success_rate']*100,
                 raw_stoch_results['success_rate']*100, emb_det_results['success_rate']*100,
                 emb_stoch_results['success_rate']*100]

bars = ax.bar(policies, success_rates, color=colors, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Success Rate (%)', fontsize=12)
ax.set_title('Policy Evaluation: Success Rate', fontsize=14, fontweight='bold')
ax.set_ylim(0, 105)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, rate in zip(bars, success_rates):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
            f'{rate:.0f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('policy_evaluation.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary table
print("\n" + "="*70)
print("EVALUATION SUMMARY")
print("="*70)
print(f"{'Policy':<25} {'Mean Steps':>12} {'Std':>10} {'Success':>12}")
print("-"*70)
print(f"{'Random (Baseline)':<25} {random_results['mean_steps']:>12.2f} {random_results['std_steps']:>10.2f} {random_results['success_rate']*100:>11.1f}%")
print(f"{'Raw State (Det)':<25} {raw_det_results['mean_steps']:>12.2f} {raw_det_results['std_steps']:>10.2f} {raw_det_results['success_rate']*100:>11.1f}%")
print(f"{'Raw State (Stoch)':<25} {raw_stoch_results['mean_steps']:>12.2f} {raw_stoch_results['std_steps']:>10.2f} {raw_stoch_results['success_rate']*100:>11.1f}%")
print(f"{'Embedding (Det)':<25} {emb_det_results['mean_steps']:>12.2f} {emb_det_results['std_steps']:>10.2f} {emb_det_results['success_rate']*100:>11.1f}%")
print(f"{'Embedding (Stoch)':<25} {emb_stoch_results['mean_steps']:>12.2f} {emb_stoch_results['std_steps']:>10.2f} {emb_stoch_results['success_rate']*100:>11.1f}%")
print("="*70)

## Hyperbolic Planning Mechanism

The planning mechanism works as follows:

1. **Compute embedding** φ(s₁, s₂) for the (start, goal) pair
2. **Find subgoal**: Find the lowest-norm atomic embedding φ(s', s') that lies along the radial line through φ(s₁, s₂)
3. **Decode state**: Map the embedding back to a concrete state s'
4. **Recurse**: Apply the algorithm recursively on (s₁, s') and (s', s₂)
5. **Display plan**: Show the resulting hierarchical plan

The intuition is that atomic states (s, s) with low norm represent "waypoints" that are easy to reach (low variance in hitting time), and the radial line represents the "direction" of the plan.

In [None]:
# =============================================================================
# Hyperbolic Planning
# =============================================================================

def get_atomic_embeddings(model, mdp):
    """
    Get embeddings for all atomic (state, state) pairs.
    
    Returns:
        dict: {state: embedding_array}
    """
    model.eval()
    atomic_emb = {}
    
    with torch.no_grad():
        for s in range(mdp.n_states):
            emb = get_embedding(model, s, s, mdp.n_states)
            atomic_emb[s] = emb
    
    return atomic_emb


def angular_distance(v1, v2):
    """
    Compute angular distance between two vectors.
    Returns angle in radians (0 to π).
    """
    # Normalize vectors
    v1_norm = v1 / (np.linalg.norm(v1) + 1e-10)
    v2_norm = v2 / (np.linalg.norm(v2) + 1e-10)
    
    # Dot product gives cosine of angle
    cos_angle = np.clip(np.dot(v1_norm, v2_norm), -1.0, 1.0)
    
    return np.arccos(cos_angle)


def find_subgoal_on_radial_line(target_emb, atomic_embeddings, start_state, goal_state, 
                                 angle_threshold=np.pi/4, exclude_states=None):
    """
    Find the best subgoal along the radial line from origin through target_emb.
    
    The subgoal should be:
    - An atomic state (s', s') that is different from start and goal
    - Close to the radial line (low angular distance)
    - Has the lowest hyperbolic norm among valid candidates
    
    Args:
        target_emb: Embedding of (start, goal) pair
        atomic_embeddings: Dict of {state: embedding} for atomic states
        start_state: Start state (to exclude)
        goal_state: Goal state (to exclude)
        angle_threshold: Maximum angular deviation from radial line
        exclude_states: Additional states to exclude
    
    Returns:
        best_state: The best subgoal state, or None if no valid candidate
        best_emb: Embedding of the subgoal
    """
    if exclude_states is None:
        exclude_states = set()
    
    exclude_states = set(exclude_states) | {start_state, goal_state}
    
    best_state = None
    best_norm = float('inf')
    best_emb = None
    
    target_angle = np.arctan2(target_emb[1], target_emb[0])
    
    for state, emb in atomic_embeddings.items():
        if state in exclude_states:
            continue
        
        # Check angular distance from radial line
        angle_dist = angular_distance(target_emb, emb)
        
        if angle_dist <= angle_threshold:
            norm = hyperbolic_norm(emb)
            if norm < best_norm:
                best_norm = norm
                best_state = state
                best_emb = emb
    
    return best_state, best_emb


def hyperbolic_plan(model, mdp, start_state, goal_state, max_depth=3, angle_threshold=np.pi/4):
    """
    Generate a hierarchical plan using hyperbolic embeddings.
    
    The algorithm:
    1. Get embedding φ(start, goal)
    2. Find lowest-norm atomic embedding φ(s', s') along radial line
    3. If found, recursively plan (start, s') and (s', goal)
    4. Return the hierarchical plan
    
    Args:
        model: Trained hyperbolic encoder
        mdp: MDP instance
        start_state: Starting state
        goal_state: Goal state
        max_depth: Maximum recursion depth
        angle_threshold: Angular tolerance for radial line matching
    
    Returns:
        plan: List of subgoals (hierarchical decomposition)
    """
    model.eval()
    
    # Get atomic embeddings
    atomic_emb = get_atomic_embeddings(model, mdp)
    
    def plan_recursive(s1, s2, depth, visited):
        """Recursive planning helper."""
        if depth >= max_depth:
            return [s1, s2]
        
        if s1 == s2:
            return [s1]
        
        # Get embedding for this pair
        pair_emb = get_embedding(model, s1, s2, mdp.n_states)
        
        # Find subgoal on radial line
        subgoal, subgoal_emb = find_subgoal_on_radial_line(
            pair_emb, atomic_emb, s1, s2, 
            angle_threshold=angle_threshold,
            exclude_states=visited
        )
        
        if subgoal is None or subgoal in visited:
            # No valid subgoal found
            return [s1, s2]
        
        # Recurse on both halves
        visited_new = visited | {subgoal}
        left_plan = plan_recursive(s1, subgoal, depth + 1, visited_new)
        right_plan = plan_recursive(subgoal, s2, depth + 1, visited_new)
        
        # Combine plans (avoiding duplicate subgoal)
        if right_plan and right_plan[0] == subgoal:
            return left_plan + right_plan[1:]
        return left_plan + right_plan
    
    plan = plan_recursive(start_state, goal_state, 0, set())
    return plan


def visualize_plan(model, mdp, start_state, goal_state, plan):
    """
    Visualize the planning process on the Poincare disk.
    """
    model.eval()
    
    fig, ax = plt.subplots(figsize=(10, 10))
    plot_poincare_disk(ax, f"Hyperbolic Plan: {pair_label(start_state, goal_state)}")
    
    # Get atomic embeddings
    atomic_emb = get_atomic_embeddings(model, mdp)
    
    # Plot all atomic embeddings in light gray
    for s, emb in atomic_emb.items():
        ax.scatter([emb[0]], [emb[1]], c='lightgray', s=60, zorder=2, alpha=0.7)
        ax.annotate(pair_label(s, s), (emb[0], emb[1]), fontsize=8, alpha=0.5,
                   xytext=(3, 3), textcoords='offset points')
    
    # Plot the main (start, goal) embedding
    main_emb = get_embedding(model, start_state, goal_state, mdp.n_states)
    ax.scatter([main_emb[0]], [main_emb[1]], c='red', s=200, marker='*', 
              edgecolors='black', linewidth=2, zorder=10,
              label=f'Target: {pair_label(start_state, goal_state)}')
    
    # Draw radial line from origin through target
    direction = main_emb / (np.linalg.norm(main_emb) + 1e-10)
    line_end = direction * 0.95  # Almost to boundary
    ax.plot([0, line_end[0]], [0, line_end[1]], 'r--', linewidth=2, alpha=0.5, label='Radial line')
    
    # Highlight plan states
    colors = plt.cm.viridis(np.linspace(0, 1, len(plan)))
    for i, state in enumerate(plan):
        emb = atomic_emb[state]
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=150, 
                  edgecolors='black', linewidth=2, zorder=5)
        ax.annotate(f"{i+1}:{pair_label(state, state)}", (emb[0], emb[1]), 
                   fontsize=11, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    # Draw arrows between plan states
    for i in range(len(plan) - 1):
        emb1 = atomic_emb[plan[i]]
        emb2 = atomic_emb[plan[i + 1]]
        ax.annotate('', xy=emb2, xytext=emb1,
                   arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    
    ax.legend(loc='upper left')
    
    plt.tight_layout()
    plt.savefig('hyperbolic_plan.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return fig


# Generate and display plan from state 1 to state 6
print("="*70)
print("HYPERBOLIC PLANNING")
print("="*70)

start = mdp.start_state  # State 1 (index 0)
goal = mdp.goal_state    # State 6 (index 5)

plan = hyperbolic_plan(model, mdp, start, goal, max_depth=3, angle_threshold=np.pi/3)

print(f"\nPlan from {pair_label(start, goal)}:")
print(f"  Sequence: {' -> '.join([pair_label(s, s) for s in plan])}")
print(f"  States: {' -> '.join([str(s+1) for s in plan])}")

# Visualize
visualize_plan(model, mdp, start, goal, plan)

In [None]:
# =============================================================================
# Analysis of Atomic Embeddings and Plans
# =============================================================================

# Get and display atomic embeddings
print("="*70)
print("ATOMIC EMBEDDINGS φ(s, s)")
print("="*70)
print(f"{'State':<10} {'Embedding':<25} {'Norm':<12} {'Angle (deg)':<12}")
print("-"*70)

atomic_emb = get_atomic_embeddings(model, mdp)
for s in range(mdp.n_states):
    emb = atomic_emb[s]
    norm = hyperbolic_norm(emb)
    angle = np.degrees(np.arctan2(emb[1], emb[0]))
    print(f"{pair_label(s, s):<10} ({emb[0]:+.4f}, {emb[1]:+.4f}){'':>5} {norm:<12.4f} {angle:+.1f}")

# Plot all atomic embeddings
fig, ax = plt.subplots(figsize=(10, 10))
plot_poincare_disk(ax, "All Atomic Embeddings φ(s, s)")

colors = plt.cm.tab10(np.linspace(0, 1, mdp.n_states))
for s in range(mdp.n_states):
    emb = atomic_emb[s]
    norm = hyperbolic_norm(emb)
    ax.scatter([emb[0]], [emb[1]], c=[colors[s]], s=150, 
              edgecolors='black', linewidth=2, zorder=5)
    ax.annotate(f"{pair_label(s, s)}\n(||·||={norm:.2f})", (emb[0], emb[1]), 
               fontsize=10, fontweight='bold',
               xytext=(8, 8), textcoords='offset points')

plt.tight_layout()
plt.savefig('atomic_embeddings.png', dpi=150, bbox_inches='tight')
plt.show()

# Generate plans from all non-goal states
print("\n" + "="*70)
print("PLANS FROM ALL STATES TO GOAL")
print("="*70)

for start_s in range(mdp.n_states):
    if start_s != mdp.goal_state:
        plan = hyperbolic_plan(model, mdp, start_s, mdp.goal_state, max_depth=3)
        print(f"  {pair_label(start_s, mdp.goal_state)}: {' -> '.join([str(s+1) for s in plan])}")

---

# Part 2: Shadow Cone Loss in Poincaré Half-Space

This section implements an alternative loss function using **Shadow Cone geometry** in the **Poincaré Half-Space model**.

## Key Differences from InfoNCE:

| Aspect | InfoNCE (Part 1) | Shadow Cone Loss |
|--------|------------------|------------------|
| **Geometry** | Poincaré Ball | Poincaré Half-Space |
| **Relation** | Symmetric similarity | Asymmetric containment |
| **What it learns** | "Same vs different" | "Contains vs doesn't contain" |
| **Structure** | Clusters | Hierarchy (tree-like) |

## Poincaré Half-Space Model

$$\mathbb{H}^n = \{(x_1, \ldots, x_{n-1}, y) \in \mathbb{R}^n : y > 0\}$$

- Points near boundary (y → 0) are "specific" (leaf intervals)
- Points higher (large y) are "general" (root intervals)
- Shadow cones point downward from parent to children

In [None]:
# =============================================================================
# Poincaré Half-Space Geometry
# =============================================================================

class HalfSpaceGeometry:
    """
    Utility class for Poincaré Half-Space computations.
    
    The half-space model: H^n = {(x, y) : y > 0}
    - x: horizontal coordinates (can be multi-dimensional)
    - y: vertical coordinate (must be positive)
    """
    
    def __init__(self, min_y=0.01):
        """
        Args:
            min_y: Minimum y value for numerical stability
        """
        self.min_y = min_y
    
    def clamp_y(self, y):
        """Ensure y > min_y for numerical stability."""
        return torch.clamp(y, min=self.min_y)
    
    def distance(self, p, q):
        """
        Compute hyperbolic distance in the half-space model.
        
        d_H(p, q) = arcosh(1 + ||p - q||^2 / (2 * y_p * y_q))
        
        Args:
            p: Points of shape (..., dim) where last coord is y
            q: Points of shape (..., dim) where last coord is y
        
        Returns:
            Distances of shape (...)
        """
        # Extract y coordinates (last dimension)
        y_p = self.clamp_y(p[..., -1])
        y_q = self.clamp_y(q[..., -1])
        
        # Squared Euclidean distance
        diff = p - q
        sq_dist = (diff ** 2).sum(dim=-1)
        
        # Hyperbolic distance formula
        cosh_dist = 1.0 + sq_dist / (2.0 * y_p * y_q)
        
        # Clamp for numerical stability (cosh_dist >= 1)
        cosh_dist = torch.clamp(cosh_dist, min=1.0 + 1e-7)
        
        return torch.acosh(cosh_dist)
    
    def is_below(self, v, u):
        """
        Check if v is below u (y_v < y_u).
        
        Returns:
            Boolean tensor of shape (...)
        """
        y_v = v[..., -1]
        y_u = u[..., -1]
        return y_v < y_u
    
    def horizontal_distance(self, p, q):
        """
        Compute horizontal (x) distance between points.
        
        Args:
            p, q: Points of shape (..., dim)
        
        Returns:
            Horizontal distances of shape (...)
        """
        # All coordinates except y (last one)
        x_p = p[..., :-1]
        x_q = q[..., :-1]
        return torch.norm(x_p - x_q, dim=-1)


# Test the geometry
geom = HalfSpaceGeometry(min_y=0.01)

# Test points
p1 = torch.tensor([0.0, 1.0])  # (x=0, y=1)
p2 = torch.tensor([0.0, 2.0])  # (x=0, y=2) - directly above
p3 = torch.tensor([1.0, 1.0])  # (x=1, y=1) - same height, different x

print("Half-Space Geometry Tests:")
print(f"  d(p1, p2) = {geom.distance(p1, p2).item():.4f}  (vertical)")
print(f"  d(p1, p3) = {geom.distance(p1, p3).item():.4f}  (horizontal)")
print(f"  d(p2, p3) = {geom.distance(p2, p3).item():.4f}  (diagonal)")
print(f"  p1 below p2? {geom.is_below(p1, p2).item()}")
print(f"  p2 below p1? {geom.is_below(p2, p1).item()}")

In [None]:
# =============================================================================
# Shadow Cone Geometry
# =============================================================================

class ShadowCone:
    """
    Shadow Cone computations in Poincaré Half-Space.
    
    A shadow cone from point u contains all points v that are:
    1. Below u (y_v < y_u)
    2. Within angular aperture r from the vertical axis through u
    
    The cone boundary is a hypercycle at hyperbolic distance r from the axis.
    """
    
    def __init__(self, aperture=0.1, min_y=0.01):
        """
        Args:
            aperture: Cone aperture radius r (in hyperbolic distance units)
            min_y: Minimum y value for numerical stability
        """
        self.aperture = aperture
        self.min_y = min_y
        self.geom = HalfSpaceGeometry(min_y=min_y)
    
    def cone_energy(self, u, v):
        """
        Compute energy E(u, v): distance from v to cone(u).
        
        E(u, v) = 0 if v is inside the cone
        E(u, v) > 0 if v is outside the cone
        
        Implementation:
        - If y_v >= y_u: E = d_H(u, v) (must go through apex)
        - If y_v < y_u: E = max(0, d_H(v, axis(u)) - r)
        
        For simplicity, we use:
        E(u, v) = d_H(u, v) - r * 1[y_v < y_u]
        
        Args:
            u: Parent embeddings of shape (batch, dim) or (dim,)
            v: Child embeddings of shape (batch, dim) or (dim,)
        
        Returns:
            Energy values of shape (batch,) or scalar
        """
        # Compute hyperbolic distance
        d = self.geom.distance(u, v)
        
        # Check if v is below u
        is_below = self.geom.is_below(v, u).float()
        
        # Energy: d - r if below, d if not below
        # When below, being in cone means d < r, so energy = max(0, d - r)
        # But we want smooth gradients, so use: d - r * is_below
        energy = d - self.aperture * is_below
        
        return energy
    
    def signed_cone_energy(self, u, v):
        """
        Compute signed energy: negative inside cone, positive outside.
        
        This gives stronger gradients for points already in the cone.
        
        Args:
            u: Parent embeddings of shape (batch, dim)
            v: Child embeddings of shape (batch, dim)
        
        Returns:
            Signed energy values
        """
        d = self.geom.distance(u, v)
        is_below = self.geom.is_below(v, u).float()
        
        # Compute distance to vertical axis through u
        # In 2D half-space, axis is vertical line x = x_u
        # Distance to axis in hyperbolic space approximated by horizontal distance
        x_u = u[..., :-1]
        x_v = v[..., :-1]
        y_v = self.geom.clamp_y(v[..., -1])
        
        # Approximate hyperbolic distance to axis
        horizontal_dist = torch.norm(x_v - x_u, dim=-1)
        axis_dist = torch.asinh(horizontal_dist / y_v)  # Approximate
        
        # Signed energy
        # If below and close to axis: negative (inside cone)
        # If below but far from axis: positive (outside cone)
        # If above: use distance through apex
        
        inside_cone = is_below * (axis_dist < self.aperture).float()
        
        energy = torch.where(
            inside_cone > 0.5,
            axis_dist - self.aperture,  # Negative if truly inside
            d  # Distance through apex
        )
        
        return energy
    
    def in_cone(self, u, v, strict=True):
        """
        Check if v is inside the shadow cone of u.
        
        Args:
            u: Parent embeddings
            v: Child embeddings
            strict: If True, require v strictly below u
        
        Returns:
            Boolean tensor
        """
        is_below = self.geom.is_below(v, u)
        
        if not strict:
            is_below = is_below | (v[..., -1] == u[..., -1])
        
        # Check angular/horizontal proximity
        x_u = u[..., :-1]
        x_v = v[..., :-1]
        y_v = self.geom.clamp_y(v[..., -1])
        
        horizontal_dist = torch.norm(x_v - x_u, dim=-1)
        axis_dist = torch.asinh(horizontal_dist / y_v)
        
        in_angular_range = axis_dist <= self.aperture
        
        return is_below & in_angular_range


# Test shadow cone
cone = ShadowCone(aperture=0.5, min_y=0.01)

# Parent at (0, 2), children at various positions
parent = torch.tensor([0.0, 2.0])
child_below_center = torch.tensor([0.0, 1.0])  # Directly below - should be in cone
child_below_side = torch.tensor([0.5, 1.0])    # Below but off-center
child_above = torch.tensor([0.0, 3.0])         # Above parent
child_far = torch.tensor([2.0, 0.5])           # Far away

print("\nShadow Cone Tests (aperture=0.5):")
print(f"  Parent: (0, 2)")
print(f"  Child below center (0, 1): energy={cone.cone_energy(parent, child_below_center).item():.4f}, in_cone={cone.in_cone(parent, child_below_center).item()}")
print(f"  Child below side (0.5, 1): energy={cone.cone_energy(parent, child_below_side).item():.4f}, in_cone={cone.in_cone(parent, child_below_side).item()}")
print(f"  Child above (0, 3): energy={cone.cone_energy(parent, child_above).item():.4f}, in_cone={cone.in_cone(parent, child_above).item()}")
print(f"  Child far (2, 0.5): energy={cone.cone_energy(parent, child_far).item():.4f}, in_cone={cone.in_cone(parent, child_far).item()}")

In [None]:
# =============================================================================
# Interval Containment and Sampling
# =============================================================================

def is_contained(child_interval, parent_interval):
    """
    Check if child_interval is temporally contained in parent_interval.
    
    [i_c, j_c] ⊂ [i_p, j_p] iff i_p <= i_c AND j_c <= j_p
    
    Args:
        child_interval: (i_c, j_c) tuple
        parent_interval: (i_p, j_p) tuple
    
    Returns:
        bool: True if child is contained in parent
    """
    i_c, j_c = child_interval
    i_p, j_p = parent_interval
    return i_p <= i_c and j_c <= j_p


def is_strictly_contained(child_interval, parent_interval):
    """
    Check if child is strictly contained (proper subset).
    
    Child ⊂ Parent AND Child ≠ Parent
    """
    if child_interval == parent_interval:
        return False
    return is_contained(child_interval, parent_interval)


def are_incomparable(interval1, interval2):
    """
    Check if two intervals are incomparable (neither contains the other).
    
    This includes:
    - Partial overlap: i1 < i2 < j1 < j2 or i2 < i1 < j2 < j1
    - Disjoint: j1 < i2 or j2 < i1
    """
    return (not is_contained(interval1, interval2) and 
            not is_contained(interval2, interval1))


def extract_all_intervals(trajectory_length):
    """
    Extract all interval pairs from a trajectory of given length.
    
    For trajectory of length T+1 (states s_0, ..., s_T),
    intervals are all (i, j) with 0 <= i < j <= T.
    
    Args:
        trajectory_length: T+1 (number of states)
    
    Returns:
        List of (i, j) tuples
    """
    T = trajectory_length - 1
    intervals = []
    for i in range(T + 1):
        for j in range(i + 1, T + 1):
            intervals.append((i, j))
    return intervals


def build_containment_pairs(intervals):
    """
    Build all (parent, child) containment pairs from intervals.
    
    Returns:
        List of (parent_interval, child_interval) tuples
    """
    pairs = []
    for parent in intervals:
        for child in intervals:
            if is_strictly_contained(child, parent):
                pairs.append((parent, child))
    return pairs


def build_incomparable_pairs(intervals):
    """
    Build all incomparable pairs from intervals.
    
    Returns:
        List of (interval1, interval2) tuples where neither contains the other
    """
    pairs = []
    n = len(intervals)
    for i in range(n):
        for j in range(i + 1, n):
            if are_incomparable(intervals[i], intervals[j]):
                pairs.append((intervals[i], intervals[j]))
    return pairs


# Test with a short trajectory
test_traj_len = 4  # States s_0, s_1, s_2, s_3
test_intervals = extract_all_intervals(test_traj_len)
print(f"Intervals for trajectory of length {test_traj_len}:")
print(f"  All intervals ({len(test_intervals)}): {test_intervals}")

containment_pairs = build_containment_pairs(test_intervals)
print(f"\nContainment pairs ({len(containment_pairs)}):")
for parent, child in containment_pairs[:5]:
    print(f"  {child} ⊂ {parent}")
if len(containment_pairs) > 5:
    print(f"  ... and {len(containment_pairs) - 5} more")

incomparable_pairs = build_incomparable_pairs(test_intervals)
print(f"\nIncomparable pairs ({len(incomparable_pairs)}):")
for int1, int2 in incomparable_pairs[:5]:
    print(f"  {int1} ⊥ {int2}")
if len(incomparable_pairs) > 5:
    print(f"  ... and {len(incomparable_pairs) - 5} more")

In [None]:
# =============================================================================
# Shadow Cone Dataset
# =============================================================================

class ShadowConeDataset(Dataset):
    """
    Dataset for Shadow Cone contrastive learning.
    
    Each sample contains:
    - Interval embeddings for a single trajectory
    - Positive pairs (containment relations)
    - Negative pairs (incomparable intervals)
    
    Args:
        trajectories: List of trajectory state sequences
        n_states: Number of states in MDP
        max_intervals_per_traj: Maximum intervals to sample per trajectory
        max_positive_pairs: Maximum positive pairs per sample
        max_negative_pairs: Maximum negative pairs per sample
        seed: Random seed
    """
    
    def __init__(
        self,
        trajectories,
        n_states=6,
        max_intervals_per_traj=50,
        max_positive_pairs=100,
        max_negative_pairs=100,
        seed=42,
    ):
        self.trajectories = trajectories
        self.n_states = n_states
        self.max_intervals = max_intervals_per_traj
        self.max_pos = max_positive_pairs
        self.max_neg = max_negative_pairs
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Filter trajectories with sufficient length
        self.valid_indices = [i for i, t in enumerate(trajectories) if len(t) >= 3]
        
        print(f"ShadowConeDataset: {len(self.valid_indices)} valid trajectories")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        """
        Get a sample: all data from one trajectory.
        
        Returns:
            dict with:
            - 'intervals': tensor of shape (n_intervals, 2) with normalized (start, end) states
            - 'interval_indices': tensor of shape (n_intervals, 2) with time indices (i, j)
            - 'positive_pairs': tensor of shape (n_pos, 2) with indices into intervals
            - 'negative_pairs': tensor of shape (n_neg, 2) with indices into intervals
        """
        traj_idx = self.valid_indices[idx]
        traj = self.trajectories[traj_idx]
        T = len(traj) - 1
        
        # Extract all intervals
        all_intervals = extract_all_intervals(len(traj))
        
        # Subsample if too many
        if len(all_intervals) > self.max_intervals:
            sampled_indices = np.random.choice(len(all_intervals), self.max_intervals, replace=False)
            intervals = [all_intervals[i] for i in sampled_indices]
        else:
            intervals = all_intervals
        
        # Build interval representations: (start_state_normalized, end_state_normalized)
        interval_reps = []
        interval_indices = []
        for (i, j) in intervals:
            start_state = traj[i]
            end_state = traj[j]
            # Normalize states to [0, 1]
            start_norm = start_state / (self.n_states - 1)
            end_norm = end_state / (self.n_states - 1)
            interval_reps.append([start_norm, end_norm])
            interval_indices.append([i, j])
        
        # Build containment pairs
        pos_pairs = []
        for p_idx, parent in enumerate(intervals):
            for c_idx, child in enumerate(intervals):
                if is_strictly_contained(child, parent):
                    pos_pairs.append([p_idx, c_idx])  # parent_idx, child_idx
        
        # Subsample positive pairs
        if len(pos_pairs) > self.max_pos:
            sampled = np.random.choice(len(pos_pairs), self.max_pos, replace=False)
            pos_pairs = [pos_pairs[i] for i in sampled]
        
        # Build incomparable pairs
        neg_pairs = []
        n_intervals = len(intervals)
        for i in range(n_intervals):
            for j in range(i + 1, n_intervals):
                if are_incomparable(intervals[i], intervals[j]):
                    neg_pairs.append([i, j])
        
        # Subsample negative pairs
        if len(neg_pairs) > self.max_neg:
            sampled = np.random.choice(len(neg_pairs), self.max_neg, replace=False)
            neg_pairs = [neg_pairs[i] for i in sampled]
        
        # Convert to tensors
        intervals_tensor = torch.tensor(interval_reps, dtype=torch.float32)
        indices_tensor = torch.tensor(interval_indices, dtype=torch.long)
        pos_tensor = torch.tensor(pos_pairs, dtype=torch.long) if pos_pairs else torch.zeros(0, 2, dtype=torch.long)
        neg_tensor = torch.tensor(neg_pairs, dtype=torch.long) if neg_pairs else torch.zeros(0, 2, dtype=torch.long)
        
        return {
            'intervals': intervals_tensor,
            'interval_indices': indices_tensor,
            'positive_pairs': pos_tensor,
            'negative_pairs': neg_tensor,
        }


def shadow_cone_collate_fn(batch):
    """
    Custom collate function for ShadowConeDataset.
    
    Since each trajectory has different numbers of intervals and pairs,
    we return lists instead of stacking.
    """
    return {
        'intervals': [item['intervals'] for item in batch],
        'interval_indices': [item['interval_indices'] for item in batch],
        'positive_pairs': [item['positive_pairs'] for item in batch],
        'negative_pairs': [item['negative_pairs'] for item in batch],
    }


# Create dataset
shadow_dataset = ShadowConeDataset(
    trajectories=trajectories,
    n_states=6,
    max_intervals_per_traj=50,
    max_positive_pairs=100,
    max_negative_pairs=100,
    seed=42
)

# Test one sample
sample = shadow_dataset[0]
print(f"\nSample from ShadowConeDataset:")
print(f"  Intervals shape: {sample['intervals'].shape}")
print(f"  Interval indices shape: {sample['interval_indices'].shape}")
print(f"  Positive pairs: {sample['positive_pairs'].shape[0]}")
print(f"  Negative pairs: {sample['negative_pairs'].shape[0]}")

In [None]:
# =============================================================================
# Half-Space Encoder
# =============================================================================

class HalfSpaceEncoder(nn.Module):
    """
    Encode (start_state, end_state) interval pairs to Poincaré Half-Space.
    
    Output: (x, y) where x is unrestricted and y > 0.
    
    Architecture:
    - Euclidean MLP layers
    - Final layer outputs (x, log_y) and we apply exp to get y > 0
    """
    
    def __init__(self, input_dim=2, hidden_dim=128, output_dim=2, min_y=0.01):
        """
        Args:
            input_dim: Input dimension (2 for normalized start/end states)
            hidden_dim: Hidden layer dimension
            output_dim: Output dimension (2 for 2D half-space)
            min_y: Minimum y value for numerical stability
        """
        super().__init__()
        self.output_dim = output_dim
        self.min_y = min_y
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        # Initialize output layer for reasonable starting positions
        # Want y to start around 0.5-1.0, so log_y around -0.7 to 0
        with torch.no_grad():
            self.network[-1].bias.data[-1] = 0.0  # log_y starts at 0 -> y = 1
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, input_dim) with normalized states
        
        Returns:
            Tensor of shape (batch, output_dim) in half-space (last dim is y > 0)
        """
        out = self.network(x)
        
        # Split into horizontal and vertical components
        x_coord = out[..., :-1]
        log_y = out[..., -1:]
        
        # Apply softplus to ensure y > min_y
        y = self.min_y + nn.functional.softplus(log_y)
        
        return torch.cat([x_coord, y], dim=-1)
    
    def get_embedding(self, start_state, end_state, n_states=6):
        """Get embedding for a single interval."""
        self.eval()
        with torch.no_grad():
            s_norm = start_state / (n_states - 1)
            e_norm = end_state / (n_states - 1)
            x = torch.tensor([[s_norm, e_norm]], dtype=torch.float32, device=device)
            return self.forward(x).squeeze(0).cpu().numpy()


# Test encoder
test_encoder = HalfSpaceEncoder(input_dim=2, hidden_dim=64, output_dim=2).to(device)
test_input = torch.tensor([[0.0, 1.0], [0.2, 0.8], [0.4, 0.6]], dtype=torch.float32, device=device)
test_output = test_encoder(test_input)

print("Half-Space Encoder Test:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Sample outputs:")
for i in range(test_input.shape[0]):
    x, y = test_output[i, 0].item(), test_output[i, 1].item()
    print(f"    Input {test_input[i].tolist()} -> (x={x:.4f}, y={y:.4f})")
print(f"  All y > 0? {(test_output[:, -1] > 0).all().item()}")

In [None]:
# =============================================================================
# Shadow Cone Loss Functions
# =============================================================================

class ShadowConeLoss(nn.Module):
    """
    Shadow Cone Contrastive Loss for hierarchical embeddings.
    
    Loss = L_positive + L_negative + L_regularization
    
    - L_positive: Push children into parent cones
    - L_negative: Push incomparable pairs apart
    - L_regularization: Encourage spread in y-coordinate (hierarchy levels)
    """
    
    def __init__(
        self,
        aperture=0.2,
        positive_margin=0.1,
        negative_margin=1.0,
        reg_weight=0.01,
        min_y=0.01,
    ):
        """
        Args:
            aperture: Cone aperture radius r
            positive_margin: Margin γ₁ for positive pairs
            negative_margin: Margin γ₂ for negative pairs
            reg_weight: Weight λ for regularization term
            min_y: Minimum y value for numerical stability
        """
        super().__init__()
        self.cone = ShadowCone(aperture=aperture, min_y=min_y)
        self.positive_margin = positive_margin
        self.negative_margin = negative_margin
        self.reg_weight = reg_weight
    
    def forward(self, embeddings, positive_pairs, negative_pairs):
        """
        Compute the shadow cone loss.
        
        Args:
            embeddings: Tensor of shape (n_intervals, dim) - all interval embeddings
            positive_pairs: Tensor of shape (n_pos, 2) with (parent_idx, child_idx)
            negative_pairs: Tensor of shape (n_neg, 2) with (idx1, idx2) incomparable
        
        Returns:
            Total loss and dict of component losses
        """
        losses = {}
        
        # Positive loss: push children into parent cones
        if positive_pairs.shape[0] > 0:
            parent_emb = embeddings[positive_pairs[:, 0]]  # (n_pos, dim)
            child_emb = embeddings[positive_pairs[:, 1]]   # (n_pos, dim)
            
            # Energy: distance from child to parent's cone
            energy = self.cone.cone_energy(parent_emb, child_emb)
            
            # Soft margin loss: log(1 + exp(E - γ₁))
            pos_loss = torch.log1p(torch.exp(energy - self.positive_margin)).mean()
            losses['positive'] = pos_loss
        else:
            pos_loss = torch.tensor(0.0, device=embeddings.device)
            losses['positive'] = pos_loss
        
        # Negative loss: push incomparable pairs apart
        if negative_pairs.shape[0] > 0:
            emb1 = embeddings[negative_pairs[:, 0]]
            emb2 = embeddings[negative_pairs[:, 1]]
            
            # For negatives, neither should be in the other's cone
            # Compute energy in both directions
            energy_12 = self.cone.cone_energy(emb1, emb2)  # emb2 in cone of emb1?
            energy_21 = self.cone.cone_energy(emb2, emb1)  # emb1 in cone of emb2?
            
            # Take minimum: we want BOTH energies to be large
            # If either is small, it's a violation
            min_energy = torch.min(energy_12, energy_21)
            
            # Soft margin loss: log(1 + exp(γ₂ - E))
            neg_loss = torch.log1p(torch.exp(self.negative_margin - min_energy)).mean()
            losses['negative'] = neg_loss
        else:
            neg_loss = torch.tensor(0.0, device=embeddings.device)
            losses['negative'] = neg_loss
        
        # Regularization: encourage spread in y-coordinate
        y_coords = embeddings[:, -1]
        if y_coords.shape[0] > 1:
            y_var = y_coords.var()
            reg_loss = self.reg_weight / (y_var + 1e-6)
            losses['regularization'] = reg_loss
        else:
            reg_loss = torch.tensor(0.0, device=embeddings.device)
            losses['regularization'] = reg_loss
        
        total_loss = pos_loss + neg_loss + reg_loss
        losses['total'] = total_loss
        
        return total_loss, losses


class ShadowConeInfoNCELoss(nn.Module):
    """
    InfoNCE-style Shadow Cone Loss.
    
    For each parent with child and negatives:
    L = -log(exp(-E(parent, child)) / (exp(-E(parent, child)) + Σ exp(-E(parent, neg))))
    """
    
    def __init__(self, aperture=0.2, temperature=0.5, min_y=0.01):
        super().__init__()
        self.cone = ShadowCone(aperture=aperture, min_y=min_y)
        self.temperature = temperature
    
    def forward(self, embeddings, positive_pairs, negative_pairs):
        """
        Compute InfoNCE-style shadow cone loss.
        
        Args:
            embeddings: Tensor of shape (n_intervals, dim)
            positive_pairs: Tensor of shape (n_pos, 2) with (parent_idx, child_idx)
            negative_pairs: Tensor of shape (n_neg, 2) with incomparable pairs
        
        Returns:
            Loss and component dict
        """
        if positive_pairs.shape[0] == 0:
            return torch.tensor(0.0, device=embeddings.device), {'total': 0.0}
        
        parent_emb = embeddings[positive_pairs[:, 0]]
        child_emb = embeddings[positive_pairs[:, 1]]
        
        # Positive energies
        pos_energy = self.cone.cone_energy(parent_emb, child_emb)
        pos_score = -pos_energy / self.temperature
        
        # For negatives, sample from incomparable pairs
        if negative_pairs.shape[0] > 0:
            # For each positive pair, get negative scores
            # Use all negative interval pairs as negatives
            neg_emb = embeddings[negative_pairs[:, 1]]  # Use second element as negative
            
            # Expand for broadcasting: (n_pos, 1, dim) and (1, n_neg, dim)
            parent_expanded = parent_emb.unsqueeze(1)
            neg_expanded = neg_emb.unsqueeze(0)
            
            # Compute energies: (n_pos, n_neg)
            # Need to handle shapes carefully
            n_pos = parent_emb.shape[0]
            n_neg = neg_emb.shape[0]
            
            neg_energies = torch.zeros(n_pos, n_neg, device=embeddings.device)
            for i in range(n_pos):
                neg_energies[i] = self.cone.cone_energy(
                    parent_emb[i:i+1].expand(n_neg, -1),
                    neg_emb
                )
            
            neg_scores = -neg_energies / self.temperature
            
            # InfoNCE loss
            all_scores = torch.cat([pos_score.unsqueeze(1), neg_scores], dim=1)
            labels = torch.zeros(n_pos, dtype=torch.long, device=embeddings.device)
            loss = nn.functional.cross_entropy(all_scores, labels)
        else:
            loss = torch.tensor(0.0, device=embeddings.device)
        
        return loss, {'total': loss.item() if isinstance(loss, torch.Tensor) else loss}


# Test loss functions
print("Testing Shadow Cone Loss:")
test_emb = torch.tensor([
    [0.0, 2.0],   # High y - parent
    [0.1, 1.0],   # Below, close to axis - child (in cone)
    [1.0, 0.5],   # Below, far from axis - negative
    [0.0, 3.0],   # Above - negative
], dtype=torch.float32)

test_pos = torch.tensor([[0, 1]], dtype=torch.long)  # parent=0, child=1
test_neg = torch.tensor([[0, 2], [1, 3]], dtype=torch.long)  # incomparable pairs

loss_fn = ShadowConeLoss(aperture=0.3, positive_margin=0.1, negative_margin=0.5)
loss, components = loss_fn(test_emb, test_pos, test_neg)

print(f"  Total loss: {loss.item():.4f}")
for name, val in components.items():
    if isinstance(val, torch.Tensor):
        print(f"  {name}: {val.item():.4f}")
    else:
        print(f"  {name}: {val:.4f}")

In [None]:
# =============================================================================
# Training with Shadow Cone Loss
# =============================================================================

def train_shadow_cone_model(
    encoder,
    dataset,
    num_epochs=100,
    lr=0.001,
    aperture=0.2,
    positive_margin=0.1,
    negative_margin=1.0,
    reg_weight=0.01,
):
    """
    Train the half-space encoder with shadow cone loss.
    
    Args:
        encoder: HalfSpaceEncoder model
        dataset: ShadowConeDataset
        num_epochs: Number of training epochs
        lr: Learning rate
        aperture: Cone aperture parameter
        positive_margin: Margin for positive pairs
        negative_margin: Margin for negative pairs
        reg_weight: Regularization weight
    
    Returns:
        Training history (list of dicts)
    """
    optimizer = optim.Adam(encoder.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
    
    loss_fn = ShadowConeLoss(
        aperture=aperture,
        positive_margin=positive_margin,
        negative_margin=negative_margin,
        reg_weight=reg_weight,
    )
    
    dataloader = DataLoader(
        dataset, 
        batch_size=1,  # Process one trajectory at a time
        shuffle=True,
        collate_fn=shadow_cone_collate_fn
    )
    
    history = []
    encoder.train()
    
    for epoch in range(num_epochs):
        epoch_losses = {'total': 0, 'positive': 0, 'negative': 0, 'regularization': 0}
        n_batches = 0
        
        for batch in dataloader:
            # Process each trajectory in the batch
            for intervals, pos_pairs, neg_pairs in zip(
                batch['intervals'], batch['positive_pairs'], batch['negative_pairs']
            ):
                if intervals.shape[0] < 2:
                    continue
                
                intervals = intervals.to(device)
                pos_pairs = pos_pairs.to(device)
                neg_pairs = neg_pairs.to(device)
                
                # Forward pass: encode all intervals
                embeddings = encoder(intervals)
                
                # Compute loss
                loss, components = loss_fn(embeddings, pos_pairs, neg_pairs)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
                optimizer.step()
                
                # Track losses
                for key in epoch_losses:
                    if key in components:
                        val = components[key]
                        if isinstance(val, torch.Tensor):
                            epoch_losses[key] += val.item()
                        else:
                            epoch_losses[key] += val
                n_batches += 1
        
        scheduler.step()
        
        # Average losses
        if n_batches > 0:
            for key in epoch_losses:
                epoch_losses[key] /= n_batches
        
        history.append(epoch_losses)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}: "
                  f"total={epoch_losses['total']:.4f}, "
                  f"pos={epoch_losses['positive']:.4f}, "
                  f"neg={epoch_losses['negative']:.4f}, "
                  f"reg={epoch_losses['regularization']:.4f}")
    
    return history


# Train the shadow cone model
print("="*70)
print("Training Shadow Cone Model")
print("="*70)

shadow_encoder = HalfSpaceEncoder(
    input_dim=2, 
    hidden_dim=128, 
    output_dim=2,
    min_y=0.01
).to(device)

shadow_history = train_shadow_cone_model(
    shadow_encoder,
    shadow_dataset,
    num_epochs=150,
    lr=0.001,
    aperture=0.3,
    positive_margin=0.1,
    negative_margin=0.8,
    reg_weight=0.005,
)

# Plot training curves
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

losses_dict = {k: [h[k] for h in shadow_history] for k in ['total', 'positive', 'negative', 'regularization']}

for ax, (name, values) in zip(axes, losses_dict.items()):
    ax.plot(values)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f'{name.capitalize()} Loss')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('shadow_cone_training.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# Half-Space Visualization
# =============================================================================

def plot_half_space(ax, title="", y_max=3.0):
    """Draw half-space boundary and setup axes."""
    ax.axhline(y=0, color='black', linewidth=2, linestyle='-')
    ax.fill_between([-3, 3], 0, -0.5, color='lightgray', alpha=0.5)
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-0.2, y_max)
    ax.set_xlabel('x (horizontal)')
    ax.set_ylabel('y (height = generality)')
    ax.grid(True, alpha=0.3)
    if title:
        ax.set_title(title, fontsize=13, fontweight='bold')


def visualize_half_space_embeddings(encoder, mdp, hitting_stats):
    """
    Visualize embeddings for (state, goal) pairs in half-space.
    """
    encoder.eval()
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Get embeddings for all (state, final_goal) pairs
    embeddings = {}
    for state in range(mdp.n_states):
        if state != mdp.goal_state:
            emb = encoder.get_embedding(state, mdp.goal_state, mdp.n_states)
            embeddings[state] = emb
    
    states = sorted(embeddings.keys())
    coords = np.array([embeddings[s] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # Determine y_max for all plots
    y_max = coords[:, 1].max() * 1.3
    
    # Plot 1: Colored by mean hitting time
    ax = axes[0]
    plot_half_space(ax, "Colored by Mean Hitting Time", y_max)
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=means, cmap='viridis',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Plot 2: Colored by variance
    ax = axes[1]
    plot_half_space(ax, "Colored by Variance", y_max)
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Plot 3: Show shadow cones
    ax = axes[2]
    plot_half_space(ax, "Shadow Cone Structure", y_max)
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(states)))
    for i, s in enumerate(states):
        emb = embeddings[s]
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=150,
                  edgecolors='black', linewidth=2, zorder=5,
                  label=f"{pair_label(s, mdp.goal_state)}")
        
        # Draw shadow cone (simplified as triangle)
        cone_aperture = 0.3
        y_bottom = 0.01
        width_at_bottom = (emb[1] - y_bottom) * np.sinh(cone_aperture)
        
        triangle = plt.Polygon([
            [emb[0], emb[1]],
            [emb[0] - width_at_bottom, y_bottom],
            [emb[0] + width_at_bottom, y_bottom]
        ], alpha=0.2, color=colors[i])
        ax.add_patch(triangle)
        
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    ax.legend(loc='upper right', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('half_space_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return embeddings


# Visualize
print("\n" + "="*70)
print("Half-Space Embeddings Visualization")
print("="*70)

shadow_embeddings = visualize_half_space_embeddings(shadow_encoder, mdp, hitting_stats)

In [None]:
# =============================================================================
# Shadow Cone Analysis and Evaluation
# =============================================================================

def analyze_shadow_cone_embeddings(encoder, mdp, hitting_stats):
    """
    Analyze the learned half-space embeddings.
    
    Expected structure:
    - y-coordinate should correlate with interval "size" (generality)
    - Larger intervals should have higher y (more general)
    - States with high variance should have lower y (more specific)
    """
    encoder.eval()
    
    # Get embeddings for all (state, goal) pairs
    embeddings = {}
    for s in range(mdp.n_states):
        if s != mdp.goal_state:
            emb = encoder.get_embedding(s, mdp.goal_state, mdp.n_states)
            embeddings[s] = emb
    
    states = sorted(embeddings.keys())
    
    # Extract coordinates
    x_coords = np.array([embeddings[s][0] for s in states])
    y_coords = np.array([embeddings[s][1] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # Compute correlations
    # Hypothesis: y ~ 1/variance (more specific = lower y, higher variance)
    # Hypothesis: x ~ mean (similar means = similar x)
    
    corr_y_var, p_y_var = stats.spearmanr(y_coords, vars_)
    corr_y_mean, p_y_mean = stats.spearmanr(y_coords, means)
    corr_x_mean, p_x_mean = stats.spearmanr(x_coords, means)
    
    print("="*70)
    print("SHADOW CONE EMBEDDING ANALYSIS")
    print("="*70)
    
    print("\nExpected structure in Half-Space:")
    print("  - High y = general (large intervals)")
    print("  - Low y = specific (small intervals / high variance)")
    print("  - x separates different 'branches' of the hierarchy")
    
    print("\nCorrelations:")
    print(f"  y vs variance:  ρ = {corr_y_var:.4f} (p = {p_y_var:.4f})")
    print(f"  y vs mean:      ρ = {corr_y_mean:.4f} (p = {p_y_mean:.4f})")
    print(f"  x vs mean:      ρ = {corr_x_mean:.4f} (p = {p_x_mean:.4f})")
    
    print("\n" + "="*70)
    print("STATE-BY-STATE ANALYSIS")
    print("="*70)
    print(f"{'Pair':<10} {'x':<12} {'y':<12} {'Mean T':<10} {'Var T':<10}")
    print("-"*70)
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        print(f"{label:<10} {emb[0]:<12.4f} {emb[1]:<12.4f} "
              f"{hitting_stats[s]['mean']:<10.2f} {hitting_stats[s]['var']:<10.2f}")
    
    return {
        'embeddings': embeddings,
        'corr_y_var': corr_y_var,
        'corr_y_mean': corr_y_mean,
        'corr_x_mean': corr_x_mean,
    }


def evaluate_containment_accuracy(encoder, trajectories, mdp, n_samples=100):
    """
    Evaluate whether containment relationships are correctly embedded.
    
    For (parent, child) pairs where child ⊂ parent:
    - child should be in parent's shadow cone (y_child < y_parent and close to axis)
    """
    encoder.eval()
    cone = ShadowCone(aperture=0.3, min_y=0.01)
    
    correct_pos = 0
    total_pos = 0
    correct_neg = 0
    total_neg = 0
    
    # Sample trajectories
    np.random.seed(42)
    sampled_trajs = np.random.choice(len(trajectories), min(n_samples, len(trajectories)), replace=False)
    
    for traj_idx in sampled_trajs:
        traj = trajectories[traj_idx]
        if len(traj) < 3:
            continue
        
        intervals = extract_all_intervals(len(traj))
        if len(intervals) < 2:
            continue
        
        # Get embeddings
        with torch.no_grad():
            interval_emb = []
            for (i, j) in intervals:
                s_norm = traj[i] / (mdp.n_states - 1)
                e_norm = traj[j] / (mdp.n_states - 1)
                x = torch.tensor([[s_norm, e_norm]], dtype=torch.float32, device=device)
                emb = encoder(x).squeeze(0)
                interval_emb.append(emb)
            interval_emb = torch.stack(interval_emb)
        
        # Check containment pairs (sample subset)
        pos_pairs = build_containment_pairs(intervals)
        if len(pos_pairs) > 20:
            pos_indices = np.random.choice(len(pos_pairs), 20, replace=False)
            pos_pairs = [pos_pairs[i] for i in pos_indices]
        
        for parent_int, child_int in pos_pairs:
            parent_idx = intervals.index(parent_int)
            child_idx = intervals.index(child_int)
            
            parent_emb = interval_emb[parent_idx]
            child_emb = interval_emb[child_idx]
            
            in_cone = cone.in_cone(parent_emb, child_emb).item()
            if in_cone:
                correct_pos += 1
            total_pos += 1
        
        # Check incomparable pairs (sample subset)
        neg_pairs = build_incomparable_pairs(intervals)
        if len(neg_pairs) > 20:
            neg_indices = np.random.choice(len(neg_pairs), 20, replace=False)
            neg_pairs = [neg_pairs[i] for i in neg_indices]
        
        for int1, int2 in neg_pairs:
            idx1 = intervals.index(int1)
            idx2 = intervals.index(int2)
            
            emb1 = interval_emb[idx1]
            emb2 = interval_emb[idx2]
            
            # Neither should be in the other's cone
            not_in_1 = not cone.in_cone(emb1, emb2).item()
            not_in_2 = not cone.in_cone(emb2, emb1).item()
            
            if not_in_1 and not_in_2:
                correct_neg += 1
            total_neg += 1
    
    pos_accuracy = correct_pos / total_pos if total_pos > 0 else 0
    neg_accuracy = correct_neg / total_neg if total_neg > 0 else 0
    
    print("\n" + "="*70)
    print("CONTAINMENT EVALUATION")
    print("="*70)
    print(f"Positive pairs (child should be in parent's cone):")
    print(f"  Accuracy: {pos_accuracy*100:.1f}% ({correct_pos}/{total_pos})")
    print(f"\nNegative pairs (neither should be in the other's cone):")
    print(f"  Accuracy: {neg_accuracy*100:.1f}% ({correct_neg}/{total_neg})")
    
    return {
        'positive_accuracy': pos_accuracy,
        'negative_accuracy': neg_accuracy,
    }


# Run analysis
shadow_analysis = analyze_shadow_cone_embeddings(shadow_encoder, mdp, hitting_stats)

# Evaluate containment accuracy
containment_results = evaluate_containment_accuracy(shadow_encoder, trajectories, mdp, n_samples=200)

In [None]:
# =============================================================================
# Comparison: InfoNCE vs Shadow Cone
# =============================================================================

def plot_comparison(poincare_embeddings, half_space_embeddings, hitting_stats, mdp):
    """
    Side-by-side comparison of Poincaré Ball (InfoNCE) and Half-Space (Shadow Cone) embeddings.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    states = sorted(poincare_embeddings.keys())
    
    # Poincaré Ball embeddings
    ax = axes[0]
    plot_poincare_disk(ax, "Poincaré Ball (InfoNCE Loss)")
    
    coords = np.array([poincare_embeddings[s] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma_r',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance', shrink=0.8)
    
    for s in states:
        emb = poincare_embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Half-Space embeddings
    ax = axes[1]
    hs_coords = np.array([half_space_embeddings[s] for s in states])
    y_max = hs_coords[:, 1].max() * 1.3
    
    plot_half_space(ax, "Half-Space (Shadow Cone Loss)", y_max)
    
    scatter = ax.scatter(hs_coords[:, 0], hs_coords[:, 1], c=vars_, cmap='plasma_r',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance', shrink=0.8)
    
    for s in states:
        emb = half_space_embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    plt.tight_layout()
    plt.savefig('embedding_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()


# Compare the two approaches
print("="*70)
print("COMPARISON: InfoNCE vs Shadow Cone Loss")
print("="*70)

# Get Poincaré Ball embeddings from the first model
poincare_emb = get_state_embeddings(model, mdp)

# Get Half-Space embeddings from the shadow cone model
half_space_emb = {s: shadow_embeddings[s] for s in shadow_embeddings}

plot_comparison(poincare_emb, half_space_emb, hitting_stats, mdp)

# Summary statistics
print("\n" + "="*70)
print("SUMMARY COMPARISON")
print("="*70)

print("\n--- Poincaré Ball (InfoNCE) ---")
print(f"  Norm vs Variance correlation: {results['corr_norm_var']:.4f}")
print(f"  Angle vs Mean correlation:    {results['corr_angle_mean']:.4f}")

print("\n--- Half-Space (Shadow Cone) ---")
print(f"  y vs Variance correlation:    {shadow_analysis['corr_y_var']:.4f}")
print(f"  y vs Mean correlation:        {shadow_analysis['corr_y_mean']:.4f}")
print(f"  Containment accuracy:         {containment_results['positive_accuracy']*100:.1f}%")
print(f"  Separation accuracy:          {containment_results['negative_accuracy']*100:.1f}%")

print("\n" + "="*70)
print("KEY DIFFERENCES")
print("="*70)
print("""
| Aspect            | InfoNCE (Ball)                | Shadow Cone (Half-Space)      |
|-------------------|-------------------------------|-------------------------------|
| Geometry          | Poincaré Ball                 | Poincaré Half-Space           |
| Loss function     | Symmetric contrastive         | Asymmetric containment        |
| Learns            | "Same vs different" clusters  | Hierarchical containment      |
| Structure         | Radial (norm/angle)           | Vertical hierarchy (y)        |
| Interpretability  | Norm ~ variance, angle ~ mean | y ~ generality, x ~ branch    |
""")
print("="*70)

---

# Part 3: Shadow Cone Loss in Poincaré Ball

This section implements the **Shadow Cone Loss** in the **Poincaré Ball** model.

## Key Differences from Half-Space Shadow Cone:

| Aspect | Half-Space | Poincaré Ball |
|--------|------------|---------------|
| **Most general** | High y (top) | Origin (norm ≈ 0) |
| **Most specific** | y → 0 (boundary) | Boundary (norm → 1) |
| **Parent-child** | Parent has higher y | Parent has lower norm |
| **Cone direction** | Downward | Outward from origin |

## Hierarchy Convention

- **Origin** (norm = 0): most general/abstract (root intervals)
- **Boundary** (norm → 1): most specific/concrete (leaf intervals)
- **Parents** have **lower norm** than children

In [None]:
# =============================================================================
# Poincaré Ball Geometry (for Shadow Cone)
# =============================================================================

class PoincareBallGeometry:
    """
    Utility class for Poincaré Ball computations.
    
    The ball model: B^n = {x ∈ R^n : ||x|| < 1}
    """
    
    def __init__(self, eps=1e-5, max_norm=0.95):
        """
        Args:
            eps: Small constant for numerical stability
            max_norm: Maximum allowed norm (keep away from boundary)
        """
        self.eps = eps
        self.max_norm = max_norm
    
    def distance(self, u, v):
        """
        Compute hyperbolic distance in the Poincaré ball.
        
        d_B(u, v) = arcosh(1 + 2 * ||u - v||^2 / ((1 - ||u||^2)(1 - ||v||^2)))
        
        Args:
            u, v: Points of shape (..., dim)
        
        Returns:
            Distances of shape (...)
        """
        diff_norm_sq = torch.sum((u - v)**2, dim=-1)
        u_norm_sq = torch.sum(u**2, dim=-1)
        v_norm_sq = torch.sum(v**2, dim=-1)
        
        # Clamp norms to avoid numerical issues at boundary
        u_norm_sq = torch.clamp(u_norm_sq, max=self.max_norm**2)
        v_norm_sq = torch.clamp(v_norm_sq, max=self.max_norm**2)
        
        denom = (1 - u_norm_sq) * (1 - v_norm_sq)
        denom = torch.clamp(denom, min=self.eps)
        
        arg = 1 + 2 * diff_norm_sq / denom
        arg = torch.clamp(arg, min=1 + self.eps)
        
        return torch.acosh(arg)
    
    def norm(self, x):
        """Compute Euclidean norm."""
        return torch.norm(x, dim=-1)
    
    def project_to_ball(self, x):
        """
        Project points to inside the ball (norm < max_norm).
        
        Uses smooth tanh projection.
        """
        norm = torch.norm(x, dim=-1, keepdim=True)
        # Smooth projection: tanh maps R -> (-1, 1)
        scale = torch.tanh(norm) * self.max_norm / (norm + self.eps)
        return x * scale
    
    def clip_to_ball(self, x):
        """
        Hard clip to ball (simpler, but less smooth gradients).
        """
        norm = torch.norm(x, dim=-1, keepdim=True)
        scale = torch.clamp(self.max_norm / (norm + self.eps), max=1.0)
        return x * scale
    
    def cosine_similarity(self, u, v):
        """
        Compute cosine similarity between vectors.
        """
        u_norm = torch.norm(u, dim=-1)
        v_norm = torch.norm(v, dim=-1)
        dot = torch.sum(u * v, dim=-1)
        return dot / (u_norm * v_norm + self.eps)
    
    def safe_normalize(self, x):
        """Normalize vector, handling zero vectors."""
        norm = torch.norm(x, dim=-1, keepdim=True)
        return x / (norm + self.eps)


# Test the geometry
ball_geom = PoincareBallGeometry()

# Test points
p1 = torch.tensor([0.0, 0.0])   # Origin
p2 = torch.tensor([0.5, 0.0])   # Mid-radius
p3 = torch.tensor([0.9, 0.0])   # Near boundary
p4 = torch.tensor([0.0, 0.5])   # Different angle

print("Poincaré Ball Geometry Tests:")
print(f"  d(origin, mid) = {ball_geom.distance(p1, p2).item():.4f}")
print(f"  d(origin, near_boundary) = {ball_geom.distance(p1, p3).item():.4f}")
print(f"  d(mid, near_boundary) = {ball_geom.distance(p2, p3).item():.4f}")
print(f"  d(mid_x, mid_y) = {ball_geom.distance(p2, p4).item():.4f}")
print(f"  cos(p2, p4) = {ball_geom.cosine_similarity(p2, p4).item():.4f}")

In [None]:
# =============================================================================
# Shadow Cone in Poincaré Ball
# =============================================================================

class BallShadowCone:
    """
    Shadow Cone computations in Poincaré Ball.
    
    Imagine a point light source at the origin. Each point u is an opaque ball.
    The shadow cone of u contains all points "behind" u from the origin's perspective.
    
    For child v to be in parent u's shadow cone:
    1. ||v|| > ||u|| (child has higher norm - further from origin)
    2. v is angularly close to u (within cone aperture)
    
    Convention:
    - Origin (low norm) = general/abstract (parents)
    - Boundary (high norm) = specific/concrete (children)
    """
    
    def __init__(self, 
                 cone_angle_deg=30.0,
                 radial_margin=0.05,
                 angular_weight=1.0,
                 eps=1e-5,
                 max_norm=0.95):
        """
        Args:
            cone_angle_deg: Maximum half-angle of cone in degrees
            radial_margin: Margin for radial ordering
            angular_weight: Weight for angular component in energy
            eps: Small constant for numerical stability
            max_norm: Maximum allowed norm
        """
        self.cone_angle_rad = np.radians(cone_angle_deg)
        self.cos_threshold = np.cos(self.cone_angle_rad)
        self.radial_margin = radial_margin
        self.angular_weight = angular_weight
        self.eps = eps
        self.geom = PoincareBallGeometry(eps=eps, max_norm=max_norm)
    
    def radial_energy(self, u, v):
        """
        Compute radial component of energy.
        
        For containment (u=parent, v=child), we want ||v|| > ||u||.
        Energy is high if this is violated.
        
        E_radial = ReLU(||u|| - ||v|| + margin)
        
        Args:
            u: Parent embeddings (should have lower norm)
            v: Child embeddings (should have higher norm)
        
        Returns:
            Radial energy (0 if constraint satisfied with margin)
        """
        u_norm = self.geom.norm(u)
        v_norm = self.geom.norm(v)
        
        # Want v_norm > u_norm + margin
        # Violation: u_norm - v_norm + margin > 0
        return torch.relu(u_norm - v_norm + self.radial_margin)
    
    def angular_energy(self, u, v):
        """
        Compute angular component of energy.
        
        For containment, v should be angularly aligned with u.
        cos(angle(u, v)) should be >= cos_threshold
        
        E_angular = ReLU(cos_threshold - cos(angle(u, v)))
        
        Args:
            u: Parent embeddings
            v: Child embeddings
        
        Returns:
            Angular energy (0 if v is within cone angle of u)
        """
        cos_sim = self.geom.cosine_similarity(u, v)
        
        # Want cos_sim >= cos_threshold
        # Violation: cos_threshold - cos_sim > 0
        return torch.relu(self.cos_threshold - cos_sim)
    
    def cone_energy(self, u, v):
        """
        Compute total energy for (parent, child) pair.
        
        E(u, v) = E_radial + λ * E_angular
        
        Low energy means v is properly inside u's shadow cone.
        
        Args:
            u: Parent embeddings
            v: Child embeddings
        
        Returns:
            Total cone energy
        """
        e_radial = self.radial_energy(u, v)
        e_angular = self.angular_energy(u, v)
        
        return e_radial + self.angular_weight * e_angular
    
    def hyperbolic_cone_energy(self, u, v, aperture=0.3):
        """
        Alternative energy using hyperbolic distance to geodesic.
        
        More faithful to the hyperbolic geometry but more complex.
        
        Args:
            u: Parent embeddings
            v: Child embeddings
            aperture: Cone aperture in hyperbolic distance units
        
        Returns:
            Energy based on hyperbolic distance to geodesic axis
        """
        u_norm = self.geom.norm(u)
        v_norm = self.geom.norm(v)
        
        # Radial check
        radial_violation = torch.relu(u_norm - v_norm + self.radial_margin)
        
        # For points where radial is satisfied, check angular alignment
        # Project v onto the geodesic through u (which is the ray from origin through u)
        u_normalized = self.geom.safe_normalize(u)
        
        # Projection of v onto ray through u
        # v_parallel = (v · û) * û
        dot_product = torch.sum(v * u_normalized, dim=-1, keepdim=True)
        v_parallel = dot_product * u_normalized
        
        # Distance from v to its projection on the axis
        d_to_axis = self.geom.distance(v, v_parallel)
        
        # Angular violation: d_to_axis > aperture
        angular_violation = torch.relu(d_to_axis - aperture)
        
        # Combine: if radial violated, use radial + hyperbolic distance
        # if radial satisfied, use angular violation
        radial_ok = (u_norm < v_norm).float()
        
        energy = radial_violation + radial_ok * angular_violation
        
        return energy
    
    def in_cone(self, u, v):
        """
        Check if v is inside u's shadow cone.
        
        Returns:
            Boolean tensor
        """
        u_norm = self.geom.norm(u)
        v_norm = self.geom.norm(v)
        cos_sim = self.geom.cosine_similarity(u, v)
        
        radial_ok = v_norm > u_norm
        angular_ok = cos_sim >= self.cos_threshold
        
        return radial_ok & angular_ok


# Test Ball Shadow Cone
ball_cone = BallShadowCone(cone_angle_deg=30.0, radial_margin=0.05)

# Parent near origin, children at various positions
parent = torch.tensor([0.3, 0.0])
child_in_cone = torch.tensor([0.6, 0.1])      # Higher norm, aligned
child_wrong_angle = torch.tensor([0.1, 0.6])  # Higher norm, wrong angle
child_wrong_radius = torch.tensor([0.2, 0.0]) # Lower norm than parent
child_at_origin = torch.tensor([0.01, 0.01])  # Near origin

print("\nBall Shadow Cone Tests (angle=30°, margin=0.05):")
print(f"  Parent: norm={ball_cone.geom.norm(parent).item():.3f}")
print(f"\n  Child in cone (0.6, 0.1):")
print(f"    norm={ball_cone.geom.norm(child_in_cone).item():.3f}, "
      f"energy={ball_cone.cone_energy(parent, child_in_cone).item():.4f}, "
      f"in_cone={ball_cone.in_cone(parent, child_in_cone).item()}")
print(f"\n  Child wrong angle (0.1, 0.6):")
print(f"    norm={ball_cone.geom.norm(child_wrong_angle).item():.3f}, "
      f"energy={ball_cone.cone_energy(parent, child_wrong_angle).item():.4f}, "
      f"in_cone={ball_cone.in_cone(parent, child_wrong_angle).item()}")
print(f"\n  Child wrong radius (0.2, 0.0):")
print(f"    norm={ball_cone.geom.norm(child_wrong_radius).item():.3f}, "
      f"energy={ball_cone.cone_energy(parent, child_wrong_radius).item():.4f}, "
      f"in_cone={ball_cone.in_cone(parent, child_wrong_radius).item()}")

In [None]:
# =============================================================================
# Ball Shadow Cone Encoder and Loss
# =============================================================================

class BallShadowConeEncoder(nn.Module):
    """
    Encode interval pairs to Poincaré Ball with explicit projection.
    
    Unlike the hypll-based encoder, this one uses a simple tanh projection
    to ensure outputs are inside the ball.
    """
    
    def __init__(self, input_dim=2, hidden_dim=128, output_dim=2, max_norm=0.95):
        super().__init__()
        self.max_norm = max_norm
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        self.geom = PoincareBallGeometry(max_norm=max_norm)
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, input_dim)
        
        Returns:
            Embeddings in Poincaré Ball of shape (batch, output_dim)
        """
        raw = self.network(x)
        # Project to ball using smooth tanh projection
        return self.geom.project_to_ball(raw)
    
    def get_embedding(self, start_state, end_state, n_states=6):
        """Get embedding for a single interval."""
        self.eval()
        with torch.no_grad():
            s_norm = start_state / (n_states - 1)
            e_norm = end_state / (n_states - 1)
            x = torch.tensor([[s_norm, e_norm]], dtype=torch.float32, device=device)
            return self.forward(x).squeeze(0).cpu().numpy()


class BallShadowConeLoss(nn.Module):
    """
    Shadow Cone Loss for Poincaré Ball embeddings.
    
    Loss = L_positive + μ * L_negative + L_regularization
    
    - L_positive: Push children into parent cones (higher norm, aligned)
    - L_negative: Push incomparable pairs to different cones
    - L_regularization: Prevent collapse and encourage spread
    """
    
    def __init__(
        self,
        cone_angle_deg=30.0,
        radial_margin=0.05,
        angular_weight=1.0,
        positive_margin=0.0,
        negative_margin=0.5,
        neg_weight=1.0,
        temperature=5.0,
        boundary_weight=0.1,
        origin_weight=0.01,
        spread_weight=0.01,
        max_norm=0.95,
        min_norm=0.01,
    ):
        super().__init__()
        
        self.cone = BallShadowCone(
            cone_angle_deg=cone_angle_deg,
            radial_margin=radial_margin,
            angular_weight=angular_weight,
            max_norm=max_norm,
        )
        
        self.positive_margin = positive_margin
        self.negative_margin = negative_margin
        self.neg_weight = neg_weight
        self.temperature = temperature
        
        # Regularization weights
        self.boundary_weight = boundary_weight
        self.origin_weight = origin_weight
        self.spread_weight = spread_weight
        self.max_norm = max_norm
        self.min_norm = min_norm
    
    def positive_loss(self, parent_emb, child_emb):
        """
        Loss for positive (containment) pairs.
        
        Child should be in parent's shadow cone.
        L_pos = mean(softplus(β * (E(parent, child) - γ₁)))
        """
        energy = self.cone.cone_energy(parent_emb, child_emb)
        
        # Softplus loss with temperature
        loss = torch.log1p(torch.exp(self.temperature * (energy - self.positive_margin)))
        
        return loss.mean()
    
    def negative_loss(self, emb1, emb2):
        """
        Loss for negative (incomparable) pairs.
        
        Neither should be in the other's cone.
        L_neg = mean(ReLU(γ₂ - E(u, v)) + ReLU(γ₂ - E(v, u)))
        """
        # Energy in both directions
        energy_12 = self.cone.cone_energy(emb1, emb2)
        energy_21 = self.cone.cone_energy(emb2, emb1)
        
        # Penalize if either energy is too low (meaning one is in the other's cone)
        loss_12 = torch.relu(self.negative_margin - energy_12)
        loss_21 = torch.relu(self.negative_margin - energy_21)
        
        return (loss_12 + loss_21).mean()
    
    def regularization_loss(self, embeddings):
        """
        Regularization to prevent collapse and encourage spread.
        """
        norms = self.cone.geom.norm(embeddings)
        
        # Boundary penalty: keep away from boundary
        boundary_loss = torch.relu(norms - self.max_norm).pow(2).mean()
        
        # Origin penalty: keep away from origin (prevent collapse)
        origin_loss = torch.relu(self.min_norm - norms).pow(2).mean()
        
        # Spread penalty: encourage variance in norms
        if norms.shape[0] > 1:
            norm_var = norms.var()
            spread_loss = 1.0 / (norm_var + 1e-6)
        else:
            spread_loss = torch.tensor(0.0, device=embeddings.device)
        
        return (self.boundary_weight * boundary_loss + 
                self.origin_weight * origin_loss + 
                self.spread_weight * spread_loss)
    
    def forward(self, embeddings, positive_pairs, negative_pairs):
        """
        Compute total loss.
        
        Args:
            embeddings: All interval embeddings (n_intervals, dim)
            positive_pairs: (n_pos, 2) with (parent_idx, child_idx)
            negative_pairs: (n_neg, 2) with (idx1, idx2) incomparable
        
        Returns:
            Total loss and component dict
        """
        losses = {}
        
        # Positive loss
        if positive_pairs.shape[0] > 0:
            parent_emb = embeddings[positive_pairs[:, 0]]
            child_emb = embeddings[positive_pairs[:, 1]]
            pos_loss = self.positive_loss(parent_emb, child_emb)
        else:
            pos_loss = torch.tensor(0.0, device=embeddings.device)
        losses['positive'] = pos_loss
        
        # Negative loss
        if negative_pairs.shape[0] > 0:
            emb1 = embeddings[negative_pairs[:, 0]]
            emb2 = embeddings[negative_pairs[:, 1]]
            neg_loss = self.negative_loss(emb1, emb2)
        else:
            neg_loss = torch.tensor(0.0, device=embeddings.device)
        losses['negative'] = neg_loss
        
        # Regularization
        reg_loss = self.regularization_loss(embeddings)
        losses['regularization'] = reg_loss
        
        # Total
        total_loss = pos_loss + self.neg_weight * neg_loss + reg_loss
        losses['total'] = total_loss
        
        return total_loss, losses


# Test the loss
print("Testing Ball Shadow Cone Loss:")
test_encoder = BallShadowConeEncoder(input_dim=2, hidden_dim=64, output_dim=2).to(device)
test_loss_fn = BallShadowConeLoss(cone_angle_deg=30.0, radial_margin=0.05)

# Test forward pass
test_input = torch.tensor([[0.0, 1.0], [0.2, 0.8], [0.4, 0.6], [0.5, 0.5]], dtype=torch.float32, device=device)
test_emb = test_encoder(test_input)

print(f"  Input shape: {test_input.shape}")
print(f"  Embedding shape: {test_emb.shape}")
print(f"  Embedding norms: {[f'{n:.3f}' for n in torch.norm(test_emb, dim=-1).tolist()]}")

# Test loss
test_pos = torch.tensor([[0, 1]], dtype=torch.long, device=device)  # 0 is parent of 1
test_neg = torch.tensor([[0, 2], [1, 3]], dtype=torch.long, device=device)

loss, components = test_loss_fn(test_emb, test_pos, test_neg)
print(f"\n  Loss components:")
for name, val in components.items():
    print(f"    {name}: {val.item():.4f}")

In [None]:
# =============================================================================
# Training Ball Shadow Cone Model
# =============================================================================

def train_ball_shadow_cone_model(
    encoder,
    dataset,
    num_epochs=150,
    lr=0.001,
    cone_angle_deg=30.0,
    radial_margin=0.05,
    positive_margin=0.0,
    negative_margin=0.5,
    neg_weight=1.0,
    temperature=5.0,
):
    """
    Train the ball encoder with shadow cone loss.
    """
    optimizer = optim.Adam(encoder.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
    
    loss_fn = BallShadowConeLoss(
        cone_angle_deg=cone_angle_deg,
        radial_margin=radial_margin,
        positive_margin=positive_margin,
        negative_margin=negative_margin,
        neg_weight=neg_weight,
        temperature=temperature,
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=shadow_cone_collate_fn
    )
    
    history = []
    encoder.train()
    
    for epoch in range(num_epochs):
        epoch_losses = {'total': 0, 'positive': 0, 'negative': 0, 'regularization': 0}
        n_batches = 0
        
        for batch in dataloader:
            for intervals, pos_pairs, neg_pairs in zip(
                batch['intervals'], batch['positive_pairs'], batch['negative_pairs']
            ):
                if intervals.shape[0] < 2:
                    continue
                
                intervals = intervals.to(device)
                pos_pairs = pos_pairs.to(device)
                neg_pairs = neg_pairs.to(device)
                
                # Forward pass
                embeddings = encoder(intervals)
                
                # Compute loss
                loss, components = loss_fn(embeddings, pos_pairs, neg_pairs)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
                optimizer.step()
                
                for key in epoch_losses:
                    if key in components:
                        val = components[key]
                        if isinstance(val, torch.Tensor):
                            epoch_losses[key] += val.item()
                        else:
                            epoch_losses[key] += val
                n_batches += 1
        
        scheduler.step()
        
        if n_batches > 0:
            for key in epoch_losses:
                epoch_losses[key] /= n_batches
        
        history.append(epoch_losses)
        
        if (epoch + 1) % 25 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}: "
                  f"total={epoch_losses['total']:.4f}, "
                  f"pos={epoch_losses['positive']:.4f}, "
                  f"neg={epoch_losses['negative']:.4f}, "
                  f"reg={epoch_losses['regularization']:.4f}")
    
    return history


# Train the Ball Shadow Cone model
print("="*70)
print("Training Ball Shadow Cone Model")
print("="*70)

ball_shadow_encoder = BallShadowConeEncoder(
    input_dim=2,
    hidden_dim=128,
    output_dim=2,
    max_norm=0.95
).to(device)

ball_shadow_history = train_ball_shadow_cone_model(
    ball_shadow_encoder,
    shadow_dataset,
    num_epochs=150,
    lr=0.001,
    cone_angle_deg=35.0,
    radial_margin=0.05,
    positive_margin=0.0,
    negative_margin=0.3,
    neg_weight=1.0,
    temperature=5.0,
)

# Plot training curves
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

losses_dict = {k: [h[k] for h in ball_shadow_history] for k in ['total', 'positive', 'negative', 'regularization']}

for ax, (name, values) in zip(axes, losses_dict.items()):
    ax.plot(values)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f'{name.capitalize()} Loss')
    ax.grid(True, alpha=0.3)

plt.suptitle('Ball Shadow Cone Training', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('ball_shadow_cone_training.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# Visualization and Analysis: Ball Shadow Cone
# =============================================================================

def visualize_ball_shadow_cone_embeddings(encoder, mdp, hitting_stats):
    """
    Visualize Ball Shadow Cone embeddings with cone visualization.
    """
    encoder.eval()
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Get embeddings
    embeddings = {}
    for state in range(mdp.n_states):
        if state != mdp.goal_state:
            emb = encoder.get_embedding(state, mdp.goal_state, mdp.n_states)
            embeddings[state] = emb
    
    states = sorted(embeddings.keys())
    coords = np.array([embeddings[s] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    norms = np.array([np.linalg.norm(embeddings[s]) for s in states])
    
    # Plot 1: Colored by mean
    ax = axes[0]
    plot_poincare_disk(ax, "Colored by Mean Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=means, cmap='viridis',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Plot 2: Colored by variance
    ax = axes[1]
    plot_poincare_disk(ax, "Colored by Variance")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for s in states:
        emb = embeddings[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=11, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Plot 3: Shadow cone visualization
    ax = axes[2]
    plot_poincare_disk(ax, "Shadow Cone Structure (from origin)")
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(states)))
    cone_angle = np.radians(35)  # Same as training
    
    for i, s in enumerate(states):
        emb = embeddings[s]
        norm = np.linalg.norm(emb)
        
        # Draw point
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=150,
                  edgecolors='black', linewidth=2, zorder=5)
        
        # Draw cone (sector from this point outward)
        if norm > 0.05:
            angle = np.arctan2(emb[1], emb[0])
            # Draw cone as arc
            theta1 = angle - cone_angle
            theta2 = angle + cone_angle
            
            # Draw lines from point to boundary
            for theta in [theta1, theta2]:
                end_x = 0.95 * np.cos(theta)
                end_y = 0.95 * np.sin(theta)
                ax.plot([emb[0], end_x], [emb[1], end_y], 
                       color=colors[i], linewidth=1, alpha=0.3)
            
            # Draw arc at boundary
            arc_theta = np.linspace(theta1, theta2, 30)
            arc_x = 0.95 * np.cos(arc_theta)
            arc_y = 0.95 * np.sin(arc_theta)
            ax.plot(arc_x, arc_y, color=colors[i], linewidth=2, alpha=0.5)
        
        label = pair_label(s, mdp.goal_state)
        ax.annotate(f"{label}\n(r={norm:.2f})", (emb[0], emb[1]), fontsize=9, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    plt.tight_layout()
    plt.savefig('ball_shadow_cone_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return embeddings


# Visualize
print("\n" + "="*70)
print("Ball Shadow Cone Embeddings")
print("="*70)

ball_shadow_embeddings = visualize_ball_shadow_cone_embeddings(ball_shadow_encoder, mdp, hitting_stats)

# Analyze
print("\nEmbedding Analysis:")
print(f"{'Pair':<10} {'x':<10} {'y':<10} {'Norm':<10} {'Mean T':<10} {'Var T':<10}")
print("-"*70)

for s in sorted(ball_shadow_embeddings.keys()):
    emb = ball_shadow_embeddings[s]
    norm = np.linalg.norm(emb)
    label = pair_label(s, mdp.goal_state)
    print(f"{label:<10} {emb[0]:<10.4f} {emb[1]:<10.4f} {norm:<10.4f} "
          f"{hitting_stats[s]['mean']:<10.2f} {hitting_stats[s]['var']:<10.2f}")

In [None]:
# =============================================================================
# Final Comparison: All Three Approaches
# =============================================================================

def plot_all_three_comparison(infonce_emb, halfspace_emb, ball_shadow_emb, hitting_stats, mdp):
    """
    Compare all three embedding approaches side by side.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    states = sorted(infonce_emb.keys())
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # 1. InfoNCE (Poincaré Ball)
    ax = axes[0]
    plot_poincare_disk(ax, "InfoNCE Loss\n(Poincaré Ball)")
    coords = np.array([infonce_emb[s] for s in states])
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma_r',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    for s in states:
        emb = infonce_emb[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=10, fontweight='bold',
                   xytext=(4, 4), textcoords='offset points')
    
    # 2. Shadow Cone (Half-Space)
    ax = axes[1]
    hs_coords = np.array([halfspace_emb[s] for s in states])
    y_max = hs_coords[:, 1].max() * 1.3
    plot_half_space(ax, "Shadow Cone Loss\n(Half-Space)", y_max)
    scatter = ax.scatter(hs_coords[:, 0], hs_coords[:, 1], c=vars_, cmap='plasma_r',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    for s in states:
        emb = halfspace_emb[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=10, fontweight='bold',
                   xytext=(4, 4), textcoords='offset points')
    
    # 3. Shadow Cone (Poincaré Ball)
    ax = axes[2]
    plot_poincare_disk(ax, "Shadow Cone Loss\n(Poincaré Ball)")
    ball_coords = np.array([ball_shadow_emb[s] for s in states])
    scatter = ax.scatter(ball_coords[:, 0], ball_coords[:, 1], c=vars_, cmap='plasma_r',
                        s=150, edgecolors='black', linewidth=2, zorder=5)
    for s in states:
        emb = ball_shadow_emb[s]
        label = pair_label(s, mdp.goal_state)
        ax.annotate(label, (emb[0], emb[1]), fontsize=10, fontweight='bold',
                   xytext=(4, 4), textcoords='offset points')
    
    # Add colorbar
    plt.colorbar(scatter, ax=axes[2], label='Variance (lighter = lower)')
    
    plt.tight_layout()
    plt.savefig('all_three_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()


# Get all embeddings
infonce_emb = get_state_embeddings(model, mdp)

# Plot comparison
print("="*70)
print("FINAL COMPARISON: All Three Approaches")
print("="*70)

plot_all_three_comparison(infonce_emb, shadow_embeddings, ball_shadow_embeddings, hitting_stats, mdp)

# Compute correlations for all approaches
print("\n" + "="*70)
print("CORRELATION ANALYSIS")
print("="*70)

for name, emb_dict in [("InfoNCE (Ball)", infonce_emb), 
                        ("Shadow Cone (Half-Space)", shadow_embeddings),
                        ("Shadow Cone (Ball)", ball_shadow_embeddings)]:
    states = sorted(emb_dict.keys())
    embs = np.array([emb_dict[s] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    
    if name == "Shadow Cone (Half-Space)":
        # Use y-coordinate
        key_coord = embs[:, 1]
        coord_name = "y"
    else:
        # Use norm
        key_coord = np.linalg.norm(embs, axis=1)
        coord_name = "norm"
    
    corr_var, p_var = stats.spearmanr(key_coord, vars_)
    corr_mean, p_mean = stats.spearmanr(key_coord, means)
    
    print(f"\n{name}:")
    print(f"  {coord_name} vs variance: ρ = {corr_var:+.4f} (p = {p_var:.4f})")
    print(f"  {coord_name} vs mean:     ρ = {corr_mean:+.4f} (p = {p_mean:.4f})")

## Shadow Cone Loss: Summary

### Implementation

We implemented the **Shadow Cone Loss** in the **Poincaré Half-Space** model:

1. **Half-Space Geometry**: Points $(x, y)$ with $y > 0$
   - Distance: $d_{\mathbb{H}}(p, q) = \text{arcosh}\left(1 + \frac{\|p - q\|^2}{2 y_p y_q}\right)$
   - Higher $y$ = more "general" (parent intervals)
   - Lower $y$ = more "specific" (child intervals)

2. **Shadow Cone**: Each parent point casts a "shadow cone" downward
   - Children (contained intervals) should fall inside the parent's cone
   - Incomparable intervals should be outside each other's cones

3. **Loss Function**:
   - **Positive loss**: Push children into parent cones
   - **Negative loss**: Push incomparable pairs apart  
   - **Regularization**: Encourage spread in y-coordinate

### Key Differences from InfoNCE

| Property | InfoNCE | Shadow Cone |
|----------|---------|-------------|
| Relation type | Symmetric similarity | Asymmetric containment |
| What it encodes | "Same vs different" | "Contains vs doesn't contain" |
| Geometric structure | Clusters | Tree/hierarchy |
| Boundary interpretation | Points at boundary are specific | Points near $y=0$ are specific |

### Use Cases

- **InfoNCE**: Good for learning similarity-based representations
- **Shadow Cone**: Good for learning hierarchical/compositional structure

The Shadow Cone approach explicitly encodes the **poset structure** induced by temporal containment, which may be beneficial for:
- Hierarchical planning (decomposing goals into subgoals)
- Transfer learning (sharing knowledge between related intervals)
- Interpretable representations (clear parent-child relationships)

## Conclusion

This notebook demonstrates:

1. **Hyperbolic Embeddings**: We trained a hyperbolic encoder to embed (state, goal) pairs from MDP trajectories using contrastive learning with temporal containment constraints.

2. **Hitting Time Conjecture**: We tested whether:
   - **Norm** correlates negatively with **variance** in hitting times
   - **Angle** correlates with **mean** hitting times

3. **GCBC Policies**: We trained goal-conditioned behavioral cloning policies using:
   - Raw state representations (one-hot)
   - Hyperbolic embeddings

4. **Hyperbolic Planning**: We implemented a planning mechanism that:
   - Finds subgoals along the radial line through the target embedding
   - Uses lowest-norm atomic embeddings as waypoints
   - Recursively decomposes the planning problem

### Key Insights

- The hyperbolic space naturally represents the hierarchical structure of goals
- Atomic states (s, s) with low norm may represent "easy-to-reach" waypoints
- The radial line provides a natural direction for planning
- This approach could scale to continuous state spaces with learned decoders