# MDP Hyperbolic Embeddings: Hitting Time Conjecture

This notebook explores hyperbolic embeddings for (state, goal) pairs in a 6-state MDP.

**Conjecture:**
- **Norm** (distance from origin) correlates **negatively** with **variance** in hitting times (low variance → high norm, high variance → low norm)
- **Angular coordinate** correlates with **mean** hitting times

## MDP Structure
```
State 1 (start)
    |
    |-- a11 (stochastic) --> 4 (p=0.5) or 5 (p=0.5) --> 6 (goal)
    |
    |-- a12 (deterministic) --> 2 --> (0.9 self-loop, 0.1 to 4) --> 6
    |
    |-- a13 (deterministic) --> 3 --> (0.9 self-loop, 0.1 to 5) --> 6
```

In [None]:
!pip install hypll geoopt

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
from scipy import stats

# Import hypll library
from hypll.manifolds.poincare_ball import PoincareBall, Curvature
from hypll.tensors.manifold_tensor import ManifoldTensor
from hypll.tensors import TangentTensor
import hypll.nn as hnn

# Setup
sns.set_style("whitegrid")
np.random.seed(42)
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## MDP Definition

In [None]:
class SimpleMDP:
    """
    6-state MDP with the following structure:
    
    States (0-indexed internally, displayed as 1-6):
        0 (State 1): Start state, 3 actions available
        1 (State 2): Self-loop (0.9) or to state 3 (0.1)
        2 (State 3): Self-loop (0.9) or to state 4 (0.1)
        3 (State 4): Deterministic to goal (state 5)
        4 (State 5): Deterministic to goal (state 5)
        5 (State 6): Goal (terminal)
    
    From state 0:
        - a0: Stochastic -> state 3 (p=0.5) or state 4 (p=0.5)
        - a1: Deterministic -> state 1
        - a2: Deterministic -> state 2
    """
    
    def __init__(self):
        self.n_states = 6
        self.start_state = 0  # State 1
        self.goal_state = 5   # State 6
        
        # Number of actions per state
        self.n_actions = {
            0: 3,  # State 1: a11, a12, a13
            1: 1,  # State 2: only one action (stochastic outcome)
            2: 1,  # State 3: only one action (stochastic outcome)
            3: 1,  # State 4: deterministic to goal
            4: 1,  # State 5: deterministic to goal
            5: 0,  # State 6: terminal
        }
        
    def get_transitions(self, state, action):
        """
        Get transition probabilities for (state, action) pair.
        Returns list of (next_state, probability) tuples.
        """
        if state == 0:  # State 1 (start)
            if action == 0:  # a11: stochastic
                return [(3, 0.5), (4, 0.5)]  # -> State 4 or 5
            elif action == 1:  # a12: deterministic
                return [(1, 1.0)]  # -> State 2
            elif action == 2:  # a13: deterministic
                return [(2, 1.0)]  # -> State 3
        
        elif state == 1:  # State 2
            return [(1, 0.9), (3, 0.1)]  # Self-loop or -> State 4
        
        elif state == 2:  # State 3
            return [(2, 0.9), (4, 0.1)]  # Self-loop or -> State 5
        
        elif state == 3:  # State 4
            return [(5, 1.0)]  # -> Goal (State 6)
        
        elif state == 4:  # State 5
            return [(5, 1.0)]  # -> Goal (State 6)
        
        elif state == 5:  # State 6 (goal)
            return [(5, 1.0)]  # Stay at goal
        
        return []
    
    def step(self, state, action=None):
        """
        Take a step from state with given action.
        If action is None, sample uniformly from available actions.
        Returns next_state.
        """
        if state == self.goal_state:
            return state
        
        if action is None:
            n_actions = self.n_actions[state]
            action = np.random.randint(0, max(1, n_actions))
        
        transitions = self.get_transitions(state, action)
        
        if len(transitions) == 1:
            return transitions[0][0]
        
        # Sample according to probabilities
        probs = [t[1] for t in transitions]
        next_states = [t[0] for t in transitions]
        return np.random.choice(next_states, p=probs)
    
    def state_name(self, state):
        """Return human-readable state name."""
        return f"S{state + 1}"


# Test the MDP
mdp = SimpleMDP()
print("MDP Structure Test:")
print(f"Start state: {mdp.state_name(mdp.start_state)}")
print(f"Goal state: {mdp.state_name(mdp.goal_state)}")
print()
for state in range(mdp.n_states):
    print(f"{mdp.state_name(state)}: {mdp.n_actions[state]} actions")
    for action in range(mdp.n_actions[state]):
        transitions = mdp.get_transitions(state, action)
        trans_str = ", ".join([f"{mdp.state_name(s)} (p={p})" for s, p in transitions])
        print(f"  a{action}: {trans_str}")

In [None]:
def visualize_mdp_graph(mdp):
    """
    Visualize the MDP as a directed graph.
    """
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Node positions (manually designed for clarity)
    positions = {
        0: (0, 0),      # S1 (start)
        1: (-2, -1.5),  # S2
        2: (2, -1.5),   # S3
        3: (-1, -3),    # S4
        4: (1, -3),     # S5
        5: (0, -4.5),   # S6 (goal)
    }
    
    # Draw nodes
    node_colors = {
        0: '#90EE90',  # Start - light green
        5: '#FFD700',  # Goal - gold
    }
    default_color = '#87CEEB'  # Light blue
    
    for state, pos in positions.items():
        color = node_colors.get(state, default_color)
        circle = plt.Circle(pos, 0.35, color=color, ec='black', linewidth=2, zorder=5)
        ax.add_patch(circle)
        
        # Label
        label = mdp.state_name(state)
        if state == 0:
            label += "\n(start)"
        elif state == 5:
            label += "\n(goal)"
        ax.text(pos[0], pos[1], label, ha='center', va='center', fontsize=11, fontweight='bold', zorder=6)
    
    # Draw edges with arrows
    def draw_arrow(start, end, label="", color='black', style='-', offset=0, curve=0):
        x1, y1 = positions[start]
        x2, y2 = positions[end]
        
        # Calculate direction
        dx, dy = x2 - x1, y2 - y1
        dist = np.sqrt(dx**2 + dy**2)
        
        # Shorten to not overlap with circles
        r = 0.4
        x1 += r * dx / dist
        y1 += r * dy / dist
        x2 -= r * dx / dist
        y2 -= r * dy / dist
        
        # Apply offset for parallel edges
        if offset != 0:
            perp_x, perp_y = -dy / dist, dx / dist
            x1 += offset * perp_x
            y1 += offset * perp_y
            x2 += offset * perp_x
            y2 += offset * perp_y
        
        if curve == 0:
            ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
                       arrowprops=dict(arrowstyle='->', color=color, lw=2, ls=style))
        else:
            ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
                       arrowprops=dict(arrowstyle='->', color=color, lw=2, ls=style,
                                      connectionstyle=f"arc3,rad={curve}"))
        
        if label:
            mid_x, mid_y = (x1 + x2) / 2 + offset * 0.3, (y1 + y2) / 2
            ax.text(mid_x, mid_y, label, fontsize=9, ha='center', va='bottom',
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8))
    
    # Draw self-loops
    def draw_self_loop(state, label="", direction='left'):
        x, y = positions[state]
        if direction == 'left':
            loop_x, loop_y = x - 0.6, y
            theta1, theta2 = 120, 240
        else:
            loop_x, loop_y = x + 0.6, y
            theta1, theta2 = -60, 60
        
        from matplotlib.patches import FancyArrowPatch, Arc
        arc = Arc((loop_x, loop_y), 0.5, 0.5, angle=0, theta1=theta1, theta2=theta2, 
                  color='black', lw=2)
        ax.add_patch(arc)
        
        # Arrow head
        if direction == 'left':
            ax.annotate("", xy=(x-0.35, y+0.15), xytext=(x-0.4, y+0.25),
                       arrowprops=dict(arrowstyle='->', color='black', lw=2))
        
        ax.text(loop_x - 0.3, loop_y, label, fontsize=9, ha='right', va='center',
               bbox=dict(boxstyle='round,pad=0.2', facecolor='lightyellow', alpha=0.9))
    
    # Edges from S1
    draw_arrow(0, 3, "a₁₁: p=0.5", color='blue', offset=-0.15, curve=-0.2)
    draw_arrow(0, 4, "a₁₁: p=0.5", color='blue', offset=0.15, curve=0.2)
    draw_arrow(0, 1, "a₁₂", color='green')
    draw_arrow(0, 2, "a₁₃", color='green')
    
    # Self-loops and exits for S2, S3
    draw_self_loop(1, "0.9", direction='left')
    draw_arrow(1, 3, "0.1", color='red')
    
    draw_self_loop(2, "0.9", direction='right')
    draw_arrow(2, 4, "0.1", color='red')
    
    # Deterministic to goal
    draw_arrow(3, 5, "1.0", color='purple')
    draw_arrow(4, 5, "1.0", color='purple')
    
    # Legend
    legend_elements = [
        plt.Line2D([0], [0], color='blue', lw=2, label='a₁₁: Stochastic'),
        plt.Line2D([0], [0], color='green', lw=2, label='a₁₂, a₁₃: Deterministic'),
        plt.Line2D([0], [0], color='red', lw=2, label='Stochastic exit'),
        plt.Line2D([0], [0], color='purple', lw=2, label='To goal'),
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
    
    ax.set_xlim(-4, 4)
    ax.set_ylim(-5.5, 1)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('MDP Structure', fontsize=16, fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.savefig('mdp_graph.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()


# Visualize the MDP
visualize_mdp_graph(mdp)

## Trajectory Generation

In [None]:
def generate_mdp_trajectories(mdp, n_trajectories, max_length=1000, seed=42):
    """
    Generate trajectories from start to goal under uniform random policy.
    
    Args:
        mdp: SimpleMDP instance
        n_trajectories: Number of trajectories to generate
        max_length: Maximum trajectory length
        seed: Random seed
    
    Returns:
        List of state sequences (lists of state indices)
    """
    np.random.seed(seed)
    trajectories = []
    
    for _ in range(n_trajectories):
        traj = [mdp.start_state]
        state = mdp.start_state
        
        for _ in range(max_length):
            if state == mdp.goal_state:
                break
            
            # Uniform random action selection
            state = mdp.step(state, action=None)
            traj.append(state)
        
        trajectories.append(traj)
    
    return trajectories


# Generate trajectories
trajectories = generate_mdp_trajectories(mdp, n_trajectories=5000, max_length=1000, seed=42)

# Analyze trajectories
lengths = [len(t) for t in trajectories]
print(f"Generated {len(trajectories)} trajectories")
print(f"Length stats: mean={np.mean(lengths):.1f}, std={np.std(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")

# Show sample trajectories
print("\nSample trajectories:")
for i in range(5):
    traj_str = " -> ".join([mdp.state_name(s) for s in trajectories[i][:15]])
    if len(trajectories[i]) > 15:
        traj_str += f" ... ({len(trajectories[i])} states total)"
    print(f"  {i+1}: {traj_str}")

## Contrastive Dataset

In [None]:
class MDPContrastiveDataset(Dataset):
    """
    Dataset for contrastive learning on MDP trajectory intervals.
    
    For each sample:
    - Anchor: [start_state, goal_state] interval from trajectory
    - Positive: Subinterval [k, l] where anchor_i <= k <= l <= anchor_j
    - Negatives: Intervals that are NOT subintervals of anchor
    
    States are normalized to [0, 1] by dividing by (n_states - 1).
    """
    
    def __init__(self, trajectories, n_states=6, num_samples=10000, n_negatives=5, seed=42):
        self.trajectories = trajectories
        self.n_states = n_states
        self.num_samples = num_samples
        self.n_negatives = n_negatives
        
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Filter valid trajectories (need at least 2 states for intervals)
        self.valid_traj_indices = [i for i, t in enumerate(trajectories) if len(t) >= 2]
        
        # Pre-generate all samples
        self.anchors, self.positives, self.negatives = self._generate_all_samples()
    
    def _normalize_state(self, state):
        """Normalize state index to [0, 1]."""
        return state / (self.n_states - 1)
    
    def _generate_all_samples(self):
        anchors = []
        positives = []
        negatives_list = []
        
        for _ in range(self.num_samples):
            anchor, positive, negs = self._generate_single_sample()
            anchors.append(anchor)
            positives.append(positive)
            negatives_list.append(negs)
        
        return (
            torch.tensor(anchors, dtype=torch.float32),
            torch.tensor(positives, dtype=torch.float32),
            torch.tensor(negatives_list, dtype=torch.float32),
        )
    
    def _generate_single_sample(self):
        """Generate a single (anchor, positive, negatives) tuple."""
        # Sample trajectory
        traj_idx = np.random.choice(self.valid_traj_indices)
        traj = self.trajectories[traj_idx]
        T = len(traj) - 1
        
        # Sample anchor interval [i, j] where i <= j
        j = np.random.randint(0, T + 1)
        i = np.random.randint(0, j + 1)
        
        anchor = [self._normalize_state(traj[i]), self._normalize_state(traj[j])]
        
        # Sample positive (subinterval): i <= k <= l <= j
        l = np.random.randint(i, j + 1)
        k = np.random.randint(i, l + 1)
        
        positive = [self._normalize_state(traj[k]), self._normalize_state(traj[l])]
        
        # Sample negatives (non-subintervals)
        negatives = []
        for _ in range(self.n_negatives):
            neg = self._sample_negative(traj, i, j, T)
            negatives.append(neg)
        
        return anchor, positive, negatives
    
    def _sample_negative(self, traj, anchor_i, anchor_j, T, max_attempts=1000):
        """Sample an interval that is NOT a subinterval of anchor."""
        for _ in range(max_attempts):
            l = np.random.randint(0, T + 1)
            k = np.random.randint(0, l + 1)
            
            # Check if NOT a subinterval (temporal containment)
            is_subinterval = (anchor_i <= k) and (l <= anchor_j)
            
            if not is_subinterval:
                return [self._normalize_state(traj[k]), self._normalize_state(traj[l])]
        
        # Fallback
        if anchor_i > 0:
            return [self._normalize_state(traj[0]), self._normalize_state(traj[0])]
        else:
            return [self._normalize_state(traj[T]), self._normalize_state(traj[T])]
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.anchors[idx], self.positives[idx], self.negatives[idx]


# Create dataset
dataset = MDPContrastiveDataset(
    trajectories=trajectories,
    n_states=6,
    num_samples=10000,
    n_negatives=5,
    seed=42
)

# Check shapes
anchor, positive, negatives = dataset[0]
print(f"Anchor shape: {anchor.shape}")
print(f"Positive shape: {positive.shape}")
print(f"Negatives shape: {negatives.shape}")
print(f"\nSample anchor: {anchor}")
print(f"Sample positive: {positive}")
print(f"Sample negatives[0]: {negatives[0]}")

## Hyperbolic Encoder

In [None]:
def manifold_map(x, manifold):
    """Map Euclidean tensor to hyperbolic manifold via exponential map."""
    tangents = TangentTensor(x, man_dim=-1, manifold=manifold)
    return manifold.expmap(tangents)


class HyperbolicIntervalEncoder(nn.Module):
    """
    Encode (state, goal) pairs to Poincare ball.
    Architecture: 2 Euclidean layers + 2 Hyperbolic layers
    """
    
    def __init__(self, embedding_dim=2, c=1.0, euc_width=128, hyp_width=128):
        super().__init__()
        
        # Create manifold
        curvature = Curvature(value=c, requires_grad=False)
        self.manifold = PoincareBall(c=curvature)
        
        # Euclidean layers
        self.euc_layer1 = nn.Linear(2, euc_width)
        self.euc_layer2 = nn.Linear(euc_width, hyp_width)
        self.euc_relu = nn.ReLU()
        
        # Hyperbolic layers
        self.hyp_layer1 = hnn.HLinear(
            in_features=hyp_width,
            out_features=hyp_width,
            manifold=self.manifold
        )
        self.hyp_layer2 = hnn.HLinear(
            in_features=hyp_width,
            out_features=embedding_dim,
            manifold=self.manifold
        )
        self.hyp_relu = hnn.HReLU(manifold=self.manifold)
    
    def forward(self, x):
        # Euclidean part
        x = self.euc_relu(self.euc_layer1(x))
        x = self.euc_layer2(x)
        
        # Map to hyperbolic space
        x = manifold_map(x, self.manifold)
        
        # Hyperbolic part
        x = self.hyp_relu(self.hyp_layer1(x))
        x = self.hyp_layer2(x)
        
        return x


# Test encoder
model = HyperbolicIntervalEncoder(embedding_dim=2, c=1.0).to(device)
test_input = torch.tensor([[0.0, 1.0], [0.2, 0.8]], dtype=torch.float32).to(device)
test_output = model(test_input)
print(f"Test input shape: {test_input.shape}")
print(f"Test output type: {type(test_output)}")
if isinstance(test_output, ManifoldTensor):
    print(f"Test output tensor shape: {test_output.tensor.shape}")
    print(f"Test output: {test_output.tensor}")

## Loss Function

In [None]:
def info_nce_loss(anchor, positive, negatives, manifold, temperature=0.5):
    """InfoNCE loss using hyperbolic distance."""
    batch_size = anchor.shape[0] if not isinstance(anchor, ManifoldTensor) else anchor.tensor.shape[0]
    num_neg = negatives.shape[1] if not isinstance(negatives, ManifoldTensor) else negatives.tensor.shape[1]
    
    # Compute positive distance
    pos_dist = manifold.dist(x=anchor, y=positive)
    
    # Expand anchor for negative comparisons
    if isinstance(anchor, ManifoldTensor):
        anchor_tensor = anchor.tensor.unsqueeze(1).expand(-1, num_neg, -1)
        anchor_expanded = ManifoldTensor(anchor_tensor, manifold=manifold)
    else:
        anchor_expanded = anchor.unsqueeze(1).expand(-1, num_neg, -1)
    
    neg_dist = manifold.dist(x=anchor_expanded, y=negatives)
    
    # Extract tensors
    if isinstance(pos_dist, ManifoldTensor):
        pos_dist = pos_dist.tensor
    if isinstance(neg_dist, ManifoldTensor):
        neg_dist = neg_dist.tensor
    
    # Reshape
    if pos_dist.dim() > 1:
        pos_dist = pos_dist.squeeze(-1)
    if neg_dist.dim() > 2:
        neg_dist = neg_dist.squeeze(-1)
    
    # Check for NaN
    if torch.isnan(pos_dist).any() or torch.isnan(neg_dist).any():
        print("WARNING: NaN in distances")
        return torch.tensor(0.0, device=pos_dist.device, requires_grad=True)
    
    # Convert distances to similarities
    pos_sim = -pos_dist / temperature
    neg_sim = -neg_dist / temperature
    
    # Combine for cross-entropy
    logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
    labels = torch.zeros(batch_size, dtype=torch.long, device=logits.device)
    
    loss = nn.functional.cross_entropy(logits, labels)
    return loss

## Training

In [None]:
def train_model(model, dataset, num_epochs=200, batch_size=32, lr=0.001, temperature=0.1):
    """Train the hyperbolic encoder."""
    from hypll.optim import RiemannianAdam
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    optimizer = RiemannianAdam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
    
    losses = []
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        for anchor, positive, negatives in dataloader:
            anchor = anchor.to(device)
            positive = positive.to(device)
            negatives = negatives.to(device)
            
            # Forward pass
            anchor_emb = model(anchor)
            positive_emb = model(positive)
            
            bs, num_neg, _ = negatives.shape
            negatives_emb = model(negatives.view(-1, 2))
            
            # Reshape negatives
            if isinstance(negatives_emb, ManifoldTensor):
                neg_tensor = negatives_emb.tensor.view(bs, num_neg, -1)
                negatives_emb = ManifoldTensor(neg_tensor, manifold=model.manifold)
            else:
                negatives_emb = negatives_emb.view(bs, num_neg, -1)
            
            # Loss
            loss = info_nce_loss(anchor_emb, positive_emb, negatives_emb,
                                model.manifold, temperature=temperature)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Skipping batch due to NaN/Inf loss")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        if num_batches > 0:
            avg_loss = epoch_loss / num_batches
            losses.append(avg_loss)
            scheduler.step()
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    return losses

In [None]:
# Train the model
model = HyperbolicIntervalEncoder(embedding_dim=2, c=1.0, euc_width=128, hyp_width=128).to(device)
losses = train_model(model, dataset, num_epochs=200, batch_size=32, lr=0.001, temperature=0.1)

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## Hitting Time Analysis

In [None]:
def compute_hitting_times(mdp, start_state, goal_state=5, n_simulations=10000, max_steps=10000):
    """
    Monte Carlo estimation of hitting times from start_state to goal_state
    under uniform random policy.
    
    Returns (mean, variance, std) of hitting times.
    """
    hitting_times = []
    
    for _ in range(n_simulations):
        state = start_state
        steps = 0
        
        while state != goal_state and steps < max_steps:
            state = mdp.step(state, action=None)
            steps += 1
        
        if state == goal_state:
            hitting_times.append(steps)
    
    if len(hitting_times) == 0:
        return float('inf'), float('inf'), float('inf')
    
    return np.mean(hitting_times), np.var(hitting_times), np.std(hitting_times)


def compute_all_hitting_times(mdp, n_simulations=10000):
    """
    Compute hitting time statistics for all (start, goal) pairs.
    
    Returns dict: {start_state: {'mean': ..., 'var': ..., 'std': ...}}
    """
    hitting_stats = {}
    goal = mdp.goal_state
    
    for start in range(mdp.n_states):
        if start != goal:
            mean, var, std = compute_hitting_times(mdp, start, goal, n_simulations)
            hitting_stats[start] = {'mean': mean, 'var': var, 'std': std}
            print(f"{mdp.state_name(start)} -> {mdp.state_name(goal)}: mean={mean:.2f}, var={var:.2f}, std={std:.2f}")
    
    return hitting_stats


# Compute hitting times
print("Computing hitting time statistics...")
print("="*60)
hitting_stats = compute_all_hitting_times(mdp, n_simulations=50000)

## Visualization on Poincare Ball

In [None]:
def plot_mdp_embeddings(embeddings, hitting_stats, mdp):
    """
    Visualize embeddings on the Poincare disk, colored by hitting time statistics.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    states = sorted(embeddings.keys())
    coords = np.array([embeddings[s] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # Plot 1: Colored by mean hitting time
    ax = axes[0]
    plot_poincare_disk(ax, "Colored by Mean Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=means, cmap='viridis',
                        s=120, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for s in states:
        emb = embeddings[s]
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    # Plot 2: Colored by variance
    ax = axes[1]
    plot_poincare_disk(ax, "Colored by Variance of Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma',
                        s=120, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for s in states:
        emb = embeddings[s]
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    # Plot 3: Labeled with all info
    ax = axes[2]
    plot_poincare_disk(ax, "All State Embeddings")
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(states)))
    for i, s in enumerate(states):
        emb = embeddings[s]
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=120,
                  edgecolors='black', linewidth=1.5, zorder=5,
                  label=f"{mdp.state_name(s)}: T={hitting_stats[s]['mean']:.1f}")
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=12, fontweight='bold',
                   xytext=(6, 6), textcoords='offset points')
    
    ax.legend(loc='upper right', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('mdp_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_mdp_embeddings(embeddings, hitting_stats, mdp)

In [None]:
def plot_mdp_embeddings(embeddings, hitting_stats, mdp):
    """
    Visualize embeddings on the Poincare disk, colored by hitting time statistics.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    states = sorted(embeddings.keys())
    coords = np.array([embeddings[s] for s in states])
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    
    # Plot 1: Colored by mean hitting time
    ax = axes[0]
    plot_poincare_disk(ax, "Colored by Mean Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=means, cmap='viridis',
                        s=300, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for s in states:
        emb = embeddings[s]
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=14, fontweight='bold',
                   xytext=(8, 8), textcoords='offset points')
    
    # Plot 2: Colored by variance
    ax = axes[1]
    plot_poincare_disk(ax, "Colored by Variance of Hitting Time")
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma',
                        s=300, edgecolors='black', linewidth=2, zorder=5)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for s in states:
        emb = embeddings[s]
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=14, fontweight='bold',
                   xytext=(8, 8), textcoords='offset points')
    
    # Plot 3: Labeled with all info
    ax = axes[2]
    plot_poincare_disk(ax, "All State Embeddings")
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(states)))
    for i, s in enumerate(states):
        emb = embeddings[s]
        ax.scatter([emb[0]], [emb[1]], c=[colors[i]], s=300,
                  edgecolors='black', linewidth=2, zorder=5,
                  label=f"{mdp.state_name(s)}: T={hitting_stats[s]['mean']:.1f}")
        ax.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=14, fontweight='bold',
                   xytext=(8, 8), textcoords='offset points')
    
    ax.legend(loc='upper right', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('mdp_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_mdp_embeddings(embeddings, hitting_stats, mdp)

In [ ]:
def test_conjecture(embeddings, hitting_stats, mdp):
    """
    Test the conjecture:
    - Norm (distance from origin) correlates NEGATIVELY with variance in hitting times
      (low variance -> high norm, high variance -> low norm)
    - Angular coordinate correlates with mean hitting times
    """
    states = sorted(embeddings.keys())
    
    # Extract data
    coords = np.array([embeddings[s] for s in states])
    norms = np.array([hyperbolic_norm(embeddings[s]) for s in states])
    angles = np.array([np.arctan2(embeddings[s][1], embeddings[s][0]) for s in states])
    
    means = np.array([hitting_stats[s]['mean'] for s in states])
    vars_ = np.array([hitting_stats[s]['var'] for s in states])
    stds = np.array([hitting_stats[s]['std'] for s in states])
    
    # Compute correlations
    corr_norm_var, p_norm_var = stats.spearmanr(norms, vars_)
    corr_norm_std, p_norm_std = stats.spearmanr(norms, stds)
    corr_angle_mean, p_angle_mean = stats.spearmanr(angles, means)
    
    # Also compute Pearson
    pearson_norm_var = np.corrcoef(norms, vars_)[0, 1]
    pearson_angle_mean = np.corrcoef(angles, means)[0, 1]
    
    print("="*70)
    print("CONJECTURE TEST RESULTS")
    print("="*70)
    
    print("\n1. NORM vs VARIANCE Conjecture:")
    print(f"   Spearman correlation: {corr_norm_var:.4f} (p={p_norm_var:.4f})")
    print(f"   Pearson correlation:  {pearson_norm_var:.4f}")
    print(f"   Expected: NEGATIVE (low variance -> high norm)")
    norm_var_supported = corr_norm_var < -0.3
    print(f"   Result: {'SUPPORTED' if norm_var_supported else 'NOT SUPPORTED'}")
    
    print("\n2. ANGLE vs MEAN Conjecture:")
    print(f"   Spearman correlation: {corr_angle_mean:.4f} (p={p_angle_mean:.4f})")
    print(f"   Pearson correlation:  {pearson_angle_mean:.4f}")
    print(f"   Expected: Strong correlation (similar means -> similar angles)")
    angle_mean_supported = abs(corr_angle_mean) > 0.3
    print(f"   Result: {'SUPPORTED' if angle_mean_supported else 'NOT SUPPORTED'}")
    
    print("\n" + "="*70)
    print("STATE-BY-STATE ANALYSIS")
    print("="*70)
    print(f"{'State':<8} {'Norm':<10} {'Angle(deg)':<12} {'Mean T':<10} {'Var T':<10} {'Std T':<10}")
    print("-"*70)
    
    for i, s in enumerate(states):
        print(f"{mdp.state_name(s):<8} {norms[i]:<10.4f} {np.degrees(angles[i]):<12.1f} "
              f"{means[i]:<10.2f} {vars_[i]:<10.2f} {stds[i]:<10.2f}")
    
    return {
        'corr_norm_var': corr_norm_var,
        'corr_angle_mean': corr_angle_mean,
        'p_norm_var': p_norm_var,
        'p_angle_mean': p_angle_mean,
        'norms': norms,
        'angles': angles,
        'means': means,
        'vars': vars_,
        'states': states,
        'norm_var_supported': norm_var_supported,
        'angle_mean_supported': angle_mean_supported
    }


results = test_conjecture(embeddings, hitting_stats, mdp)

In [None]:
def plot_summary(embeddings, results, hitting_stats, mdp):
    """
    Create a comprehensive summary plot with Poincare disk, scatter plots, and statistics.
    """
    fig = plt.figure(figsize=(16, 14))
    
    # Create grid spec for layout
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.25)
    
    states = results['states']
    norms = results['norms']
    angles = results['angles']
    means = results['means']
    vars_ = results['vars']
    coords = np.array([embeddings[s] for s in states])
    
    # =========================================================================
    # Plot 1: Poincare Ball (top-left)
    # =========================================================================
    ax1 = fig.add_subplot(gs[0, 0])
    
    # Draw disk
    circle = plt.Circle((0, 0), 1, color='black', fill=False, linewidth=2, linestyle='--')
    ax1.add_patch(circle)
    ax1.scatter([0], [0], s=60, c='black', marker='+', linewidth=2, zorder=10)
    
    # Color by variance (inverse - low var = brighter)
    scatter = ax1.scatter(coords[:, 0], coords[:, 1], c=vars_, cmap='plasma_r',
                         s=100, edgecolors='black', linewidth=1.5, zorder=5)
    plt.colorbar(scatter, ax=ax1, label='Variance', shrink=0.8)
    
    for s in states:
        emb = embeddings[s]
        ax1.annotate(mdp.state_name(s), (emb[0], emb[1]), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    ax1.set_xlim(-1.15, 1.15)
    ax1.set_ylim(-1.15, 1.15)
    ax1.set_aspect('equal')
    ax1.grid(True, alpha=0.2)
    ax1.set_title('Poincaré Ball Embeddings\n(colored by variance)', fontsize=13, fontweight='bold')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    
    # =========================================================================
    # Plot 2: Norm vs Variance (top-right)
    # =========================================================================
    ax2 = fig.add_subplot(gs[0, 1])
    
    scatter = ax2.scatter(vars_, norms, c=means, cmap='viridis', s=100, 
                         edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax2, label='Mean Hitting Time', shrink=0.8)
    
    for i, s in enumerate(states):
        ax2.annotate(mdp.state_name(s), (vars_[i], norms[i]), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(vars_, norms, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(vars_) - 5, max(vars_) + 5, 100)
    ax2.plot(x_trend, p(x_trend), 'r--', linewidth=2, alpha=0.7, label=f'Trend (slope={z[0]:.4f})')
    
    ax2.set_xlabel('Variance of Hitting Time', fontsize=11)
    ax2.set_ylabel('Hyperbolic Norm', fontsize=11)
    
    # Add correlation annotation
    corr = results['corr_norm_var']
    status = "SUPPORTED" if corr < -0.3 else "NOT SUPPORTED"
    color = 'green' if corr < -0.3 else 'red'
    ax2.set_title(f'Conjecture 1: Norm vs Variance\nρ = {corr:.3f} ({status})', 
                 fontsize=13, fontweight='bold', color='black')
    ax2.legend(loc='upper right')
    ax2.grid(True, alpha=0.3)
    
    # =========================================================================
    # Plot 3: Angle vs Mean (bottom-left)
    # =========================================================================
    ax3 = fig.add_subplot(gs[1, 0])
    
    scatter = ax3.scatter(means, np.degrees(angles), c=vars_, cmap='plasma', s=100,
                         edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax3, label='Variance', shrink=0.8)
    
    for i, s in enumerate(states):
        ax3.annotate(mdp.state_name(s), (means[i], np.degrees(angles[i])), fontsize=11, fontweight='bold',
                    xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(means, np.degrees(angles), 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(means) - 1, max(means) + 1, 100)
    ax3.plot(x_trend, p(x_trend), 'r--', linewidth=2, alpha=0.7, label=f'Trend (slope={z[0]:.2f})')
    
    ax3.set_xlabel('Mean Hitting Time', fontsize=11)
    ax3.set_ylabel('Angular Coordinate (degrees)', fontsize=11)
    
    corr = results['corr_angle_mean']
    status = "SUPPORTED" if abs(corr) > 0.3 else "NOT SUPPORTED"
    ax3.set_title(f'Conjecture 2: Angle vs Mean\nρ = {corr:.3f} ({status})', 
                 fontsize=13, fontweight='bold')
    ax3.legend(loc='best')
    ax3.grid(True, alpha=0.3)
    
    # =========================================================================
    # Plot 4: Summary Statistics (bottom-right)
    # =========================================================================
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.axis('off')
    
    # Create summary text
    norm_var_status = "SUPPORTED" if results['norm_var_supported'] else "NOT SUPPORTED"
    angle_mean_status = "SUPPORTED" if results['angle_mean_supported'] else "NOT SUPPORTED"
    
    summary_text = f"""
╔══════════════════════════════════════════════════════════╗
║              HYPOTHESIS TEST SUMMARY                      ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  CONJECTURE 1: Norm ∝ 1/Variance                         ║
║  ─────────────────────────────────────────────────────   ║
║  Expected: Low variance → High norm (negative corr)      ║
║  Spearman ρ = {results['corr_norm_var']:+.4f}  (p = {results['p_norm_var']:.4f})                  ║
║  Status: {norm_var_status:<20}                       ║
║                                                          ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  CONJECTURE 2: Angle ∝ Mean                              ║
║  ─────────────────────────────────────────────────────   ║
║  Expected: Similar mean → Similar angle (strong corr)    ║
║  Spearman ρ = {results['corr_angle_mean']:+.4f}  (p = {results['p_angle_mean']:.4f})                  ║
║  Status: {angle_mean_status:<20}                       ║
║                                                          ║
╠══════════════════════════════════════════════════════════╣
║                                                          ║
║  STATE STATISTICS:                                       ║
║  ─────────────────────────────────────────────────────   ║"""
    
    for i, s in enumerate(states):
        summary_text += f"\n║  {mdp.state_name(s)}: norm={norms[i]:.2f}, angle={np.degrees(angles[i]):+6.1f}°, "
        summary_text += f"μ={means[i]:5.1f}, σ²={vars_[i]:6.1f}  ║"
    
    summary_text += """
║                                                          ║
╚══════════════════════════════════════════════════════════╝
"""
    
    ax4.text(0.02, 0.98, summary_text, transform=ax4.transAxes, fontsize=10,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='lightyellow', alpha=0.9, edgecolor='black'))
    
    plt.suptitle('MDP Hyperbolic Embedding Analysis', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.savefig('mdp_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()


plot_summary(embeddings, results, hitting_stats, mdp)

In [None]:
def plot_conjecture_analysis(results, mdp):
    """
    Detailed visualization of the conjecture analysis.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    norms = results['norms']
    angles = results['angles']
    means = results['means']
    vars_ = results['vars']
    states = results['states']
    
    # Plot 1: Norm vs Variance
    ax = axes[0, 0]
    scatter = ax.scatter(vars_, norms, c=means, cmap='viridis', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time')
    
    for i, s in enumerate(states):
        ax.annotate(mdp.state_name(s), (vars_[i], norms[i]), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    # Add trend line
    z = np.polyfit(vars_, norms, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(vars_), max(vars_), 100)
    ax.plot(x_trend, p(x_trend), 'r--', linewidth=2, label=f'Trend (slope={z[0]:.4f})')
    
    ax.set_xlabel('Variance of Hitting Time', fontsize=12)
    ax.set_ylabel('Hyperbolic Norm', fontsize=12)
    ax.set_title(f'Conjecture 1: Norm vs Variance\nCorr={results["corr_norm_var"]:.3f}', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Angle vs Mean
    ax = axes[0, 1]
    scatter = ax.scatter(means, np.degrees(angles), c=vars_, cmap='plasma', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for i, s in enumerate(states):
        ax.annotate(mdp.state_name(s), (means[i], np.degrees(angles[i])), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    ax.set_xlabel('Mean Hitting Time', fontsize=12)
    ax.set_ylabel('Angular Coordinate (degrees)', fontsize=12)
    ax.set_title(f'Conjecture 2: Angle vs Mean\nCorr={results["corr_angle_mean"]:.3f}', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Norm vs Mean
    ax = axes[1, 0]
    scatter = ax.scatter(means, norms, c=vars_, cmap='plasma', s=200, edgecolors='black', linewidth=2)
    plt.colorbar(scatter, ax=ax, label='Variance')
    
    for i, s in enumerate(states):
        ax.annotate(mdp.state_name(s), (means[i], norms[i]), fontsize=12, fontweight='bold',
                   xytext=(5, 5), textcoords='offset points')
    
    corr_norm_mean = np.corrcoef(norms, means)[0, 1]
    ax.set_xlabel('Mean Hitting Time', fontsize=12)
    ax.set_ylabel('Hyperbolic Norm', fontsize=12)
    ax.set_title(f'Additional: Norm vs Mean\nCorr={corr_norm_mean:.3f}', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Summary stats
    ax = axes[1, 1]
    ax.axis('off')
    
    summary_text = f"""
    CONJECTURE ANALYSIS SUMMARY
    {'='*50}
    
    Conjecture 1: Norm ~ Variance
    -----------------------------
    Spearman correlation: {results['corr_norm_var']:.4f}
    Interpretation: {'SUPPORTED' if results['corr_norm_var'] > 0.3 else 'WEAK/NOT SUPPORTED'}
    
    Conjecture 2: Angle ~ Mean
    ---------------------------
    Spearman correlation: {results['corr_angle_mean']:.4f}
    Interpretation: {'SUPPORTED' if abs(results['corr_angle_mean']) > 0.3 else 'WEAK/NOT SUPPORTED'}
    
    {'='*50}
    
    Note: The MDP has structure that affects embeddings:
    - S4, S5: 1-step to goal (low var, low mean)
    - S2, S3: Geometric waiting (high var, high mean)
    - S1: Depends on action choice
    """
    
    ax.text(0.05, 0.95, summary_text, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('conjecture_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_conjecture_analysis(results, mdp)

In [ ]:
def plot_all_state_pairs(model, mdp):
    """
    Plot embeddings for ALL (state, goal) pairs, not just (state, final_goal).
    """
    model.eval()
    n_states = mdp.n_states
    
    fig, ax = plt.subplots(figsize=(10, 10))
    plot_poincare_disk(ax, "All (State, Goal) Pair Embeddings")
    
    all_embeddings = []
    labels = []
    colors = []
    
    with torch.no_grad():
        for start in range(n_states):
            for goal in range(n_states):
                if start != goal:  # Skip self-loops
                    s_norm = start / (n_states - 1)
                    g_norm = goal / (n_states - 1)
                    
                    x = torch.tensor([[s_norm, g_norm]], dtype=torch.float32).to(device)
                    emb = model(x)
                    
                    if isinstance(emb, ManifoldTensor):
                        emb = emb.tensor
                    
                    emb = emb.squeeze(0).cpu().numpy()
                    all_embeddings.append(emb)
                    labels.append(f"({mdp.state_name(start)},{mdp.state_name(goal)})")
                    colors.append(goal)  # Color by goal state
    
    all_embeddings = np.array(all_embeddings)
    colors = np.array(colors)
    
    scatter = ax.scatter(all_embeddings[:, 0], all_embeddings[:, 1], 
                        c=colors, cmap='tab10', s=80, edgecolors='black', linewidth=1, zorder=5)
    
    # Add labels
    for i, label in enumerate(labels):
        ax.annotate(label, all_embeddings[i], fontsize=7, 
                   xytext=(3, 3), textcoords='offset points')
    
    plt.colorbar(scatter, ax=ax, label='Goal State Index')
    plt.tight_layout()
    plt.savefig('all_state_pairs.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_all_state_pairs(model, mdp)

In [None]:
def plot_polar_representation(embeddings, hitting_stats, mdp):
    """
    Polar plot showing angle vs hyperbolic norm.
    """
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': 'polar'})
    
    states = sorted(embeddings.keys())
    
    angles = [np.arctan2(embeddings[s][1], embeddings[s][0]) for s in states]
    norms = [hyperbolic_norm(embeddings[s]) for s in states]
    means = [hitting_stats[s]['mean'] for s in states]
    vars_ = [hitting_stats[s]['var'] for s in states]
    
    # Size by inverse variance (low variance = larger), color by mean
    max_var = max(vars_) if max(vars_) > 0 else 1
    sizes = 50 + 100 * (1 - np.array(vars_) / max_var)
    
    scatter = ax.scatter(angles, norms, c=means, cmap='viridis', s=sizes, 
                        edgecolors='black', linewidth=1.5)
    plt.colorbar(scatter, ax=ax, label='Mean Hitting Time', pad=0.1)
    
    for i, s in enumerate(states):
        ax.annotate(mdp.state_name(s), (angles[i], norms[i]), fontsize=11, fontweight='bold')
    
    ax.set_title('Polar View: Angle vs Hyperbolic Norm\n(Size ~ 1/Variance, Color ~ Mean)', fontsize=13, pad=20)
    
    plt.tight_layout()
    plt.savefig('polar_embeddings.png', dpi=150, bbox_inches='tight')
    plt.show()


plot_polar_representation(embeddings, hitting_stats, mdp)

## Discussion

### MDP Structure Analysis

The MDP has distinct structural properties that affect hitting times:

1. **States 4 and 5 (S4, S5)**: Deterministic 1-step transitions to goal
   - Mean hitting time: 1
   - Variance: 0
   
2. **States 2 and 3 (S2, S3)**: Geometric waiting times due to self-loops
   - Mean hitting time: ~10-11 (expected value of geometric(0.1) + 1)
   - Variance: ~90 (variance of geometric distribution)
   
3. **State 1 (S1)**: Depends on action selection
   - Under random policy: mix of 2-step (via a11) and longer paths (via a12, a13)

### Conjecture Interpretation

The conjecture states:
- **Norm ~ 1/Variance**: States with **lower** hitting time variance should be embedded **farther** from the origin (higher norm)
- **Angle ~ Mean**: States with similar mean hitting times should have similar angular coordinates

This would create a structure where:
- S4, S5 (low variance, deterministic) should have **high norms** (far from origin)
- S2, S3 (high variance, geometric waiting) should have **low norms** (near origin)
- States with similar means lie along similar radial directions

### Intuition

In hyperbolic space, the origin represents maximal uncertainty/abstraction. States with:
- **Low variance** (predictable hitting times) are more "specific" → farther from origin
- **High variance** (unpredictable hitting times) are more "abstract" → closer to origin

In [ ]:
# Final summary
print("="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nTraining completed with {len(trajectories)} trajectories")
print(f"Final training loss: {losses[-1]:.4f}")

print(f"\n--- Conjecture 1: Norm ~ 1/Variance ---")
print(f"Spearman correlation: {results['corr_norm_var']:.4f}")
print(f"Expected: Negative (low variance -> high norm)")
print(f"Status: {'SUPPORTED' if results['norm_var_supported'] else 'NOT SUPPORTED'}")

print(f"\n--- Conjecture 2: Angle ~ Mean ---")
print(f"Spearman correlation: {results['corr_angle_mean']:.4f}")
print(f"Expected: Strong correlation")
print(f"Status: {'SUPPORTED' if results['angle_mean_supported'] else 'NOT SUPPORTED'}")

print("\n" + "="*70)
print("See the summary plot above for detailed visualization.")

In [None]:
# Final summary
print("="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nTraining completed with {len(trajectories)} trajectories")
print(f"Final training loss: {losses[-1]:.4f}")
print(f"\nConjecture 1 (Norm ~ Variance): correlation = {results['corr_norm_var']:.4f}")
print(f"Conjecture 2 (Angle ~ Mean): correlation = {results['corr_angle_mean']:.4f}")
print("\nSee visualizations above for detailed analysis.")