In [1]:
import torch
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sentinel = 1337

In [2]:
def search_alpha_parallel(sol, V_start, radius, t_eval, L):

    if sol.dim() > 1:
      sol = sol[:, 1:]
    else:
      sol = sol[1:]
    t_eval = t_eval[1:]
    ceiling = 1e5
    alpha_h = torch.tensor([1], device = device)*torch.ones_like(V_start)
    alpha_l = torch.tensor([-1], device = device)*torch.ones_like(V_start)
    if sol.dim() > 1:
      r = radius.unsqueeze(1).expand(sol.size())
    else:
      r = radius.unsqueeze(0).expand(sol.size())

    if True:
      condition_h = torch.min(sol*torch.exp(torch.outer(alpha_h,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_h,t_eval))), dim = 1).values > V_start - radius
      condition_l = torch.min(sol*torch.exp(torch.outer(alpha_l,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_l,t_eval))), dim = 1).values < V_start - radius


    while torch.any(~condition_h) or torch.any(~condition_l):
      alpha_h[~condition_h] = alpha_h[~condition_h]*2
      alpha_l[~condition_l] = alpha_l[~condition_l]*2
      if True:
        #print('uh oh')
        condition_h = torch.min(sol*torch.exp(torch.outer(alpha_h,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_h,t_eval))), dim = 1).values > V_start - radius
        condition_l = torch.min(sol*torch.exp(torch.outer(alpha_l,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_l,t_eval))), dim = 1).values < V_start - radius

    alpha_mid = (alpha_h+alpha_l)/2
    while torch.max(torch.abs(alpha_h-alpha_l))>0.0001:
        if True:
          condition = torch.min(sol*torch.exp(torch.outer(alpha_mid,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_mid,t_eval))), dim = 1).values < V_start - radius

        alpha_l[condition] = alpha_mid[condition]
        alpha_h[~condition] = alpha_mid[~condition]
        alpha_mid = (alpha_h+alpha_l)/2

    indices = torch.argmin(sol*torch.exp(torch.outer(alpha_mid,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha_mid,t_eval))), dim = 1)
    return alpha_mid, indices

In [3]:
def certify_alpha(sol, V_start, radius, t_eval, L, alpha):
    sol = sol[:, 1:]
    t_eval = t_eval[1:]
    alpha = alpha*torch.tensor([1], device = device)*torch.ones_like(V_start)
    r = radius.unsqueeze(1).expand(sol.size())

    condition = torch.min(sol*torch.exp(torch.outer(alpha,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha,t_eval))), dim = 1).values < V_start - radius

    indices = torch.argmin(sol*torch.exp(torch.outer(alpha,t_eval)) + torch.mul(r, torch.exp(torch.outer(L + alpha,t_eval))), dim = 1)
    return condition, indices

In [4]:
import torch
import math

def wrap_to_pi(tensor):
    """
    Wrap angles in the last-dim index 0 (i.e., theta) to (-pi, pi]
    for any tensor ending in dimension 2 (e.g., (m, n, 2)).
    """
    wrapped = (tensor[..., 0] + math.pi) % (2 * math.pi) - math.pi
    result = tensor.clone()
    result[..., 0] = wrapped
    return result


In [5]:
import torch

def find_hypercubes_old(points, centers, radii):
    """
    Vectorized version: returns the first matching hypercube index for each point,
    or -1 if none are found.
    """

    points = wrap_to_pi(points)
    k, n = points.shape
    m, _ = centers.shape

    if radii.dim() == 1:
        radii = radii.unsqueeze(1)
    radii = radii.expand(m, n)

    lower_bounds = centers - radii  # (m, n)
    upper_bounds = centers + radii  # (m, n)

    # (k, 1, n) vs (1, m, n) → (k, m, n)
    points = points.unsqueeze(1)
    contained = (points >= lower_bounds) & (points <= upper_bounds)
    inside_mask = contained.all(dim=2)  # (k, m)

    # Set all False to large positive index (m), then take min index along dim=1
    masked_indices = torch.where(inside_mask, torch.arange(m, device=points.device), m)
    min_indices = masked_indices.min(dim=1).values

    # Set to -1 if no hypercube matched (i.e. if index == m)
    result = torch.where(min_indices == m, torch.full_like(min_indices, -1), min_indices)


    return result


In [6]:
import torch

def find_hypercubes(points, centers, radii):
    """
    Vectorized function to find which square (if any) each point falls into.

    Args:
        points: (k, 2) tensor of 2D points.
        centers: (m, 2) tensor of square centers.
        radii: (m,) tensor of square half-widths.

    Returns:
        (k,) tensor of square indices, or -1 if not inside any square.
    """
    k = points.shape[0]
    m = centers.shape[0]

    # Expand for vectorized broadcasted comparison
    points_exp = points[:, None, :]        # (k, 1, 2)
    centers_exp = centers[None, :, :]      # (1, m, 2)
    radii_exp = radii[None, :, None]       # (1, m, 1)

    lower = centers_exp - radii_exp        # (1, m, 2)
    upper = centers_exp + radii_exp        # (1, m, 2)

    # (k, m, 2): whether each point is within each square in both dims
    contained = (points_exp >= lower) & (points_exp <= upper)
    contained = contained.all(dim=-1)      # (k, m): True if point i in square j

    any_match = contained.any(dim=1)       # (k,)
    match_idx = torch.argmax(contained.int(), dim=1)  # (k,): first match

    # Replace unmatched with -1
    result = torch.where(any_match, match_idx, torch.full_like(match_idx, -1))

    return result


In [7]:
import torch

def inverted_pendulum_2d_torch(x, u, m=0.1, l=10, g=9.81):
    """
    2D simplified inverted pendulum dynamics (batch version with PyTorch).

    Args:
        x: Tensor of shape (n, 2) -> [theta, theta_dot]
        u: Tensor of shape (n,) or (n,1) -> control input
        m: Mass of pendulum
        l: Length to center of mass
        g: Gravitational acceleration

    Returns:
        x_dot: Tensor of shape (n, 2) -> [theta_dot, theta_ddot]
    """
    theta = x[:, 0]
    theta_dot = x[:, 1]

    # Ensure u has shape (n,)
    u = u.squeeze(-1)

    # Compute second derivative
    theta_ddot = (g / l) * torch.sin(theta) + (u / (m * l)) * torch.abs(torch.cos(theta))

    # Pack into (n,2)
    x_dot = torch.stack([theta_dot, theta_ddot], dim=1)

    return x_dot


In [8]:
import torch

def simplified_pendulum_derivatives(x, u):
    """
    Pendulum-v1 dynamics exactly as implemented in OpenAI Gym.

    Args:
        x: Tensor of shape (batch_size, 2) – [theta, theta_dot]
        u: Tensor of shape (batch_size,) – torque input in [-2, 2]

    Returns:
        dxdt: Tensor of shape (batch_size, 2) – [dtheta, dtheta_dot]
    """
    theta = x[:, 0]
    theta_dot = x[:, 1]

    g = 10.0   # gravity (as used in Gym)
    torque_coeff = 15.0

    dtheta = theta_dot
    dtheta_dot = 3 * g * torch.sin(theta) + torque_coeff * u

    dxdt = torch.stack([dtheta, dtheta_dot], dim=1)
    return dxdt


In [9]:
import torch
import math

def count_hypercube_intersections_wrap(centers, radii, query_centers, query_radii):
    """
    Count how many hypercubes intersect each query hypercube.
    Dimension 0 is angular and wraps around at 2π.
    """
    N, D = centers.shape
    M = query_centers.shape[0]

    if radii.ndim == 1:
        radii = radii.unsqueeze(1).expand(N, D)
    elif radii.shape[1] == 1:
        radii = radii.expand(N, D)

    if query_radii.ndim == 1:
        query_radii = query_radii.unsqueeze(1).expand(M, D)
    elif query_radii.shape[1] == 1:
        query_radii = query_radii.expand(M, D)

    # Expand for broadcasting
    centers = centers.unsqueeze(0)             # (1, N, D)
    radii = radii.unsqueeze(0)                 # (1, N, D)
    query_centers = query_centers.unsqueeze(1) # (M, 1, D)
    query_radii = query_radii.unsqueeze(1)     # (M, 1, D)

    # Compute wrapped angular distance in dimension 0
    angle_diff = torch.abs(centers[..., 0] - query_centers[..., 0])  # (M, N)
    angle_dist = torch.minimum(angle_diff, 2 * math.pi - angle_diff)  # (M, N)

    # Compute regular distance for other dimensions
    if D > 1:
        euclidean_dist = torch.abs(centers[..., 1:] - query_centers[..., 1:])  # (M, N, D-1)
        euclidean_thresh = radii[..., 1:] + query_radii[..., 1:]               # (M, N, D-1)
        euclidean_mask = euclidean_dist <= euclidean_thresh                   # (M, N, D-1)
        angular_thresh = radii[..., 0] + query_radii[..., 0]                  # (M, N)
        angular_mask = angle_dist <= angular_thresh                           # (M, N)

        # Combine angular and Euclidean masks
        intersects = angular_mask & euclidean_mask.all(dim=-1)               # (M, N)
    else:
        # Only one dimension (angular case)
        angular_thresh = radii[..., 0] + query_radii[..., 0]
        intersects = angle_dist <= angular_thresh

    return intersects.sum(dim=1)  # (M,)

In [10]:
import torch
import math

def find_hypercube_intersections_wrap(centers, radii, query_centers, query_radii):
    """
    For each query hypercube, return a list of tensors with the indices of intersecting hypercubes.
    The first dimension (dimension 0) is angular and wraps at 2π.
    """
    N, D = centers.shape
    M = query_centers.shape[0]

    # Expand radii to (N, D)
    if radii.ndim == 1:
        radii = radii.unsqueeze(1).expand(N, D)
    elif radii.shape[1] == 1:
        radii = radii.expand(N, D)

    if query_radii.ndim == 1:
        query_radii = query_radii.unsqueeze(1).expand(M, D)
    elif query_radii.shape[1] == 1:
        query_radii = query_radii.expand(M, D)

    # Expand for broadcasting
    qc = query_centers.unsqueeze(1)  # (M, 1, D)
    qr = query_radii.unsqueeze(1)    # (M, 1, D)
    c = centers.unsqueeze(0)         # (1, N, D)
    r = radii.unsqueeze(0)           # (1, N, D)

    # Wrapped distance in angular dimension (dim 0)
    angle_diff = torch.abs(c[..., 0] - qc[..., 0])                 # (M, N)
    angle_dist = torch.minimum(angle_diff, 2 * math.pi - angle_diff)  # (M, N)
    angle_thresh = r[..., 0] + qr[..., 0]                          # (M, N)
    angle_mask = angle_dist < angle_thresh                        # (M, N)

    if D > 1:
        # Regular distance in remaining dimensions
        dist = torch.abs(c[..., 1:] - qc[..., 1:])                 # (M, N, D-1)
        threshold = r[..., 1:] + qr[..., 1:]                       # (M, N, D-1)
        euclidean_mask = dist < threshold                        # (M, N, D-1)
        total_mask = angle_mask & euclidean_mask.all(dim=-1)      # (M, N)
    else:
        total_mask = angle_mask  # only 1D

    # Get (query_idx, box_idx) pairs
    query_idx, box_idx = torch.nonzero(total_mask, as_tuple=True)

    # Group indices by query
    counts = torch.bincount(query_idx, minlength=M)
    splits = counts.cumsum(0)
    splits = torch.cat([splits.new_zeros(1), splits])

    return [box_idx[splits[i]:splits[i+1]] for i in range(M)]


In [11]:
def check_parallel(batch_states, t, r, alpha,
                               L, rate, num_samples,
                               unverified_centers, unverified_radii,
                               verified_centers, verified_radii,
                               verified_alphas, verified_indices, trajectories):
    device = batch_states.device
    B = batch_states.shape[0]

    # Scale radius
    scaled_r = r * np.exp(L * t * rate)
    radii_tensor = scaled_r * torch.ones(B, device=device)

    # Count unverified intersections
    unverified_count = count_hypercube_intersections_wrap(
         unverified_centers, unverified_radii, batch_states, radii_tensor
    )
    eligible_mask = unverified_count == 0
    eligible_indices = torch.nonzero(eligible_mask, as_tuple=False).squeeze(1)

    # Filter current candidates
    query_states = batch_states[eligible_indices]
    query_radii = radii_tensor[eligible_indices]


    # Get verified intersections
    verified_lists = find_hypercube_intersections_wrap(
       verified_centers, verified_radii, query_states, query_radii,
    )

    # Flatten into (K_total,) for all (k, k_i) pairs
    k_offsets = torch.arange(len(verified_lists), device=device)
    row_counts = torch.tensor([v.numel() for v in verified_lists], device=device)
    if row_counts.sum() == 0:
        return eligible_indices.new_empty(0)

    k_idx = torch.repeat_interleave(k_offsets, row_counts)
    k_i_idx = torch.cat(verified_lists, dim=0)


    # Gather data
    q_states = query_states[k_idx]                         # (K, D)
    q_radii = query_radii[k_idx]                           # (K,)

    v_centers = verified_centers[k_i_idx]                  # (K, D)
    v_radii = verified_radii[k_i_idx]                      # (K,)
    v_alphas = verified_alphas[k_i_idx]                    # (K,)
    v_indices = verified_indices[k_i_idx]                  # (K,)

    traj_k = trajectories[eligible_indices[k_idx], :, 0]   # (K, T+1)

    norm_traj = torch.norm(traj_k, dim=1) - r[eligible_indices][k_idx]         # (K,)
    term1 = torch.exp(-(v_alphas - alpha) * (v_indices) * rate)
    term2 = torch.exp(v_alphas * t * rate)
    numerator = (v_centers + v_radii.unsqueeze(1)).norm(dim=1)
    inequality = term1 * term2 * numerator / norm_traj     # (K,)

    # Per k: check if all inequality ≤ 1
    failed = inequality > 1
    failed_mask = torch.zeros(len(verified_lists), dtype=torch.bool, device=device)
    failed_mask.index_put_((k_idx[failed],), torch.ones_like(k_idx[failed], dtype=torch.bool), accumulate=True)

    # Final: only those k where all passed
    passed_mask = ~failed_mask
    return eligible_indices[passed_mask]


In [12]:
from typing_extensions import final
import torch
import numpy as np
import itertools
import gc

class PendulumEnv:
    def __init__(self,
                 num_envs=10,    # Number of parallel environments
                 target = None,   # Allow custom target location
                 max_radius=5.0,         # Define boundary size
                 max_speed=1.0,           # Limit agent's velocity
                 precision = 0.03,        # Define the allowed distance from the target
                 step_penalization = 2,   # How strongly each step taken is penalized
                 rate = 0.1,              # Rate of movement
                 field_function = inverted_pendulum_2d_torch,  # Allow different field functions
                 reward_type="distance",   # Allow different reward strategies
                 dimension = 2,
                 max_starting_speed = 1.0,
                 min_alpha = 0.01,
                 L = 5,
                 eval_func = False
                ):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.num_envs = num_envs

        self.max_speed = max_speed
        self.max_radius = max_radius
        self.precision = precision
        self.step_penalization = step_penalization
        self.rate = rate
        self.field_function = field_function
        self.reward_type = reward_type
        self.dim = dimension
        self.f = field_function
        self.max_starting_speed = max_starting_speed
        self.alpha = min_alpha
        self.L = L
        self.eval_func = eval_func

        if target == None:
          self.target = torch.zeros(self.dim, device = self.device)
        else:
          self.target = torch.tensor(target, device=self.device, dtype=torch.float32)


        # Initialize states (x, y, vx, vy) for all environments in parallel
        self.states = self.sample_points_in_circle(self.num_envs).to(self.device)
        self.states = self.sample_points_in_hypercube(self.num_envs).to(self.device)

    def sample_points_in_circle(self, batch_size):
        """ Sample random starting points within a circle of max_radius. """

        # Step 1: Sample from normal distribution (for uniform direction)
        points = torch.randn((batch_size, self.dim), device=self.device)

        # Step 2: Normalize to lie on the unit n-sphere
        points /= torch.norm(points, dim=1, keepdim=True)

        # Step 3: Sample radii with correct volume scaling
        radii = torch.rand((batch_size, 1), device=self.device) ** (1 / self.dim)  # Ensure uniform density

        # Step 4: Scale points by radii
        return self.max_radius * points * radii

    def sample_points(self, batch_size):
        """ Sample random starting points. """

        # Sample each coordinate independently from U(-1, 1)
        points = (torch.rand((batch_size, self.dim), device=self.device) * 2 - 1)  # Rescale [0,1] to [-1,1]

        points = points * torch.tensor([self.max_radius, self.max_starting_speed], device=self.device)

        return points

    def sample_points_in_hypercube(self, batch_size):
        """ Sample random starting points uniformly within a hypercube. """

        # Sample each coordinate independently from U(-1, 1)
        points = (torch.rand((batch_size, self.dim), device=self.device) * 2 - 1)  # Rescale [0,1] to [-1,1]

        # Scale by max_radius
        return self.max_radius * points


    def reset(self, seed = None):
        """ Reset all environments in parallel. """
        if seed is not None:
          self.states = seed
        else:
          self.states = self.sample_points_in_hypercube(self.num_envs)


    def step(self, actions):
        """ Step function that updates all environments in parallel. """
        actions = torch.clamp(actions, -self.max_speed, self.max_speed)

        movement = self.f(self.states, actions)
        self.states += movement * self.rate  # Update positions

        distances = torch.linalg.norm(self.states - self.target, dim=1, ord = float('inf'))
        rewards = -distances - self.step_penalization * self.rate

        # Out-of-bounds penalty
        out_of_bounds = distances > self.max_radius
        rewards[out_of_bounds] = -10000

        # Check if agents have reached the target
        done = distances < self.precision
        rewards[done] += 10000  # Bonus for reaching target

        stop = torch.logical_or(done,out_of_bounds)

        return self.states.clone(), rewards, stop

    def sample_trajectory(self, time_steps=5, control_seed = None, variance: int = 1, num_samples: int = 1, r = None):
        """ Generate batch trajectories for multiple sampled controls and select the best one.

        - If `control_seed` is provided, it should have shape (num_envs, time_steps, 2).
        - Otherwise, random controls are sampled from a normal distribution.

        Returns:
        - Best control sequence (num_envs, time_steps, 2) -> **Direct input for `trajectories`**
        - Best trajectory per environment (num_envs, time_steps+1, 2)
        - Minimum final distance per environment (num_envs,)
        """

        variance = self.max_speed/3
        # Initialize batch states (only positions)
        t_eval = torch.linspace(0.0, time_steps*self.rate, steps=time_steps, device = device)
        if r == None:
          r = torch.zeros_like(t_eval)
        # Initialize batch states (only positions)
        batch_states = self.states.clone().repeat_interleave(num_samples, dim = 0)  # Shape: (num_envs * num_samples, 2)



        # Handle control seeding
        if control_seed is not None:
            assert control_seed.shape == (self.num_envs, time_steps), \
                "control_seed must have shape (num_envs, time_steps, 2)"
            mu = control_seed.unsqueeze(1).repeat(1, num_samples, 1)  # Expand for num_samples
        else:
            mu = torch.zeros((self.num_envs, num_samples, time_steps), device=self.device)

        sigma = torch.ones_like(mu) * variance  # Add variance
        action_lists = torch.normal(mu, sigma)  # Sampled action sequences

        action_lists = action_lists.view(-1, time_steps)  # Reshape to (num_envs * num_samples, time_steps)
        action_lists = torch.clamp(action_lists, -self.max_speed, self.max_speed)

        # Track all trajectories
        trajectories = [batch_states.clone().unsqueeze(1)]  # Initial positions
        for t in range(time_steps):
            actions = action_lists[:, t]
            movement = self.f(batch_states, actions)
            batch_states = batch_states + movement * self.rate  # Move only by instantaneous action
            trajectories.append(batch_states.clone().unsqueeze(1))  # Store positions



        # Stack to get (num_envs * num_samples, time_steps+1, 2)
        trajectories = torch.cat(trajectories, dim=1)

        trajectories = wrap_to_pi(trajectories)
        batch_states = wrap_to_pi(batch_states)

        # Compute final distances for selection
        final_positions = batch_states

        if self.eval_func:
          dist = torch.sqrt(((trajectories - self.target)**2 * torch.tensor([1.0, 0.01], device=device)).sum(dim=2))[:, 1:]
        else:
          dist = torch.linalg.norm(trajectories - self.target, dim=2, ord = float('inf'))[:, 1:]

        r = torch.repeat_interleave(r, num_samples)
        firstpart = dist*torch.exp(self.alpha*t_eval)
        secondpart = torch.mul(r.unsqueeze(1), torch.exp(((self.L + self.alpha)*t_eval)).unsqueeze(0))
        alphadist = firstpart + secondpart

        #final_distances = torch.linalg.norm(final_positions - self.target, dim=1, ord = float('inf'))
        alpha_distances = torch.min(alphadist, dim = 1).values
        taus = torch.argmin(alphadist, dim = 1)

        # Reshape distances to (num_envs, num_samples)
        alpha_distances = alpha_distances.view(self.num_envs, num_samples)
        taus = taus.view(self.num_envs, num_samples)

        # Find the best trajectory for each environment
        best_indices = torch.argmin(alpha_distances, dim=1)  # Best sample for each environment
        taus = taus[torch.arange(self.num_envs), best_indices]


        # Gather best actions and trajectories
        best_actions = action_lists.view(self.num_envs, num_samples, time_steps)[torch.arange(self.num_envs), best_indices]
        best_trajectories = trajectories.view(self.num_envs, num_samples, time_steps+1, self.dim)[torch.arange(self.num_envs), best_indices]

        return best_actions, best_trajectories, alpha_distances.min(dim=1).values, taus

    def sample_trajectory_reuse(self, time_steps=5, control_seed = None, variance: int = 1, num_samples: int = 1, r = None, centers = None, radii = None, verified_controls = None, verified_indices = None, splits = None):
        """ Generate batch trajectories for multiple sampled controls and select the best one.

        - If `control_seed` is provided, it should have shape (num_envs, time_steps, 2).
        - Otherwise, random controls are sampled from a normal distribution.

        Returns:
        - Best control sequence (num_envs, time_steps, 2) -> **Direct input for `trajectories`**
        - Best trajectory per environment (num_envs, time_steps+1, 2)
        - Minimum final distance per environment (num_envs,)
        """

        #variance = self.max_speed/3
        # Initialize batch states (only positions)
        t_eval = torch.linspace(0.0, time_steps*self.rate, steps=time_steps, device = device)
        if r == None:
          r = torch.zeros_like(t_eval)
        # Initialize batch states (only positions)
        batch_states = self.states.clone().repeat_interleave(num_samples, dim = 0)  # Shape: (num_envs * num_samples, 2)
        if splits is not None:
          has_switched = (splits < 5).repeat_interleave(num_samples, dim = 0)


        # Handle control seeding
        if control_seed is not None:
            assert control_seed.shape == (self.num_envs, time_steps), \
                "control_seed must have shape (num_envs, time_steps, 2)"
            mu = control_seed.unsqueeze(1).repeat(1, num_samples, 1)  # Expand for num_samples
        else:
            mu = torch.zeros((self.num_envs, num_samples, time_steps), device=self.device)

        sigma = torch.ones_like(mu) * variance  # Add variance
        action_lists = torch.normal(mu, sigma)  # Sampled action sequences

        action_lists = action_lists.view(-1, time_steps)  # Reshape to (num_envs * num_samples, time_steps)
        action_lists = torch.clamp(action_lists, -self.max_speed, self.max_speed)

        # Track all trajectories
        trajectories = [batch_states.clone().unsqueeze(1)]  # Initial positions
        for t in range(time_steps):
            actions = action_lists[:, t]
            movement = self.f(batch_states, actions)
            batch_states = batch_states + movement * self.rate  # Move only by instantaneous action
            trajectories.append(batch_states.clone().unsqueeze(1))  # Store positions
            if centers is not None and radii is not None and verified_controls is not None and verified_indices is not None:
              # Only process environments that haven't switched yet
              need_switch = ~has_switched

              # Compute mask and idx just for those
              idx_all = torch.full((self.num_envs * num_samples,), -1, dtype=torch.long, device=device)
              idx_all[need_switch] = find_hypercubes(batch_states[need_switch], centers, radii)

              # Now determine who is eligible to switch
              eligible = need_switch & (idx_all > -1) & (verified_indices[idx_all] < time_steps - t)
              valid_indices = torch.nonzero(eligible, as_tuple=True)[0]

              # Perform the switch
              if valid_indices.numel() > 0 and (time_steps - t - 1) > 0:
                  action_lists[valid_indices, t+1:] = verified_controls[idx_all[valid_indices], :time_steps - t - 1]

              # Mark them as switched
              has_switched[valid_indices] = True



        # Stack to get (num_envs * num_samples, time_steps+1, 2)
        trajectories = torch.cat(trajectories, dim=1)

        trajectories = wrap_to_pi(trajectories)
        batch_states = wrap_to_pi(batch_states)

        # Compute final distances for selection
        final_positions = batch_states

        if self.eval_func:
          dist = torch.sqrt(((trajectories - self.target)**2 * torch.tensor([1.0, 0.01], device=device)).sum(dim=2))[:, 1:]
        else:
          dist = torch.linalg.norm(trajectories - self.target, dim=2, ord = float('inf'))[:, 1:]

        r = torch.repeat_interleave(r, num_samples)
        firstpart = dist*torch.exp(self.alpha*t_eval)
        secondpart = torch.mul(r.unsqueeze(1), torch.exp(((self.L + self.alpha)*t_eval)).unsqueeze(0))
        alphadist = firstpart + secondpart

        #final_distances = torch.linalg.norm(final_positions - self.target, dim=1, ord = float('inf'))
        alpha_distances = torch.min(alphadist, dim = 1).values
        taus = torch.argmin(alphadist, dim = 1)

        # Reshape distances to (num_envs, num_samples)
        alpha_distances = alpha_distances.view(self.num_envs, num_samples)
        taus = taus.view(self.num_envs, num_samples)

        # Find the best trajectory for each environment
        best_indices = torch.argmin(alpha_distances, dim=1)  # Best sample for each environment
        taus = taus[torch.arange(self.num_envs), best_indices]


        # Gather best actions and trajectories
        best_actions = action_lists.view(self.num_envs, num_samples, time_steps)[torch.arange(self.num_envs), best_indices]
        best_trajectories = trajectories.view(self.num_envs, num_samples, time_steps+1, self.dim)[torch.arange(self.num_envs), best_indices]

        return best_actions, best_trajectories, alpha_distances.min(dim=1).values, taus




    def sample_trajectory_reuse_new(self, time_steps=5, control_seed = None, variance: int = 1, num_samples: int = 1, r = None, centers = None, radii = None, verified_alphas = None, verified_controls = None, verified_indices = None, splits = None, unverified_centers = None, unverified_radii = None):
        """ Generate batch trajectories for multiple sampled controls and select the best one.

        - If `control_seed` is provided, it should have shape (num_envs, time_steps, 2).
        - Otherwise, random controls are sampled from a normal distribution.

        Returns:
        - Best control sequence (num_envs, time_steps, 2) -> **Direct input for `trajectories`**
        - Best trajectory per environment (num_envs, time_steps+1, 2)
        - Minimum final distance per environment (num_envs,)
        """

        #variance = self.max_speed/3
        # Initialize batch states (only positions)
        t_eval = torch.linspace(0.0, time_steps*self.rate, steps=time_steps, device = device)
        if r == None:
          r = torch.zeros_like(t_eval)
        r = torch.repeat_interleave(r, num_samples)
        # Initialize batch states (only positions)
        batch_states = self.states.clone().repeat_interleave(num_samples, dim = 0)  # Shape: (num_envs * num_samples, 2)
        if splits is not None:
          has_switched = (splits < 5).repeat_interleave(num_samples, dim = 0)

        # Each index i maps to class: i // n
        class_ids = torch.arange(self.num_envs * num_samples, device=device) // num_samples  # shape: (total,)

        # Track which classes have already succeeded
        class_success = torch.zeros(self.num_envs, dtype=torch.bool, device=device)

        # Track which entries are still active (not yet masked out)
        active_mask = torch.ones(self.num_envs * num_samples, dtype=torch.bool, device=device)
        taus = torch.zeros(self.num_envs, dtype=torch.long, device=device)
        best_indices = torch.zeros(self.num_envs, dtype=torch.long, device=device)



        # Handle control seeding
        if control_seed is not None:
            assert control_seed.shape == (self.num_envs, time_steps), \
                "control_seed must have shape (num_envs, time_steps, 2)"
            mu = control_seed.unsqueeze(1).repeat(1, num_samples, 1)  # Expand for num_samples
        else:
            mu = torch.zeros((self.num_envs, num_samples, time_steps), device=self.device)

        sigma = torch.ones_like(mu) * variance  # Add variance
        action_lists = torch.normal(mu, sigma)  # Sampled action sequences

        action_lists = action_lists.view(-1, time_steps)  # Reshape to (num_envs * num_samples, time_steps)
        action_lists = torch.clamp(action_lists, -self.max_speed, self.max_speed)

        # Track all trajectories
        trajectories = [batch_states.clone().unsqueeze(1)]  # Initial positions
        for t in range(time_steps):
            active_indices = active_mask.nonzero(as_tuple=True)[0]
            actions = action_lists[active_indices, t]
            movement = self.f(batch_states[active_indices], actions)
            batch_states[active_indices] = batch_states[active_indices] + movement * self.rate  # Move only by instantaneous action
            trajectories.append(batch_states.clone().unsqueeze(1))  # Store positions
            if torch.min(r)*np.exp(L*t*self.rate) < 2*torch.max(radii) and centers is not None and radii is not None and verified_controls is not None and verified_indices is not None:
              # Only process environments that haven't switched yet
              torch.cuda.empty_cache()
              gc.collect()
              idx = check_parallel(batch_states[active_indices], t, r[active_indices], self.alpha,
                               self.L, self.rate, num_samples,
                               unverified_centers, unverified_radii,
                               centers, radii,
                               verified_alphas, verified_indices, torch.cat(trajectories, dim = 1))

              global_success_indices = active_indices[idx]
              global_success_classes = class_ids[global_success_indices]

              # Unique classes that succeeded this step and haven't succeeded before
              new_success_class_mask = ~class_success[global_success_classes]
              new_success_indices = global_success_indices[new_success_class_mask]
              new_success_classes = global_success_classes[new_success_class_mask]

              # Update: mark class as succeeded
              class_success[new_success_classes] = True

              # Save *one* successful index for each class (first seen here)
              best_indices[new_success_classes] = new_success_indices

              taus[new_success_classes] = t

              # Deactivate all samples from succeeded classes
              active_mask = active_mask & (~class_success[class_ids])



        mask = ~class_success
        idxmask = (mask[class_ids]).nonzero(as_tuple=True)[0]


        # Stack to get (num_envs * num_samples, time_steps+1, 2)
        trajectories = torch.cat(trajectories, dim=1)

        trajectories = wrap_to_pi(trajectories)
        batch_states = wrap_to_pi(batch_states)

        # Compute final distances for selection
        final_positions = batch_states[idxmask]

        if self.eval_func:
          dist = torch.sqrt(((trajectories - self.target)**2 * torch.tensor([1.0, 0.01], device=device)).sum(dim=2))[:, 1:]
        else:
          dist = torch.linalg.norm(trajectories - self.target, dim=2, ord = float('inf'))[:, 1:]

        firstpart = dist*torch.exp(self.alpha*t_eval)
        secondpart = torch.mul(r.unsqueeze(1), torch.exp(((self.L + self.alpha)*t_eval)).unsqueeze(0))
        alphadist = firstpart + secondpart

        #final_distances = torch.linalg.norm(final_positions - self.target, dim=1, ord = float('inf'))
        alpha_distances = torch.min(alphadist, dim = 1).values
        newtaus = torch.argmin(alphadist, dim = 1)

        # Reshape distances to (num_envs, num_samples)
        alpha_distances = alpha_distances.view(self.num_envs, num_samples)
        newtaus = newtaus.view(self.num_envs, num_samples)

        # Find the best trajectory for each environment
        best_indices[mask] = torch.argmin(alpha_distances, dim=1)[mask]  # Best sample for each environment
        best_indices = best_indices % num_samples

        taus[mask] = newtaus[torch.arange(self.num_envs), best_indices][mask]

        # Gather best actions and trajectories
        best_actions = action_lists.view(self.num_envs, num_samples, time_steps)[torch.arange(self.num_envs), best_indices]
        best_trajectories = trajectories.view(self.num_envs, num_samples, time_steps+1, self.dim)[torch.arange(self.num_envs), best_indices]

        return best_actions, best_trajectories, alpha_distances.min(dim=1).values, taus, class_success

    def trajectories(self, controls):
        """
        Run forward only on environments where controls are not sentinel for 'None'.

        Parameters:
        - controls (torch.Tensor): Shape (num_envs, time_steps)

        Returns:
        - trajectories (torch.Tensor): Shape (num_envs, time_steps+1, 2)
        """
        assert controls.shape[0] == self.num_envs, "Controls must match (num_envs, time_steps)"
        time_steps = controls.shape[1]
        batch_states = self.states.clone()  # shape (num_envs, 2)
        trajectories = [batch_states.clone().unsqueeze(1)]  # start with initial state

        for t in range(time_steps):
            actions = controls[:, t]
            mask = actions != sentinel  # shape (num_envs,), True where control is valid

            next_states = batch_states.clone()  # ← Clone before modification
            if mask.any():
                movement = self.f(batch_states[mask], actions[mask])  # only apply f to valid envs
                next_states[mask] += movement * self.rate

            trajectories.append(next_states.unsqueeze(1))  # ← Append the consistent snapshot
            batch_states = next_states  # ← Move to next timestep

        return torch.cat(trajectories, dim=1)  # Shape: (num_envs, time_steps+1, 2)





In [13]:
from typing_extensions import final
import torch
import numpy as np
import itertools

class PendulumEnv_Old:
    def __init__(self,
                 num_envs=10,    # Number of parallel environments
                 target = None,   # Allow custom target location
                 max_radius=5.0,         # Define boundary size
                 max_speed=1.0,           # Limit agent's velocity
                 precision = 0.03,        # Define the allowed distance from the target
                 step_penalization = 2,   # How strongly each step taken is penalized
                 rate = 0.1,              # Rate of movement
                 field_function = inverted_pendulum_2d_torch,  # Allow different field functions
                 reward_type="distance",   # Allow different reward strategies
                 dimension = 2,
                 max_starting_speed = 1.0,
                 min_alpha = 0.01,
                 L = 5,
                 eval_func = False
                ):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.num_envs = num_envs

        self.max_speed = max_speed
        self.max_radius = max_radius
        self.precision = precision
        self.step_penalization = step_penalization
        self.rate = rate
        self.field_function = field_function
        self.reward_type = reward_type
        self.dim = dimension
        self.f = field_function
        self.max_starting_speed = max_starting_speed
        self.alpha = min_alpha
        self.L = L
        self.eval_func = eval_func

        if target == None:
          self.target = torch.zeros(self.dim, device = self.device)
        else:
          self.target = torch.tensor(target, device=self.device, dtype=torch.float32)


        # Initialize states (x, y, vx, vy) for all environments in parallel
        self.states = self.sample_points_in_circle(self.num_envs).to(self.device)
        self.states = self.sample_points_in_hypercube(self.num_envs).to(self.device)

    def sample_points_in_circle(self, batch_size):
        """ Sample random starting points within a circle of max_radius. """

        # Step 1: Sample from normal distribution (for uniform direction)
        points = torch.randn((batch_size, self.dim), device=self.device)

        # Step 2: Normalize to lie on the unit n-sphere
        points /= torch.norm(points, dim=1, keepdim=True)

        # Step 3: Sample radii with correct volume scaling
        radii = torch.rand((batch_size, 1), device=self.device) ** (1 / self.dim)  # Ensure uniform density

        # Step 4: Scale points by radii
        return self.max_radius * points * radii

    def sample_points(self, batch_size):
        """ Sample random starting points. """

        # Sample each coordinate independently from U(-1, 1)
        points = (torch.rand((batch_size, self.dim), device=self.device) * 2 - 1)  # Rescale [0,1] to [-1,1]

        points = points * torch.tensor([self.max_radius, self.max_starting_speed], device=self.device)

        return points

    def sample_points_in_hypercube(self, batch_size):
        """ Sample random starting points uniformly within a hypercube. """

        # Sample each coordinate independently from U(-1, 1)
        points = (torch.rand((batch_size, self.dim), device=self.device) * 2 - 1)  # Rescale [0,1] to [-1,1]

        # Scale by max_radius
        return self.max_radius * points


    def reset(self, seed = None):
        """ Reset all environments in parallel. """
        if seed is not None:
          self.states = seed
        else:
          self.states = self.sample_points_in_hypercube(self.num_envs)


    def step(self, actions):
        """ Step function that updates all environments in parallel. """
        actions = torch.clamp(actions, -self.max_speed, self.max_speed)

        movement = self.f(self.states, actions)
        self.states += movement * self.rate  # Update positions

        distances = torch.linalg.norm(self.states - self.target, dim=1, ord = float('inf'))
        rewards = -distances - self.step_penalization * self.rate

        # Out-of-bounds penalty
        out_of_bounds = distances > self.max_radius
        rewards[out_of_bounds] = -10000

        # Check if agents have reached the target
        done = distances < self.precision
        rewards[done] += 10000  # Bonus for reaching target

        stop = torch.logical_or(done,out_of_bounds)

        return self.states.clone(), rewards, stop

    def sample_trajectory(self, time_steps=5, control_seed = None, variance: int = 1, num_samples: int = 1, r = None):
        """ Generate batch trajectories for multiple sampled controls and select the best one.

        - If `control_seed` is provided, it should have shape (num_envs, time_steps, 2).
        - Otherwise, random controls are sampled from a normal distribution.

        Returns:
        - Best control sequence (num_envs, time_steps, 2) -> **Direct input for `trajectories`**
        - Best trajectory per environment (num_envs, time_steps+1, 2)
        - Minimum final distance per environment (num_envs,)
        """


        # Initialize batch states (only positions)
        t_eval = torch.linspace(0.0, time_steps*self.rate, steps=time_steps, device = device)
        if r == None:
          r = torch.zeros_like(t_eval)
        # Initialize batch states (only positions)
        batch_states = self.states.clone().repeat_interleave(num_samples, dim = 0)  # Shape: (num_envs * num_samples, 4)

        # Handle control seeding
        if control_seed is not None:
            assert control_seed.shape == (self.num_envs, time_steps), \
                "control_seed must have shape (num_envs, time_steps, 2)"
            mu = control_seed.unsqueeze(1).repeat(1, num_samples, 1)  # Expand for num_samples
        else:
            mu = torch.zeros((self.num_envs, num_samples, time_steps), device=self.device)

        sigma = torch.ones_like(mu) * variance  # Add variance
        action_lists = torch.normal(mu, sigma)  # Sampled action sequences

        action_lists = action_lists.view(-1, time_steps)  # Reshape to (num_envs * num_samples, time_steps)
        action_lists = torch.clamp(action_lists, -self.max_speed, self.max_speed)

        # Track all trajectories
        trajectories = [batch_states.clone().unsqueeze(1)]  # Initial positions
        for t in range(time_steps):
            actions = action_lists[:, t]
            movement = self.f(batch_states, actions)
            batch_states = batch_states + movement * self.rate  # Move only by instantaneous action
            trajectories.append(batch_states.clone().unsqueeze(1))  # Store positions

        # Stack to get (num_envs * num_samples, time_steps+1, 2)
        trajectories = torch.cat(trajectories, dim=1)

        trajectories = wrap_to_pi(trajectories)
        batch_states = wrap_to_pi(batch_states)

        # Compute final distances for selection
        final_positions = batch_states

        if self.eval_func:
          dist = torch.sqrt(((trajectories - self.target)**2 * torch.tensor([1.0, 0.01], device=device)).sum(dim=2))[:, 1:]
        else:
          dist = torch.linalg.norm(trajectories - self.target, dim=2, ord = float('inf'))[:, 1:]

        r = torch.repeat_interleave(r, num_samples)
        firstpart = dist*torch.exp(self.alpha*t_eval)
        secondpart = torch.mul(r.unsqueeze(1), torch.exp(((self.L + self.alpha)*t_eval)).unsqueeze(0))
        alphadist = firstpart + secondpart

        #final_distances = torch.linalg.norm(final_positions - self.target, dim=1, ord = float('inf'))
        alpha_distances = torch.min(alphadist, dim = 1).values
        taus = torch.argmin(alphadist, dim = 1)

        # Reshape distances to (num_envs, num_samples)
        alpha_distances = alpha_distances.view(self.num_envs, num_samples)
        taus = taus.view(self.num_envs, num_samples)

        # Find the best trajectory for each environment
        best_indices = torch.argmin(alpha_distances, dim=1)  # Best sample for each environment
        taus = taus[torch.arange(self.num_envs), best_indices]


        # Gather best actions and trajectories
        best_actions = action_lists.view(self.num_envs, num_samples, time_steps)[torch.arange(self.num_envs), best_indices]
        best_trajectories = trajectories.view(self.num_envs, num_samples, time_steps+1, self.dim)[torch.arange(self.num_envs), best_indices]

        return best_actions, best_trajectories, alpha_distances.min(dim=1).values, taus



    def trajectories(self, controls):
        """
        Run forward only on environments where controls are not sentinel for 'None'.

        Parameters:
        - controls (torch.Tensor): Shape (num_envs, time_steps)

        Returns:
        - trajectories (torch.Tensor): Shape (num_envs, time_steps+1, 2)
        """
        assert controls.shape[0] == self.num_envs, "Controls must match (num_envs, time_steps)"
        time_steps = controls.shape[1]
        batch_states = self.states.clone()  # shape (num_envs, 2)
        trajectories = [batch_states.clone().unsqueeze(1)]  # start with initial state

        for t in range(time_steps):
            actions = controls[:, t]
            mask = actions != sentinel  # shape (num_envs,), True where control is valid

            next_states = batch_states.clone()  # ← Clone before modification
            if mask.any():
                movement = self.f(batch_states[mask], actions[mask])  # only apply f to valid envs
                next_states[mask] += movement * self.rate

            trajectories.append(next_states.unsqueeze(1))  # ← Append the consistent snapshot
            batch_states = next_states  # ← Move to next timestep

        return torch.cat(trajectories, dim=1)  # Shape: (num_envs, time_steps+1, 2)





In [14]:
import torch

def findpath_pendulum(env,
             seeds=None,  # Now handles multiple starting points (batch)
             countermax: int = 1,
             num_samples: int = 1000,
             variance: int = 1,
             time_steps: int = 20,
             control_seed=None,
             r = None):
    """
    Parallelized version of findpath, optimizing multiple paths simultaneously.

    Parameters:
    - env: ParallelPointNavigationEnv
    - seeds: Tensor of shape (num_envs, 2) specifying multiple start points
    - countermax: Max iterations without improvement before stopping
    - num_samples: Number of sampled controls per iteration
    - variance: Noise variance in control sampling
    - time_steps: Number of steps per trajectory
    - control_seed: Initial control guess (num_envs, time_steps, 2)

    Returns:
    - Best control sequence (num_envs, time_steps, 2)
    - Number of iterations per environment
    """

    device = env.device
    num_envs = env.num_envs

    # If no seed is provided, generate random ones
    if seeds is None:
        seeds = env.sample_points_in_hypercube(num_envs).to(device)  # Shape: (num_envs, 2)

    env.states = seeds.clone()  # Reset all environments to their start positions

    winner = control_seed if control_seed is not None else None
    endval = torch.full((num_envs,), float('inf'), device=device)  # Store best distances
    counter = torch.zeros(num_envs, device=device)  # Counter for convergence
    iters = torch.zeros(num_envs, dtype=torch.int32, device=device)  # Iteration count per env
    realtaus = torch.zeros(num_envs, dtype=torch.int32, device=device)

    while torch.any(counter < countermax):  # Stop when all envs meet countermax
        env.states = seeds.clone()  # Reset start positions
        sol_actions, sol_trajectories, sol_distances, taus = env.sample_trajectory(
              time_steps=time_steps, control_seed=winner, num_samples=num_samples, variance=variance, r = r
        )


        # Update control sequences based on improvement
        improved = sol_distances < endval
        endval[improved] = sol_distances[improved]
        winner = torch.where(improved.unsqueeze(-1), sol_actions, winner if winner is not None else sol_actions)
        realtaus = torch.where(improved, taus, realtaus)


        # Update counter: If no improvement, increment; else, reset
        counter[improved] = 0
        counter[~improved] += 1

        iters += 1  # Count total iterations

    env.states = seeds.clone()  # Reset one final time
    return winner, iters, realtaus


In [15]:
import torch

def findpath_pendulum_reuse(env,
             seeds=None,  # Now handles multiple starting points (batch)
             countermax: int = 1,
             num_samples: int = 1000,
             variance: int = 1,
             time_steps: int = 20,
             control_seed=None,
             r = None,
             centers = None,
             radii = None,
             verified_alphas = None,
             verified_controls = None,
             verified_indices = None,
             splits = None,
             unverified_centers = None,
             unverified_radii = None):
    """
    Parallelized version of findpath, optimizing multiple paths simultaneously.

    Parameters:
    - env: ParallelPointNavigationEnv
    - seeds: Tensor of shape (num_envs, 2) specifying multiple start points
    - countermax: Max iterations without improvement before stopping
    - num_samples: Number of sampled controls per iteration
    - variance: Noise variance in control sampling
    - time_steps: Number of steps per trajectory
    - control_seed: Initial control guess (num_envs, time_steps, 2)

    Returns:
    - Best control sequence (num_envs, time_steps, 2)
    - Number of iterations per environment
    """

    device = env.device
    num_envs = env.num_envs

    # If no seed is provided, generate random ones
    if seeds is None:
        seeds = env.sample_points_in_hypercube(num_envs).to(device)  # Shape: (num_envs, 2)

    env.states = seeds.clone()  # Reset all environments to their start positions

    winner = control_seed if control_seed is not None else None
    endval = torch.full((num_envs,), float('inf'), device=device)  # Store best distances
    counter = torch.zeros(num_envs, device=device)  # Counter for convergence
    iters = torch.zeros(num_envs, dtype=torch.int32, device=device)  # Iteration count per env
    realtaus = torch.zeros(num_envs, dtype=torch.int32, device=device)

    while torch.any(counter < countermax):  # Stop when all envs meet countermax
        env.states = seeds.clone()  # Reset start positions
        if centers is not None and radii is not None and verified_controls is not None and verified_indices is not None and splits is not None:
          sol_actions, sol_trajectories, sol_distances, taus, reusing = env.sample_trajectory_reuse_new(
            time_steps=time_steps, control_seed=winner, num_samples=num_samples, variance=variance, r = r, centers = centers, radii = radii, verified_alphas = verified_alphas, verified_controls = verified_controls, verified_indices = verified_indices, splits = splits, unverified_centers = unverified_centers, unverified_radii = unverified_radii
          )
        else:
          sol_actions, sol_trajectories, sol_distances, taus = env.sample_trajectory(
              time_steps=time_steps, control_seed=winner, num_samples=num_samples, variance=variance, r = r
          )


        # Update control sequences based on improvement
        improved = sol_distances < endval
        endval[improved] = sol_distances[improved]
        winner = torch.where(improved.unsqueeze(-1), sol_actions, winner if winner is not None else sol_actions)
        realtaus = torch.where(improved, taus, realtaus)


        # Update counter: If no improvement, increment; else, reset
        counter[improved] = 0
        counter[~improved] += 1

        iters += 1  # Count total iterations

    env.states = seeds.clone()  # Reset one final time
    return winner, iters, realtaus, reusing


In [16]:
import torch

def findpath_pendulum_RL(env,
             seeds=None,  # Now handles multiple starting points (batch)
             countermax: int = 1,
             num_samples: int = 1000,
             variance: int = 1,
             time_steps: int = 20,
             control_seed=None,
             r = None,
             model = None):
    """
    Parallelized version of findpath, optimizing multiple paths simultaneously.

    Parameters:
    - env: ParallelPointNavigationEnv
    - seeds: Tensor of shape (num_envs, 2) specifying multiple start points
    - countermax: Max iterations without improvement before stopping
    - num_samples: Number of sampled controls per iteration
    - variance: Noise variance in control sampling
    - time_steps: Number of steps per trajectory
    - control_seed: Initial control guess (num_envs, time_steps, 2)

    Returns:
    - Best control sequence (num_envs, time_steps, 2)
    - Number of iterations per environment
    """

    if model == None:
      raise Exception("No model provided")
    # If no seed is provided, generate random ones
    if seeds is None:
        seeds = env.sample_points_in_hypercube(num_envs).to(device)  # Shape: (num_envs, 2)

    winners = []
    trajectories = []
    for seed in seeds:
      env.reset()
      winner = []
      env.unwrapped.state = seed.cpu().numpy()
      obs = env.unwrapped._get_obs()
      traj = [env.unwrapped.state]
      for i in range(time_steps):
        action, _ = model.predict(obs)
        winner.append(action)
        obs, reward, _, _, _ = env.step(action)
        traj.append(env.unwrapped.state)
      winners.append(winner)
      trajectories.append(traj)

    winners = torch.tensor(winners, device = device).squeeze(-1)
    trajectories = torch.tensor(trajectories, device = device)
    return winners, trajectories


In [17]:
import gc
import math
import itertools
import time

def pendulum_algo(d: int = 2, R: float = 2, epsilon: float = 0.01, L: float = 1.8, tau: float = 2, min_alpha: float = 0, batch_size: int = 500, dt: float = 0.1, speed: float = 2, max_splits: int = 3, num_samples: int = 1000, function = None, target = None, reuse = False, wrapper: bool = True) -> list:

  if target == None:
    target = torch.zeros(d)
  start = time.time()

  # Initialize the starting grid

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  prod_fun = torch.vmap(torch.cartesian_prod)

  num_layers = math.ceil(math.log(R/epsilon, 3))

  points = torch.tensor([round(2*epsilon*3**i, 4) for i in range(num_layers)], device = device)
  points = torch.vstack([-points, torch.zeros_like(points), points]).transpose(0,1).reshape(-1, 3)
  if d == 2:
    points = prod_fun(points, points).reshape(-1, d)
  elif d == 3:
    points = prod_fun(points, points, points).reshape(-1, d)
  elif d == 4:
    points = prod_fun(points, points, points, points).reshape(-1, d)
  elif d == 5:
    points = prod_fun(points, points, points, points, points).reshape(-1, d)
  elif d == 6:
    points = prod_fun(points, points, points, points, points, points).reshape(-1, d)
  points = points[torch.where((points != 0).any(dim = 1))].to(device)

  radius = [[round(epsilon*3**i, 4)]* (3**d - 1) for i in range (num_layers)]
  radius = list(itertools.chain.from_iterable(radius))
  radius = torch.tensor(radius, device = device)
  splits = torch.zeros_like(radius)
  controls = torch.zeros(len(points), int(tau/dt)).to(device)
  #radius = torch.max(torch.min(radius, (torch.abs(points[:,2])-epsilon/2)), torch.zeros_like(radius))

  # for i in range(4,num_layers):
  #   splits[i*(3**d-1):]  =  splits[i*(3**d-1):] - 1


  verified_points = torch.tensor([np.zeros(d)]).to(device)
  verified_radius = torch.tensor([epsilon]).to(device)
  verified_controls = controls[0].unsqueeze(0).to(device)
  verified_alphas = torch.tensor([0]).to(device)
  verified_indices = torch.tensor([1]).to(device)

  if wrapper:
    loc_mask = torch.abs(points[:, 0]) - radius <= torch.pi
    points = points[loc_mask]
    radius = radius[loc_mask]
    controls = controls[loc_mask]
    splits = splits[loc_mask]



  t_eval = torch.linspace(0.0, tau, steps=int(tau/dt + 1), device = device)


  while len(points > 0):
    print('Points:' + str(len(points)))
    split_points = points[:batch_size]
    split_radius = radius[:batch_size]
    split_controls = controls[:batch_size]
    split_splits = splits[:batch_size]
    env = PendulumEnv(num_envs = len(split_points), dimension = d, rate = dt, max_speed = speed, min_alpha = min_alpha, L = L, eval_func = False, field_function = function)
    points = points[batch_size:]
    radius = radius[batch_size:]
    controls = controls[batch_size:]
    splits = splits[batch_size:]
    env.reset(split_points)
    if reuse:
      winner, _, taus, reusing = findpath_pendulum(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius, centers = verified_points, radii = verified_radius, verified_controls=verified_controls, verified_indices=verified_indices, splits = split_splits, unverified_centers = torch.cat(points, split_points), unverified_radii = torch.cat(radius, split_radius))
    else:
      winner, _, taus = findpath_pendulum(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius)
    env.reset(split_points)
    sol = env.trajectories(winner)
    sol = torch.linalg.norm(wrap_to_pi(sol), dim = 2, ord = float('inf'))

    #sol = torch.linalg.norm(sol, dim = 2, ord = float('inf'))
    alpha, indices = search_alpha_parallel(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L)
    #alpha, indices = search_alpha_parallel(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L)
    #condition, indices = certify_alpha(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L, min_alpha)
    indices = indices + 1
    #mask = condition
    mask = alpha > min_alpha
    verified_points = torch.cat((verified_points, split_points[mask]), dim = 0)
    verified_radius = torch.cat((verified_radius, split_radius[mask]), dim = 0)
    verified_controls = torch.cat((verified_controls, winner[mask]), dim = 0)
    verified_alphas = torch.cat((verified_alphas, alpha[mask]), dim = 0)
    verified_indices = torch.cat((verified_indices, indices[mask]), dim = 0)
    if len(split_radius[~mask]) > 0:
        temp_combinations = torch.vstack([-2/3*split_radius[~mask], torch.zeros_like(split_radius[~mask]), 2/3*split_radius[~mask]]).transpose(0,1).reshape(-1, 3)
        if d == 2:
          temp_combinations = prod_fun(temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 3:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 4:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 5:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 6:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        new_points = split_points[~mask].repeat_interleave(3**d, dim = 0) + temp_combinations
        new_radius = split_radius[~mask].repeat_interleave(3**d, dim = 0)*1/3
        new_splits = split_splits[~mask].repeat_interleave(3**d, dim = 0) + 1
        new_controls = winner[~mask].repeat_interleave(3**d, dim=0)

        points = torch.cat((points, new_points), 0)
        radius = torch.cat((radius, new_radius), 0)
        controls = torch.cat((controls, new_controls), 0)
        splits = torch.cat((splits, new_splits), 0)
        mask1 = torch.linalg.norm(points, dim = 1, ord = float('inf')) - radius <= R
        mask2 = torch.abs(points[:, 0]) - radius <= torch.pi
        mask3 = splits < max_splits
        mask = mask1 & mask2 & mask3
        points = points[mask]
        radius = radius[mask]
        controls = controls[mask]
        splits = splits[mask]


  print(time.time() - start)

  return verified_points, verified_radius, verified_controls, verified_alphas, verified_indices
  #return verified_points, verified_radius, verified_controls, verified_indices



In [18]:
import gc
import math
import itertools
import time

def pendulum_algo_reuse(d: int = 2, R: float = 2, epsilon: float = 0.01, L: float = 1.8, tau: float = 2, min_alpha: float = 0, batch_size: int = 500, dt: float = 0.1, speed: float = 2, max_splits: int = 3, num_samples: int = 1000, function = None, target = None, reuse = False, wrapper: bool = True, savetime: int = 1800, cont: bool = False, conttime: int = None, contlist = None) -> list:

  if target == None:
    target = torch.zeros(d)
  start = time.time()
  startreset = time.time()
  resetcount = 1

  # Initialize the starting grid

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  prod_fun = torch.vmap(torch.cartesian_prod)

  num_layers = math.ceil(math.log(R/epsilon, 3))

  points = torch.tensor([round(2*epsilon*3**i, 4) for i in range(num_layers)], device = device)
  points = torch.vstack([-points, torch.zeros_like(points), points]).transpose(0,1).reshape(-1, 3)
  if d == 2:
    points = prod_fun(points, points).reshape(-1, d)
  elif d == 3:
    points = prod_fun(points, points, points).reshape(-1, d)
  elif d == 4:
    points = prod_fun(points, points, points, points).reshape(-1, d)
  elif d == 5:
    points = prod_fun(points, points, points, points, points).reshape(-1, d)
  elif d == 6:
    points = prod_fun(points, points, points, points, points, points).reshape(-1, d)
  points = points[torch.where((points != 0).any(dim = 1))].to(device)

  radius = [[round(epsilon*3**i, 4)]* (3**d - 1) for i in range (num_layers)]
  radius = list(itertools.chain.from_iterable(radius))
  radius = torch.tensor(radius, device = device)
  splits = torch.zeros_like(radius)
  controls = torch.zeros(len(points), int(tau/dt)).to(device)
  #radius = torch.max(torch.min(radius, (torch.abs(points[:,2])-epsilon/2)), torch.zeros_like(radius))

  # for i in range(4,num_layers):
  #   splits[i*(3**d-1):]  =  splits[i*(3**d-1):] - 1


  verified_points = torch.tensor([np.zeros(d)]).to(device)
  verified_radius = torch.tensor([epsilon]).to(device)
  verified_controls = controls[0].unsqueeze(0).to(device)
  verified_alphas = torch.tensor([0]).to(device)
  verified_indices = torch.tensor([1]).to(device)
  unverified_points = None
  unverified_radius = None

  if wrapper:
    loc_mask = torch.abs(points[:, 0]) - radius <= torch.pi
    points = points[loc_mask]
    radius = radius[loc_mask]
    controls = controls[loc_mask]
    splits = splits[loc_mask]

  if cont:
    if conttime is not None:
      name = 'timesave_' + str(int(conttime))
      with open(name + ".pkl", "rb") as f:
        verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, points, radius, controls, splits, unverified_points, unverified_radius = pickle.load(f)
        resetcount = (conttime/savetime) + 1
    elif contlist is not None:
      verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, points, radius = contlist
      controls = torch.zeros(len(points), int(tau/dt)).to(device)
      splits = torch.zeros_like(radius)


  t_eval = torch.linspace(0.0, tau, steps=int(tau/dt + 1), device = device)


  while len(points > 0):
    if time.time() - startreset > savetime:
      name = 'timesave_' + str(resetcount * savetime)
      with open(name + ".pkl", "wb") as f:
        pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, points, radius, controls, splits, unverified_points, unverified_radius], f)
      files.download(name + ".pkl")
      startreset = time.time()
      resetcount += 1
    print('Points:' + str(len(points)))
    split_points = points[:batch_size]
    split_radius = radius[:batch_size]
    split_controls = controls[:batch_size]
    split_splits = splits[:batch_size]
    env = PendulumEnv(num_envs = len(split_points), dimension = d, rate = dt, max_speed = speed, min_alpha = min_alpha, L = L, eval_func = False, field_function = function)
    points = points[batch_size:]
    radius = radius[batch_size:]
    controls = controls[batch_size:]
    splits = splits[batch_size:]
    env.reset(split_points)
    if reuse:
      winner, _, taus, reusing = findpath_pendulum_reuse(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius, centers = verified_points, radii = verified_radius, verified_alphas = verified_alphas, verified_controls=verified_controls, verified_indices=verified_indices, splits = split_splits, unverified_centers = torch.cat((points, split_points)), unverified_radii = torch.cat((radius, split_radius)))
    else:
      winner, _, taus = findpath_pendulum(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius)
    env.reset(split_points)
    sol = env.trajectories(winner)
    sol = torch.linalg.norm(wrap_to_pi(sol), dim = 2, ord = float('inf'))
    alpha = min_alpha * torch.ones_like(split_radius, device = device)
    indices = taus
    #sol = torch.linalg.norm(sol, dim = 2, ord = float('inf'))
    print("Reusing out of Batch: ", reusing.sum().item(), '/', len(split_points))
    if (~reusing).sum() > 0:
      alpha_calc, indices_calc = search_alpha_parallel(sol[~reusing], torch.linalg.norm(split_points, dim = 1, ord = float('inf'))[~reusing], split_radius[~reusing], t_eval, L)
      alpha[~reusing] = alpha_calc
      indices[~reusing] = indices_calc
    #alpha, indices = search_alpha_parallel(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L)
    #condition, indices = certify_alpha(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L, min_alpha)
    indices = indices + 1
    #mask = condition
    mask = alpha >= min_alpha
    verified_points = torch.cat((verified_points, split_points[mask]), dim = 0)
    verified_radius = torch.cat((verified_radius, split_radius[mask]), dim = 0)
    verified_controls = torch.cat((verified_controls, winner[mask]), dim = 0)
    verified_alphas = torch.cat((verified_alphas, alpha[mask]), dim = 0)
    verified_indices = torch.cat((verified_indices, indices[mask]), dim = 0)
    if len(split_radius[~mask]) > 0:
        temp_combinations = torch.vstack([-2/3*split_radius[~mask], torch.zeros_like(split_radius[~mask]), 2/3*split_radius[~mask]]).transpose(0,1).reshape(-1, 3)
        if d == 2:
          temp_combinations = prod_fun(temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 3:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 4:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 5:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 6:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        new_points = split_points[~mask].repeat_interleave(3**d, dim = 0) + temp_combinations
        new_radius = split_radius[~mask].repeat_interleave(3**d, dim = 0)*1/3
        new_splits = split_splits[~mask].repeat_interleave(3**d, dim = 0) + 1
        new_controls = winner[~mask].repeat_interleave(3**d, dim=0)

        points = torch.cat((points, new_points), 0)
        radius = torch.cat((radius, new_radius), 0)
        controls = torch.cat((controls, new_controls), 0)
        splits = torch.cat((splits, new_splits), 0)
        mask1 = torch.linalg.norm(points, dim = 1, ord = float('inf')) - radius <= R
        mask2 = torch.abs(points[:, 0]) - radius <= torch.pi
        mask3 = splits < max_splits
        mask = mask1 & mask2 & mask3
        if torch.any(mask3 == False):
          if unverified_points == None:
            unverified_points = points[~mask]
            unverified_radius = radius[~mask]
          else:
            unverified_points = torch.cat((unverified_points, points[~mask]), dim = 0)
            unverified_radius = torch.cat((unverified_radius, radius[~mask]), dim = 0)
        points = points[mask]
        radius = radius[mask]
        controls = controls[mask]
        splits = splits[mask]



  print(time.time() - start)

  return verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius
  #return verified_points, verified_radius, verified_controls, verified_indices



In [19]:
import gc
import math
import itertools
import time

def pendulum_algo_alt(d: int = 2, R: float = 2, epsilon: float = 0.01, L: float = 1.8, tau: float = 2, min_alpha: float = 0, batch_size: int = 500, dt: float = 0.1, speed: float = 2, max_splits: int = 3, num_samples: int = 1000, function = None, target = None, reuse = False) -> list:

  if target == None:
    target = torch.zeros(d)
  start = time.time()

  # Initialize the starting grid

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  prod_fun = torch.vmap(torch.cartesian_prod)

  num_layers = math.ceil(math.log(R/epsilon, 3))

  points = torch.tensor([round(2*epsilon*3**i, 4) for i in range(num_layers)], device = device)
  points = torch.vstack([-points, torch.zeros_like(points), points]).transpose(0,1).reshape(-1, 3)
  if d == 2:
    points = prod_fun(points, points).reshape(-1, d)
  elif d == 3:
    points = prod_fun(points, points, points).reshape(-1, d)
  elif d == 4:
    points = prod_fun(points, points, points, points).reshape(-1, d)
  elif d == 5:
    points = prod_fun(points, points, points, points, points).reshape(-1, d)
  elif d == 6:
    points = prod_fun(points, points, points, points, points, points).reshape(-1, d)
  points = points[torch.where((points != 0).any(dim = 1))].to(device)

  radius = [[round(epsilon*3**i, 4)]* (3**d - 1) for i in range (num_layers)]
  radius = list(itertools.chain.from_iterable(radius))
  radius = torch.tensor(radius, device = device)
  splits = torch.zeros_like(radius)
  controls = torch.zeros(len(points), int(tau/dt)).to(device)
  #radius = torch.max(torch.min(radius, (torch.abs(points[:,2])-epsilon/2)), torch.zeros_like(radius))

  # for i in range(5,num_layers):
  #   splits[i*(3**d-1):]  =  splits[i*(3**d-1):] - 1


  verified_points = torch.tensor([np.zeros(d)]).to(device)
  verified_radius = torch.tensor([epsilon]).to(device)
  verified_controls = controls[0].unsqueeze(0).to(device)
  verified_alphas = torch.tensor([0]).to(device)
  verified_indices = torch.tensor([1]).to(device)

  t_eval = torch.linspace(0.0, tau, steps=int(tau/dt + 1), device = device)


  while len(points > 0):
    print('Points:' + str(len(points)))
    #points[points[:,2] == 0] += torch.tensor([0,0,0.0001,0]).to(device)
    split_points = points[:batch_size]
    split_radius = radius[:batch_size]
    split_controls = controls[:batch_size]
    split_splits = splits[:batch_size]
    env = PendulumEnv(num_envs = len(split_points), dimension = d, rate = dt, max_speed = speed, min_alpha = min_alpha, L = L, eval_func = False, field_function = function)
    points = points[batch_size:]
    radius = radius[batch_size:]
    controls = controls[batch_size:]
    splits = splits[batch_size:]
    env.reset(split_points)
    if reuse:
      winner, _, taus = findpath_pendulum(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius, centers = verified_points, radii = verified_radius, verified_controls=verified_controls, verified_indices=verified_indices)
    else:
      winner, _, taus = findpath_pendulum(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius)
    env.reset(split_points)
    sol = env.trajectories(winner)
    sol = torch.sqrt((wrap_to_pi(sol)**2 * torch.tensor([1.0, 0.01], device=device)).sum(dim=2))
    #alpha, indices = search_alpha_parallel(sol, sol[:, 0], split_radius, t_eval, L)
    condition, indices = certify_alpha(sol, sol[:, 0], split_radius, t_eval, L, min_alpha)
    indices = indices + 1
    mask = condition
    #mask = alpha > min_alpha
    verified_points = torch.cat((verified_points, split_points[mask]), dim = 0)
    verified_radius = torch.cat((verified_radius, split_radius[mask]), dim = 0)
    verified_controls = torch.cat((verified_controls, winner[mask]), dim = 0)
    #verified_alphas = torch.cat((verified_alphas, alpha[mask]), dim = 0)
    verified_indices = torch.cat((verified_indices, indices[mask]), dim = 0)
    if len(split_radius[~mask]) > 0:
        temp_combinations = torch.vstack([-2/3*split_radius[~mask], torch.zeros_like(split_radius[~mask]), 2/3*split_radius[~mask]]).transpose(0,1).reshape(-1, 3)
        if d == 2:
          temp_combinations = prod_fun(temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 3:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 4:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 5:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 6:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        new_points = split_points[~mask].repeat_interleave(3**d, dim = 0) + temp_combinations
        new_radius = split_radius[~mask].repeat_interleave(3**d, dim = 0)*1/3
        new_splits = split_splits[~mask].repeat_interleave(3**d, dim = 0) + 1
        new_controls = winner[~mask].repeat_interleave(3**d, dim=0)

        points = torch.cat((points, new_points), 0)
        radius = torch.cat((radius, new_radius), 0)
        controls = torch.cat((controls, new_controls), 0)
        splits = torch.cat((splits, new_splits), 0)
        mask1 = torch.linalg.norm(points, dim = 1, ord = float('inf')) - radius <= R
        mask2 = torch.abs(points[:, 0]) - radius <= torch.pi
        mask3 = splits < max_splits
        mask = mask1 & mask2 & mask3
        points = points[mask]
        radius = radius[mask]
        controls = controls[mask]
        splits = splits[mask]



  print(time.time() - start)

  #return verified_points, verified_radius, verified_controls, verified_alphas, verified_indices
  return verified_points, verified_radius, verified_controls, verified_indices



In [20]:
import gc
import math
import itertools
import time


def pendulum_algo_RL(d: int = 2, R: float = 2, epsilon: float = 0.01, L: float = 1.8, tau: float = 2, min_alpha: float = 0, batch_size: int = 500, dt: float = 0.1, speed: float = 2, max_splits: int = 3, num_samples: int = 1000, function = None, target = None, model = None, wrapper: bool = True) -> list:

  if model == None:
    raise Exception('No model provided')
  if target == None:
    target = torch.zeros(d)
  start = time.time()

  # Initialize the starting grid

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  prod_fun = torch.vmap(torch.cartesian_prod)

  num_layers = math.ceil(math.log(R/epsilon, 3))

  points = torch.tensor([round(2*epsilon*3**i, 4) for i in range(num_layers)], device = device)
  points = torch.vstack([-points, torch.zeros_like(points), points]).transpose(0,1).reshape(-1, 3)
  if d == 2:
    points = prod_fun(points, points).reshape(-1, d)
  elif d == 3:
    points = prod_fun(points, points, points).reshape(-1, d)
  elif d == 4:
    points = prod_fun(points, points, points, points).reshape(-1, d)
  elif d == 5:
    points = prod_fun(points, points, points, points, points).reshape(-1, d)
  elif d == 6:
    points = prod_fun(points, points, points, points, points, points).reshape(-1, d)
  points = points[torch.where((points != 0).any(dim = 1))].to(device)

  radius = [[round(epsilon*3**i, 4)]* (3**d - 1) for i in range (num_layers)]
  radius = list(itertools.chain.from_iterable(radius))
  radius = torch.tensor(radius, device = device)
  splits = torch.zeros_like(radius)
  controls = torch.zeros(len(points), int(tau/dt)).to(device)
  #radius = torch.max(torch.min(radius, (torch.abs(points[:,2])-epsilon/2)), torch.zeros_like(radius))

  for i in range(5,num_layers):
    splits[i*(3**d-1):]  =  splits[i*(3**d-1):] - 1

  if wrapper:
    loc_mask = torch.abs(points[:, 0]) - radius <= torch.pi
    points = points[loc_mask]
    radius = radius[loc_mask]
    controls = controls[loc_mask]
    splits = splits[loc_mask]


  verified_points = torch.tensor([np.zeros(d)]).to(device)
  verified_radius = torch.tensor([epsilon]).to(device)
  verified_controls = controls[0].unsqueeze(0).to(device)
  verified_alphas = torch.tensor([0]).to(device)
  verified_indices = torch.tensor([1]).to(device)

  t_eval = torch.linspace(0.0, tau, steps=int(tau/dt + 1), device = device)


  while len(points > 0):
    print('Points:' + str(len(points)))
    #points[points[:,2] == 0] += torch.tensor([0,0,0.0001,0]).to(device)
    split_points = points[:batch_size]
    split_radius = radius[:batch_size]
    split_controls = controls[:batch_size]
    split_splits = splits[:batch_size]
    env = gym.make("Pendulum-v1")
    points = points[batch_size:]
    radius = radius[batch_size:]
    controls = controls[batch_size:]
    splits = splits[batch_size:]
    env.reset()
    winner, sol = findpath_pendulum_RL(env = env, seeds = split_points, time_steps = int(tau/dt), control_seed = split_controls, num_samples = num_samples, r = split_radius, model = model, field_function = function)
    #sol = torch.sqrt((wrap_to_pi(sol)**2 * torch.tensor([1.0, 0.2], device=device)).sum(dim=2))
    sol = torch.linalg.norm(sol, dim = 2, ord = float('inf'))
    alpha, indices = search_alpha_parallel(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L)
    #alpha, indices = search_alpha_parallel(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L)
    #condition, indices = certify_alpha(sol, torch.linalg.norm(split_points, dim = 1, ord = float('inf')), split_radius, t_eval, L, min_alpha)
    indices = indices + 1
    #mask = condition
    mask = alpha > min_alpha
    verified_points = torch.cat((verified_points, split_points[mask]), dim = 0)
    verified_radius = torch.cat((verified_radius, split_radius[mask]), dim = 0)
    verified_controls = torch.cat((verified_controls, winner[mask]), dim = 0)
    verified_alphas = torch.cat((verified_alphas, alpha[mask]), dim = 0)
    verified_indices = torch.cat((verified_indices, indices[mask]), dim = 0)
    if len(split_radius[~mask]) > 0:
        temp_combinations = torch.vstack([-2/3*split_radius[~mask], torch.zeros_like(split_radius[~mask]), 2/3*split_radius[~mask]]).transpose(0,1).reshape(-1, 3)
        if d == 2:
          temp_combinations = prod_fun(temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 3:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 4:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 5:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        elif d == 6:
          temp_combinations = prod_fun(temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations, temp_combinations).reshape(-1, d)
        new_points = split_points[~mask].repeat_interleave(3**d, dim = 0) + temp_combinations
        new_radius = split_radius[~mask].repeat_interleave(3**d, dim = 0)*1/3
        new_splits = split_splits[~mask].repeat_interleave(3**d, dim = 0) + 1
        new_controls = winner[~mask].repeat_interleave(3**d, dim=0)

        points = torch.cat((points, new_points), 0)
        radius = torch.cat((radius, new_radius), 0)
        controls = torch.cat((controls, new_controls), 0)
        splits = torch.cat((splits, new_splits), 0)
        mask1 = torch.linalg.norm(points, dim = 1, ord = float('inf')) - radius <= R
        mask2 = torch.abs(points[:, 0]) - radius <= torch.pi
        mask3 = splits < max_splits
        mask = mask1 & mask2 & mask3
        points = points[mask]
        radius = radius[mask]
        controls = controls[mask]
        splits = splits[mask]


  print(time.time() - start)

  return verified_points, verified_radius, verified_controls, verified_alphas, verified_indices
  #return verified_points, verified_radius, verified_controls, verified_indices



In [21]:
import torch

def pad_tensor_rows_1d(tensor, indices, sentinel=sentinel):
    """
    Replace the last (n - k_i) columns of each row in a (m, n) tensor with a sentinel value.

    Args:
        tensor (torch.Tensor): The input tensor of shape (m, n).
        indices (torch.Tensor): A 1D integer tensor of shape (m,), where each value k_i specifies how many
                                columns should be kept for row i.
        sentinel (float or int): Value to insert in padded positions.

    Returns:
        torch.Tensor: The modified tensor with sentinel padding applied.
    """
    m, n = tensor.shape

    # Generate mask: True where valid, False where to fill with sentinel
    column_indices = torch.arange(n, device=tensor.device).expand(m, n)  # (m, n)
    mask = column_indices < indices.unsqueeze(1)  # (m, n)

    # Fill invalid entries with sentinel
    return tensor.masked_fill(~mask, sentinel)


In [22]:
from matplotlib.ticker import MultipleLocator, FuncFormatter
import numpy as np
import torch
from matplotlib import pyplot as plt

def f_printer(yes_points, yes_radius, d: int = 2, R: float = np.pi):

  from matplotlib import pyplot as plt
  # Plot the phase portrait using quiver
  plt.figure(figsize=(6, 6))
  fig, ax = plt.subplots()

  def format_func(value, tick_number):
      N = int(np.round(value / np.pi))
      if N == 0:
          return "0"
      elif N == 1:
          return r"$\pi$"
      elif N == -1:
          return r"-$\pi$"
      else:
          return r"${0}\pi$".format(N)

  ax.xaxis.set_major_locator(MultipleLocator(base=np.pi))
  ax.xaxis.set_major_formatter(FuncFormatter(format_func))
  ax.yaxis.set_major_locator(MultipleLocator(base=np.pi))
  ax.yaxis.set_major_formatter(FuncFormatter(format_func))

  import matplotlib.pyplot as plt


  i = 0
  print(len(yes_points))
  # Plot each square
  for (x, y), r in zip(yes_points, yes_radius):
      if i % 5000 == 0:
        print(i)
      # Create a square patch centered at (x, y) with side length = 2*r
      square = plt.Rectangle((x - r, y - r), 2 * r, 2 * r, color='blue', alpha=0.5)
      ax.add_patch(square)
      i+=1


  plt.xlabel('x')
  plt.ylabel('y')
  plt.grid(True)

  plt.xlim(-R, R)
  plt.ylim(-R, R)

  plt.show()




In [23]:
import time
import torch

def interpolate_trajectory_pendulum(points, centers, radii, controls, indices, precision: float = 0.03, target=torch.tensor([0, 0])):
    start = time.time()
    target_seed = target
    found = torch.zeros((points.shape[0],), dtype=torch.bool, device=points.device)
    frozen_points = points.clone()
    i = 0
    counts = torch.zeros(points.shape[0], device=points.device)
    meandists = []
    maxdists = []

    print("Iteration " + str(i) + ":")
    dmax = torch.max(torch.linalg.norm(points, ord=float('inf'), dim=1)).item()
    dmean = torch.mean(torch.linalg.norm(points, ord=float('inf'), dim=1)).item()
    meandists.append(dmean)
    maxdists.append(dmax)
    print("Maximal Distance from the Origin: " + str(dmax))
    print("Average Distance from the Origin: " + str(dmean))

    trajectory = None
    end_controls = None
    overall_points = points.clone()
    failed_points = torch.zeros(1, points.shape[1], device=points.device)

    while not torch.all(found):
        points = wrap_to_pi(points)
        print(len(found))
        i += 1
        env = PendulumEnv(num_envs=points.shape[0], dimension=points.shape[1])
        target = target_seed.repeat(points.shape[0], 1)

        # Freeze already found points
        current_points = torch.where(found.unsqueeze(1), frozen_points, points)


        containing_indices = find_hypercubes(current_points, centers, radii)
        mask = containing_indices > -1
        failed_points = torch.cat((failed_points, current_points[~mask]), dim=0)

        if i == 1:
          print(failed_points)

        current_points = current_points[mask]
        containing_indices = containing_indices[mask]
        counts = counts[mask]
        target = target[mask]
        found = found[mask]
        frozen_points = frozen_points[mask]
        overall_points = overall_points[mask]
        if trajectory is not None:
            trajectory = trajectory[mask]
        if end_controls is not None:
            end_controls = end_controls[mask]

        times = indices[containing_indices]
        control = controls[containing_indices]
        cropped_control = pad_tensor_rows_1d(control, times)
        counts += (cropped_control != sentinel).sum(dim=1)
        if end_controls is None:
            end_controls = cropped_control
        else:
          end_controls = torch.cat((end_controls, cropped_control), dim=1)

        env = PendulumEnv(num_envs = len(current_points), dimension = 2, max_radius=np.pi/2)
        env.reset(current_points)
        traj = env.trajectories(cropped_control)
        traj = wrap_to_pi(traj)

        # Replace future trajectory of found points with their frozen position
        for j in range(len(found)):
            if found[j]:
                traj[j, :, :] = frozen_points[j]

        # Update frozen status
        points = traj[:, -1, :]
        mask = torch.linalg.norm(points - target, dim=1, ord=float('inf')) < precision
        newly_found = ~found & mask
        frozen_points[newly_found] = points[newly_found]
        found |= mask

        if trajectory is None:
            trajectory = traj
        else:
            trajectory = torch.cat((trajectory, traj[:, 1:, :]), dim=1)
            overall_points = torch.cat((overall_points, points), dim=1)

        print("Iteration " + str(i) + ":")
        dmax = torch.max(torch.linalg.norm(points, ord=float('inf'), dim=1)).item()
        dmean = torch.mean(torch.linalg.norm(points, ord=float('inf'), dim=1)).item()
        meandists.append(dmean)
        maxdists.append(dmax)
        print("Maximal Distance from the Origin: " + str(dmax))
        print("Average Distance from the Origin: " + str(dmean))

    print((time.time() - start) / points.shape[0])
    return trajectory, counts, overall_points, meandists, maxdists, end_controls, failed_points



In [24]:
import pickle
from google.colab import files

target_velocity = 0.0

d = 2
epsilon = 0.01
L = 2
tau = 3
Rnum = 5
batch_size = 100
min_alpha = 0.0001
speed = 0.3
num_samples=5000
function = inverted_pendulum_2d_torch
dt = 0.05
save = True
R = Rnum * np.pi

#for i in range(6):
  #print(i)
if True:
  i = 5
  name = function.__name__ + '_' + str(d) + '_' + str(epsilon) + '_' + str(L) + '_' + str(tau) + '_' + str(Rnum) + 'pi_' + str(batch_size) + '_' + str(min_alpha) + '_' + str(speed) + '_' + str(num_samples) + '_' + str(dt) + '_' + str(i) + '_alphas_fixed'
  verified_points, verified_radius, verified_controls, verified_alphas, verified_indices = pendulum_algo(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = False)
  name = 'A103'
  #verified_points, verified_radius, verified_controls, verified_indices = pendulum_algo(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt)
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")
  if torch.any(torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius):
    mi = torch.min((verified_points[:, 0] + verified_radius)[torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius]).cpu().detach().numpy().item()
    ma = torch.max((verified_points[:, 0] + verified_radius)[torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius]).cpu().detach().numpy().item()
    print((mi, ma))

  f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = (Rnum + 1)*np.pi)


  name = function.__name__ + '_' + str(d) + '_' + str(epsilon) + '_' + str(L) + '_' + str(tau) + '_' + str(Rnum) + 'pi_' + str(batch_size) + '_' + str(min_alpha) + '_' + str(speed) + '_' + str(num_samples) + '_' + str(dt) + '_' + str(i) + '_alphas_alt'
  name = 'A103'
  verified_points_alt, verified_radius_alt, verified_controls_alt, verified_alphas_alt, verified_indices_alt = pendulum_algo_alt(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = False)
  #verified_points, verified_radius, verified_controls, verified_indices = pendulum_algo(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt)
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points_alt, verified_radius_alt, verified_controls_alt, verified_alphas_alt, verified_indices_alt], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")
  if torch.any(torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius):
    mi = torch.min((verified_points[:, 0] + verified_radius)[torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius]).cpu().detach().numpy().item()
    ma = torch.max((verified_points[:, 0] + verified_radius)[torch.abs(verified_points[:, 1] - target_velocity) <= verified_radius]).cpu().detach().numpy().item()
    print((mi, ma))

  f_printer(verified_points_alt.detach().to('cpu').numpy(), verified_radius_alt.detach().to('cpu').numpy(), R = 4*np.pi)







  verified_points = torch.tensor([np.zeros(d)]).to(device)


Points:50
Points:408
Points:497


KeyboardInterrupt: 

In [25]:
import pickle
from google.colab import files

target_velocity = 0.0

d = 2
epsilon = 0.01
L = 2
tau = 3
Rnum = 5
batch_size = 10
min_alpha = 0.0001
speed = 0.3
num_samples=1000
function = inverted_pendulum_2d_torch
dt = 0.05
save = True
R = Rnum * np.pi

#for i in range(6):
  #print(i)
if True:
  i = 5
  verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius = pendulum_algo_reuse(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = True, savetime = 1800, cont = True, conttime = 5400)
  name = 'A141'
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")

  f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = (Rnum + 1)*np.pi)



  batch_size = 250
  tau = 1
  verified_points, verified_radius, verified_controls, verified_alphas, verified_indices = pendulum_algo(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = True)
  name = 'A132'
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")

  f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = (Rnum + 1)*np.pi)



  batch_size = 100
  tau = 2
  verified_points, verified_radius, verified_controls, verified_alphas, verified_indices = pendulum_algo(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = True)
  name = 'A133'
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")

  f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = (Rnum + 1)*np.pi)

Points:20491
Reusing out of Batch:  0 / 10
Points:20571
Reusing out of Batch:  0 / 10
Points:20642
Reusing out of Batch:  0 / 10
Points:20695
Reusing out of Batch:  0 / 10
Points:20775
Reusing out of Batch:  0 / 10
Points:20855
Reusing out of Batch:  0 / 10
Points:20935
Reusing out of Batch:  0 / 10
Points:21015
Reusing out of Batch:  0 / 10
Points:21095
Reusing out of Batch:  0 / 10
Points:21175
Reusing out of Batch:  0 / 10
Points:21255
Reusing out of Batch:  0 / 10
Points:21335
Reusing out of Batch:  0 / 10
Points:21415
Reusing out of Batch:  0 / 10
Points:21495
Reusing out of Batch:  0 / 10
Points:21575
Reusing out of Batch:  0 / 10
Points:21655
Reusing out of Batch:  0 / 10
Points:21735
Reusing out of Batch:  0 / 10
Points:21815
Reusing out of Batch:  0 / 10
Points:21895
Reusing out of Batch:  0 / 10
Points:21975
Reusing out of Batch:  0 / 10
Points:22055
Reusing out of Batch:  0 / 10
Points:22135
Reusing out of Batch:  0 / 10
Points:22215
Reusing out of Batch:  0 / 10
Points:2229

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:27655
Reusing out of Batch:  0 / 10
Points:27735
Reusing out of Batch:  0 / 10
Points:27815
Reusing out of Batch:  0 / 10
Points:27895
Reusing out of Batch:  0 / 10
Points:27975
Reusing out of Batch:  0 / 10
Points:28055
Reusing out of Batch:  0 / 10
Points:28135
Reusing out of Batch:  0 / 10
Points:28215
Reusing out of Batch:  0 / 10
Points:28295
Reusing out of Batch:  0 / 10
Points:28339
Reusing out of Batch:  0 / 10
Points:28329
Reusing out of Batch:  0 / 10
Points:28373
Reusing out of Batch:  0 / 10
Points:28390
Reusing out of Batch:  2 / 10
Points:28380
Reusing out of Batch:  0 / 10
Points:28433
Reusing out of Batch:  0 / 10
Points:28423
Reusing out of Batch:  4 / 10
Points:28431
Reusing out of Batch:  0 / 10
Points:28511
Reusing out of Batch:  0 / 10
Points:28591
Reusing out of Batch:  0 / 10
Points:28671
Reusing out of Batch:  0 / 10
Points:28751
Reusing out of Batch:  0 / 10
Points:28831
Reusing out of Batch:  0 / 10
Points:28911
Reusing out of Batch:  0 / 10
Points:2899

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:31027
Reusing out of Batch:  3 / 10
Points:31017
Reusing out of Batch:  6 / 10
Points:31007
Reusing out of Batch:  0 / 10
Points:30997
Reusing out of Batch:  8 / 10
Points:30987
Reusing out of Batch:  10 / 10
Points:30977
Reusing out of Batch:  10 / 10
Points:30967
Reusing out of Batch:  10 / 10
Points:30957
Reusing out of Batch:  4 / 10
Points:31001
Reusing out of Batch:  0 / 10
Points:31081
Reusing out of Batch:  0 / 10
Points:31161
Reusing out of Batch:  0 / 10
Points:31241
Reusing out of Batch:  0 / 10
Points:31321
Reusing out of Batch:  0 / 10
Points:31401
Reusing out of Batch:  0 / 10
Points:31454
Reusing out of Batch:  0 / 10
Points:31480
Reusing out of Batch:  0 / 10
Points:31515
Reusing out of Batch:  0 / 10
Points:31595
Reusing out of Batch:  0 / 10
Points:31675
Reusing out of Batch:  0 / 10
Points:31755
Reusing out of Batch:  0 / 10
Points:31835
Reusing out of Batch:  0 / 10
Points:31915
Reusing out of Batch:  0 / 10
Points:31941
Reusing out of Batch:  0 / 10
Points:3

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:32112
Reusing out of Batch:  6 / 10
Points:32102
Reusing out of Batch:  0 / 10
Points:32092
Reusing out of Batch:  0 / 10
Points:32082
Reusing out of Batch:  0 / 10
Points:32072
Reusing out of Batch:  0 / 10
Points:32062
Reusing out of Batch:  0 / 10
Points:32052
Reusing out of Batch:  0 / 10
Points:32042
Reusing out of Batch:  0 / 10
Points:32032
Reusing out of Batch:  0 / 10
Points:32022
Reusing out of Batch:  0 / 10
Points:32012
Reusing out of Batch:  0 / 10
Points:32002
Reusing out of Batch:  0 / 10
Points:31992
Reusing out of Batch:  0 / 10
Points:31982
Reusing out of Batch:  0 / 10
Points:31972
Reusing out of Batch:  0 / 10
Points:31962
Reusing out of Batch:  4 / 10
Points:31952
Reusing out of Batch:  0 / 10
Points:31942
Reusing out of Batch:  0 / 10
Points:31932
Reusing out of Batch:  0 / 10
Points:31985
Reusing out of Batch:  0 / 10
Points:32065
Reusing out of Batch:  0 / 10
Points:32145
Reusing out of Batch:  0 / 10
Points:32225
Reusing out of Batch:  0 / 10
Points:3230

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:33554
Reusing out of Batch:  0 / 10
Points:33634
Reusing out of Batch:  0 / 10
Points:33714
Reusing out of Batch:  0 / 10
Points:33794
Reusing out of Batch:  0 / 10
Points:33874
Reusing out of Batch:  0 / 10
Points:33954
Reusing out of Batch:  0 / 10
Points:34034
Reusing out of Batch:  0 / 10
Points:34114
Reusing out of Batch:  0 / 10
Points:34194
Reusing out of Batch:  0 / 10
Points:34274
Reusing out of Batch:  0 / 10
Points:34354
Reusing out of Batch:  0 / 10
Points:34434
Reusing out of Batch:  3 / 10
Points:34487
Reusing out of Batch:  10 / 10
Points:34477
Reusing out of Batch:  1 / 10
Points:34476
Reusing out of Batch:  6 / 10
Points:34484
Reusing out of Batch:  6 / 10
Points:34474
Reusing out of Batch:  0 / 10
Points:34500
Reusing out of Batch:  7 / 10
Points:34499
Reusing out of Batch:  1 / 10
Points:34489
Reusing out of Batch:  0 / 10
Points:34551
Reusing out of Batch:  0 / 10
Points:34631
Reusing out of Batch:  0 / 10
Points:34711
Reusing out of Batch:  0 / 10
Points:347

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:37727
Reusing out of Batch:  0 / 10
Points:37807
Reusing out of Batch:  0 / 10
Points:37887
Reusing out of Batch:  0 / 10
Points:37967
Reusing out of Batch:  0 / 10
Points:38047
Reusing out of Batch:  0 / 10
Points:38127
Reusing out of Batch:  0 / 10
Points:38207
Reusing out of Batch:  0 / 10
Points:38287
Reusing out of Batch:  0 / 10
Points:38367
Reusing out of Batch:  0 / 10
Points:38447
Reusing out of Batch:  0 / 10
Points:38527
Reusing out of Batch:  0 / 10
Points:38607
Reusing out of Batch:  0 / 10
Points:38687
Reusing out of Batch:  0 / 10
Points:38767
Reusing out of Batch:  0 / 10
Points:38847
Reusing out of Batch:  0 / 10
Points:38927
Reusing out of Batch:  0 / 10
Points:39007
Reusing out of Batch:  0 / 10
Points:39087
Reusing out of Batch:  0 / 10
Points:39167
Reusing out of Batch:  0 / 10
Points:39247
Reusing out of Batch:  0 / 10
Points:39327
Reusing out of Batch:  0 / 10
Points:39407
Reusing out of Batch:  0 / 10
Points:39487
Reusing out of Batch:  0 / 10
Points:3956

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:42527
Reusing out of Batch:  0 / 10
Points:42607
Reusing out of Batch:  0 / 10
Points:42687
Reusing out of Batch:  0 / 10
Points:42767
Reusing out of Batch:  0 / 10
Points:42847
Reusing out of Batch:  0 / 10
Points:42927
Reusing out of Batch:  0 / 10
Points:43007
Reusing out of Batch:  0 / 10
Points:43087
Reusing out of Batch:  3 / 10
Points:43095
Reusing out of Batch:  10 / 10
Points:43085
Reusing out of Batch:  9 / 10
Points:43075
Reusing out of Batch:  6 / 10
Points:43065
Reusing out of Batch:  10 / 10
Points:43055
Reusing out of Batch:  7 / 10
Points:43045
Reusing out of Batch:  7 / 10
Points:43044
Reusing out of Batch:  10 / 10
Points:43034
Reusing out of Batch:  10 / 10
Points:43024
Reusing out of Batch:  9 / 10
Points:43014
Reusing out of Batch:  8 / 10
Points:43004
Reusing out of Batch:  10 / 10
Points:42994
Reusing out of Batch:  10 / 10
Points:42984
Reusing out of Batch:  10 / 10
Points:42974
Reusing out of Batch:  10 / 10
Points:42964
Reusing out of Batch:  10 / 10
Po

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:44966
Reusing out of Batch:  0 / 10
Points:44992
Reusing out of Batch:  0 / 10
Points:44982
Reusing out of Batch:  0 / 10
Points:44990
Reusing out of Batch:  0 / 10
Points:45016
Reusing out of Batch:  0 / 10
Points:45006
Reusing out of Batch:  1 / 10
Points:45041
Reusing out of Batch:  2 / 10
Points:45103
Reusing out of Batch:  0 / 10
Points:45183
Reusing out of Batch:  9 / 10
Points:45182
Reusing out of Batch:  3 / 10
Points:45235
Reusing out of Batch:  2 / 10
Points:45297
Reusing out of Batch:  10 / 10
Points:45287
Reusing out of Batch:  8 / 10
Points:45295
Reusing out of Batch:  5 / 10
Points:45330
Reusing out of Batch:  0 / 10
Points:45410
Reusing out of Batch:  0 / 10
Points:45490
Reusing out of Batch:  0 / 10
Points:45570
Reusing out of Batch:  0 / 10
Points:45650
Reusing out of Batch:  0 / 10
Points:45730
Reusing out of Batch:  3 / 10
Points:45783
Reusing out of Batch:  0 / 10
Points:45863
Reusing out of Batch:  0 / 10
Points:45934
Reusing out of Batch:  0 / 10
Points:459

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:46523
Reusing out of Batch:  0 / 10
Points:46603
Reusing out of Batch:  0 / 10
Points:46683
Reusing out of Batch:  0 / 10
Points:46763
Reusing out of Batch:  0 / 10
Points:46843
Reusing out of Batch:  0 / 10
Points:46923
Reusing out of Batch:  0 / 10
Points:46985
Reusing out of Batch:  0 / 10
Points:47038
Reusing out of Batch:  8 / 10
Points:47046
Reusing out of Batch:  10 / 10
Points:47036
Reusing out of Batch:  1 / 10
Points:47107
Reusing out of Batch:  10 / 10
Points:47097
Reusing out of Batch:  7 / 10
Points:47114
Reusing out of Batch:  2 / 10
Points:47149
Reusing out of Batch:  9 / 10
Points:47139
Reusing out of Batch:  10 / 10
Points:47129
Reusing out of Batch:  10 / 10
Points:47119
Reusing out of Batch:  10 / 10
Points:47109
Reusing out of Batch:  9 / 10
Points:47108
Reusing out of Batch:  10 / 10
Points:47098
Reusing out of Batch:  9 / 10
Points:47097
Reusing out of Batch:  10 / 10
Points:47087
Reusing out of Batch:  10 / 10
Points:47077
Reusing out of Batch:  5 / 10
Poi

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Points:46797
Reusing out of Batch:  0 / 10
Points:46787
Reusing out of Batch:  0 / 10
Points:46777
Reusing out of Batch:  0 / 10
Points:46767
Reusing out of Batch:  0 / 10
Points:46757
Reusing out of Batch:  0 / 10
Points:46747
Reusing out of Batch:  0 / 10
Points:46737
Reusing out of Batch:  0 / 10
Points:46727
Reusing out of Batch:  0 / 10
Points:46717
Reusing out of Batch:  0 / 10
Points:46707
Reusing out of Batch:  0 / 10
Points:46697
Reusing out of Batch:  0 / 10
Points:46687
Reusing out of Batch:  0 / 10
Points:46677
Reusing out of Batch:  0 / 10
Points:46667
Reusing out of Batch:  0 / 10
Points:46657
Reusing out of Batch:  0 / 10
Points:46647
Reusing out of Batch:  0 / 10
Points:46637
Reusing out of Batch:  0 / 10
Points:46627
Reusing out of Batch:  0 / 10
Points:46626
Reusing out of Batch:  0 / 10
Points:46697
Reusing out of Batch:  0 / 10
Points:46777
Reusing out of Batch:  0 / 10
Points:46857
Reusing out of Batch:  0 / 10
Points:46937
Reusing out of Batch:  0 / 10
Points:4701

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.64 GiB. GPU 0 has a total capacity of 14.74 GiB of which 3.63 GiB is free. Process 2440 has 11.11 GiB memory in use. Of the allocated memory 10.95 GiB is allocated by PyTorch, and 44.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import pickle
from google.colab import files

target_velocity = 0.0

d = 2
epsilon = 0.01
L = 5
tau = 2
Rnum = 5
batch_size = 8
min_alpha = 0.0001
speed = 0.4
num_samples=1000
function = simplified_pendulum_derivatives
dt = 0.05
save = True
R = Rnum * np.pi

#for i in range(6):
  #print(i)
if True:
  i = 5
  verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius = pendulum_algo_reuse(d = d, epsilon = epsilon, L = L, tau = tau, R = R, batch_size = batch_size, min_alpha = min_alpha, speed = speed, max_splits = i, num_samples=num_samples, dt = dt, function = function, reuse = True, savetime = 1800, cont = True, conttime = None, contlist = [verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius])
  name = 'A132'
  if save:
    with open(name + ".pkl", "wb") as f:
      pickle.dump([verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, unverified_points, unverified_radius], f)
      #pickle.dump([verified_points, verified_radius, verified_controls, verified_indices], f)
    files.download(name + ".pkl")

  f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = (Rnum + 1)*np.pi)

In [None]:
len(unverified_points)

In [None]:
f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = 4*np.pi)

In [None]:
import torch

def approximate_one_sided_lipschitz_controlled_linf(f, a, b, u_max, num_samples=100000, device="cpu", seed=0, l = 0.5, m = 0.1):
    """
    Approximate the one-sided Lipschitz constant over all x, y in domain and u in [-u_max, u_max].

    Parameters:
    - f: function f(x, u), batched over x (N, 2) and scalar u
    - a, b: state space bounds
    - u_max: max control input
    - num_samples: number of (x, y, u) triples
    - device: torch device
    - seed: random seed

    Returns:
    - approx_L: worst-case one-sided Lipschitz constant (L-infinity norm)
    """
    torch.manual_seed(seed)

    # Sample random states
    x = torch.empty((num_samples, 2), device=device).uniform_(-a, a)
    x[:, 1] = torch.empty(num_samples, device=device).uniform_(-b, b)

    y = torch.empty((num_samples, 2), device=device).uniform_(-a, a)
    y[:, 1] = torch.empty(num_samples, device=device).uniform_(-b, b)

    dx = x - y
    norm_inf2_dx = torch.max(torch.abs(dx), dim=1).values**2

    # Sample shared control inputs
    u = torch.empty((num_samples,), device=device).uniform_(-u_max, u_max)

    # Filter zero-distance cases
    nonzero = norm_inf2_dx > 1e-8
    x, y, dx, norm_inf2_dx, u = x[nonzero], y[nonzero], dx[nonzero], norm_inf2_dx[nonzero], u[nonzero]

    # Evaluate vector field
    fx = f(x, u, m = m, l = l)
    fy = f(y, u, m = m, l = l)
    df = fx - fy

    # Inner product numerator
    numerator = torch.sum(df * dx, dim=1)
    ratio = numerator / norm_inf2_dx

    return ratio.max().item()

def timecheck(l = 0.05, u_max = 0.3, m = 0.1):
  state = -torch.tensor([[torch.pi/12, 0]])
  start = state.clone()
  control = torch.tensor([u_max])
  i = 0
  while torch.abs(state[0][0]) >= torch.abs(start[0][0])-0.01:
    state = wrap_to_pi(state)
    delta = inverted_pendulum_2d_torch(state, -control, m = m, l = l, g = 9.81)
    state = state + delta * 0.05
    i+=1
  return i/20

c = 3
m = 0.1
u_max = c * m
for l in [0.1, 0.5, 1, 2, 5, 10, 20, 50, 100, 1000, 10000]:
  L_inf_os = approximate_one_sided_lipschitz_controlled_linf(inverted_pendulum_2d_torch, a=torch.pi, b=5*torch.pi, u_max=u_max,  device="cuda", num_samples=100000, l = l, m = m)
  print(f"Approximate L-infinity one-sided Lipschitz constant: {L_inf_os:.4f}")
  checktime = timecheck(l = l, u_max = u_max, m = m)
  #print(f"Approximate Time: {checktime:.4f}")
  print(f"l: {l}, Total:  {checktime * L_inf_os:.4f}")



In [None]:
state = torch.tensor([[torch.pi, 0]])
c = 3
m = 0.1
l = 10
control = torch.tensor([c * m])
i = 0
while state[0][1] > -0.01:
  state = wrap_to_pi(state)
  delta = inverted_pendulum_2d_torch(state, control, m = m, l = l, g = 9.81)
  state = state + delta * 0.05
  print(state)
  i+=1
while state[0][1] < 0.01:
  state = wrap_to_pi(state)
  delta = inverted_pendulum_2d_torch(state, -control, m = m, l = l, g = 9.81)
  state = state + delta * 0.05
  print(state)
  i+=1
while state[0][1] > -0.01:
  state = wrap_to_pi(state)
  delta = inverted_pendulum_2d_torch(state, control, m = m, l = l, g = 9.81)
  state = state + delta * 0.05
  print(state)
  i+=1
print(i)

In [None]:
state = -torch.tensor([[torch.pi/10, 0]])
s = -torch.tensor([[torch.pi/10, 0]])
c = 3
m = 0.1
l = 10
control = torch.tensor([c * m])
i = 0
while state[0][0] < 0:
  state = wrap_to_pi(state)
  delta = inverted_pendulum_2d_torch(state, control, m = m, l = l, g = 9.81)
  state = state + delta * 0.05
  print(state)
  i+=1

In [None]:
import torch
sentinel = 1337

spot = []
spot.append(torch.tensor([[torch.pi - 0.001, 0.0]], device = 'cuda'))
env1 = PendulumEnv(num_envs = 1, dimension = 2, max_radius=np.pi/2, max_speed = 2, rate = 0.05, field_function = simplified_pendulum_derivatives)
env1.reset(spot[-1])
controls = []
for i in range(100):
    s = wrap_to_pi(spot[-1])
    print(s)
    idx = find_hypercubes(s, verified_points, verified_radius)
    print(idx)
    if idx != -1:
      print('here')
      action = verified_controls[idx]
      t = verified_indices[idx]
      action = pad_tensor_1d(action[0], t.item()).unsqueeze(0)
      controls.append(action)
      traj = env1.trajectories(action)
      traj = wrap_to_pi(traj)
      print(traj[:, :t])
      #reward = reward_func(traj[:, :t+1, :], action[:, :t])
      #rewards.extend(reward)
      i = i + t
      spot.append(traj[:, -1, :])
      env1.reset(spot[-1])


spot = np.array(spot)
plt.plot(controls)
plt.plot(spot[:,0])
plt.plot(spot[:,1])
plt.legend(['control', 'location', 'velocity'])

In [None]:
locp = overall_points.view(len(overall_points), len(maxdists)-1, 2)[:,:,0].cpu().detach().numpy()
speedp = overall_points.view(len(overall_points), len(maxdists)-1, 2)[:,:,1].cpu().detach().numpy()


norms = torch.linalg.norm(trajectories, ord = float('inf'), dim = 2).cpu().detach().numpy()

loc = trajectories[:,:,0].cpu().detach().numpy()
speed = trajectories[:,:,1].cpu().detach().numpy()
control = end_controls.cpu().detach().numpy()

a1 = 3
a2 = 55
a3 = 89
a4 = 32
a5 = 84
a6 = 41

for i in range(len(norms)):
  plt.plot(np.array(range(len(locp[0]))), locp[i])

plt.xlabel('Time (s)')
plt.ylabel('Location')
plt.title('Location vs Time (Ends)')
plt.show()


for i in range(len(norms)):
  l = []
  for j in range(len(loc[i])):
    if j == 0 or control[i][j-1] != sentinel:
      l.append(loc[i][j])
  plt.plot(np.array(range(len(l)))/10, l)

plt.xlabel('Time (s)')
plt.ylabel('Location')
plt.title('Location vs Time (Full)')
plt.show()


for i in range(len(speed)):
  s = []
  for j in range(len(speed[i])):
    if j != 0 and control[i][j - 1] is sentinel:
      print(control[i][j-1])
    if j == 0 or control[i][j-1] != sentinel:
      s.append(speed[i][j])
  plt.plot(np.array(range(len(s)))/10, s)

plt.xlabel('Time (s)')
plt.ylabel('Velocity')
plt.title('Velocity vs Time (Full)')
plt.show()


for i in range(len(norms)):
  plt.plot(loc[i], speed[i])


plt.xlabel('Location')
plt.ylabel('Speed')
plt.title('Phase Portrait')
plt.show()


for i in range(len(norms)):
  c = []
  for j in range(len(control[i])):
    if control[i][j] != sentinel:
      c.append(control[i][j])
  plt.plot(np.array(range(len(c)))/10, c)

plt.xlabel('Time (s)')
plt.ylabel('Control')
plt.title('Control vs Time (Full)')
plt.show()



In [None]:
import time
import matplotlib.pyplot as plt


def reward_func(state, control):
  return -(state[0, 1:, 0]**2 +0.1*state[0, 1:, 1]**2 + 0.001*control**2)

def reward_func_batched(states, controls, taus):
    """
    Compute rewards for a batch of environments up to time t_i for each i.

    Args:
        states: Tensor of shape (batch, T+1, 2)
        controls: Tensor of shape (batch, T)
        taus: Tensor of shape (batch,) with integers for each env

    Returns:
        rewards: Tensor of shape (batch,)
    """
    B, T_plus_1, D = states.shape
    T = T_plus_1 - 1

    time_indices = torch.arange(T, device=states.device).unsqueeze(0)  # shape (1, T)
    taus_expanded = taus.unsqueeze(1)  # shape (B, 1)
    mask = time_indices < taus_expanded  # shape (B, T)

    pos = states[:, 1:, 0]
    vel = states[:, 1:, 1]
    control = controls

    total_cost = pos**2 + 0.1*vel**2 + 0.001 * control**2  # shape (B, T)
    total_cost = total_cost * mask  # mask entries beyond t_i
    reward = -total_cost.sum(dim=1)

    # # DEBUG PRINTS
    # print("taus:", taus)
    # print("avg theta:", (pos*mask).abs().mean().item())
    # print("avg omega:", (vel*mask).abs().mean().item())
    # print("avg control:", (control*mask).abs().mean().item())
    # print("reward per env:", reward)

    return reward  # final shape: (B,)


def find_step(obs, centers, radii, controls, indices):

  index = find_hypercube(obs[0], centers, radii)
  t = indices[index]
  control = controls[index]
  control = pad_tensor_1d(control, t).unsqueeze(0)

  return control, t

def find_steps(obs, centers, radii, controls, indices):

  index = find_hypercubes(obs, centers, radii)
  t = indices[index]
  control = controls[index]
  control = pad_tensor_rows_1d(control, t)

  return control, t

def auc_pytorch(centers, radii, controls, indices, episodes=10, render=False):
    env = PendulumEnv(num_envs = 1, dimension = 2, max_radius=np.pi/2, max_speed = 2, rate = 0.05)
    aucs = []
    its = []
    times = []
    for ep in range(episodes):
        it = 0
        start = time.time()
        if ep % 10 == 0:
          print(ep)
        obs = env.sample_points(1)
        env.reset(obs)
        rewards = []
        i = 0
        while i < 200:
            action, t = find_step(obs, centers, radii, controls, indices)
            traj = env.trajectories(action)
            traj = wrap_to_pi(traj)
            reward = reward_func(traj[:, :t+1, :], action[:, :t])
            rewards.extend(reward)
            i = i + t + 1
            if render:
                env.render()
            obs = traj[:, -1, :]
            env.reset(obs)
            it += 1
        its.append(it)

        rewards = torch.cat(rewards).tolist()
        auc = np.sum(rewards)
        aucs.append(auc)
        times.append(time.time() - start)

    #env.close()
    mean_auc = np.mean(aucs)
    print(f"Mean AUC over {episodes} randomly-initialized episodes: {mean_auc:.2f}")

    # Plot
    plt.plot(aucs, marker='o')
    plt.title("Episode-wise AUC (Random Init)")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward (AUC)")
    plt.grid(True)
    plt.show()

    return aucs, times, its



def auc_pytorch_parallel(centers, radii, controls, indices, episodes=10, seed=None, function = simplified_pendulum_derivatives):
    env = PendulumEnv(num_envs = episodes, dimension = 2, max_radius=np.pi, max_speed = 2, rate = 0.05, field_function = function)
    aucs = []
    its = []
    times = []
    reward = torch.zeros(episodes).to('cuda')
    if True:
        it = torch.zeros(episodes).to('cuda')
        start = time.time()
        obs = env.sample_points(episodes)
        if episodes == 1:
          if seed is not None:
            obs = torch.tensor([seed]).to('cuda')
          trajectory = torch.cat((obs, torch.Tensor([[0]]).to('cuda')), dim = 1)
        env.reset(obs)
        i = torch.zeros(episodes).to('cuda')
        while torch.any(i < 200):
            mask = i < 200
            env = PendulumEnv(num_envs = torch.sum(mask).item(), dimension = 2, max_radius=np.pi/2, max_speed = 2, rate = 0.05, field_function = function)
            env.reset(obs[mask])
            action, t = find_steps(obs[mask], centers, radii, controls, indices)
            traj = env.trajectories(action)
            traj = wrap_to_pi(traj)
            if episodes == 1:
              snippet =  torch.cat((traj[0, 1:t+1, :], action[:, :t].transpose(0,1)), dim = 1)
              trajectory = torch.cat((trajectory, snippet), dim=0)
            reward[mask] += reward_func_batched(traj, action, t).to('cuda')
            i[mask] = i[mask] + t
            obs[mask] = traj[:, -1, :]
            it[mask] += 1
        times.append(time.time() - start)

    if episodes == 1:
      plt.plot(trajectory.cpu().detach().numpy())
      plt.title("Trajectory")
      plt.xlabel("Time")
      plt.ylabel("Location")
      plt.grid(True)
      legend = ["theta", "thetadot", "u"]
      plt.legend(legend)
      plt.show()



      plt.plot(trajectory[:, 0].cpu().detach().numpy(), trajectory[:, 1].cpu().detach().numpy())
      plt.title("Phase Portrait")
      plt.xlabel("Location")
      plt.ylabel("Velocity")
      plt.grid(True)
      plt.show()
    aucs = (reward).detach().cpu().numpy()
    #env.close()
    mean_auc = np.mean(aucs)
    print(f"Mean AUC over {episodes} randomly-initialized episodes: {mean_auc:.2f}")

    # Plot
    plt.plot(aucs, marker='o')
    plt.title("Episode-wise AUC (Random Init)")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward (AUC)")
    plt.grid(True)
    plt.show()

    return aucs, times, it


auc, times, its = auc_pytorch_parallel(verified_points, verified_radius, verified_controls, verified_indices, episodes = 1, seed = [torch.pi/4, 0])
#auc, times, its = auc_pytorch_parallel(verified_points, verified_radius, verified_controls, verified_indices, episodes = 1)

print(times)
print(torch.mean(its))


# auc, times, its = auc_pytorch_parallel(verified_points_alt1, verified_radius_alt1, verified_controls_alt1, verified_indices_alt1, episodes = 1, seed = [torch.pi, 0])
# print(times)
# print(torch.mean(its))



In [None]:
simplified_pendulum_derivatives(torch.tensor([[-np.pi/2, 0]]), torch.tensor([0]))

In [None]:
import pickle
import torch


with open('timesave_131400.pkl', 'rb') as file:
        ll = pickle.load(file)

verified_points = ll[0]
verified_radius = ll[1]
verified_controls = ll[2]
verified_indices = ll[4]
verified_alphas = ll[3]
verified_indices[0] = 1

verified_points, verified_radius, verified_controls, verified_alphas, verified_indices, points, radius, controls, splits, unverified_points, unverified_radius = ll

f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = 4*np.pi)

In [None]:
f_printer(verified_points.detach().to('cpu').numpy(), verified_radius.detach().to('cpu').numpy(), R = 4*np.pi)

In [None]:
len(unverified_points)

In [None]:
import torch

def find_hypercube(point, centers, radii):
    """
    Returns the index of the first hypercube that contains `point`, or the index of the closest center if none do.

    Args:
        point: Tensor of shape (n,)
        centers: Tensor of shape (m, n)
        radii: Tensor of shape (m,) or (m, n)

    Returns:
        int: Index of matching or closest hypercube
    """
    point = point.flatten()
    #point = wrap_to_pi(point)
    m, n = centers.shape

    if radii.dim() == 0:
        radii = radii.view(1, 1)
    elif radii.dim() == 1:
        radii = radii.unsqueeze(1)
    radii = radii.expand(m, n)

    lower_bounds = centers - radii
    upper_bounds = centers + radii

    contained = (point >= lower_bounds) & (point <= upper_bounds)
    inside_mask = contained.all(dim=1)

    if inside_mask.any():
        return torch.nonzero(inside_mask, as_tuple=False)[0].item()
    else:
        dists = torch.norm(centers - point, dim=1)
        print('AHHHH')
        return torch.argmin(dists).item()


In [None]:
import torch

def pad_tensor_1d(tensor, k, sentinel=1337):
    """
    Replace the last (n - k) elements of a 1D tensor with a sentinel value.

    Args:
        tensor (torch.Tensor): Input tensor of shape (n,).
        k (int): Number of elements to keep (must be 0 <= k <= n).
        sentinel (float or int): Value to insert in padded positions.

    Returns:
        torch.Tensor: Modified tensor with sentinel padding applied.
    """
    n = tensor.shape[0]
    assert 0 <= k <= n, f"k must be between 0 and {n}, but got {k}"

    mask = torch.arange(n, device=tensor.device) < k
    return torch.where(mask, tensor, torch.tensor(sentinel, dtype=tensor.dtype, device=tensor.device))


In [None]:
import time
import numpy as np
device = 'cuda'
ftimes = []
for n in [1,10,100,1000,10000]:
  t = []
  for i in range(10000):
    points = 2*torch.pi*(torch.rand(n, 2)-0.5).to('cuda')
    start = time.time()
    f1 = find_hypercubes(points, verified_points, verified_radius)
    t.append(time.time()-start)
  ftimes.append(t)
print([np.mean(t) for t in ftimes])

In [None]:
import time
import numpy as np
device = 'cuda'
ftimes = []
for n in [1,10,100,1000,10000]:
  c = verified_points[:n, :]
  r = verified_radius[:n]
  t = []
  for i in range(10000):
    points = 2*torch.pi*(torch.rand(1000, 2)-0.5).to('cuda')
    start = time.time()
    f1 = find_hypercubes(points, c, r)
    t.append(time.time()-start)
  ftimes.append(t)
print([np.mean(t) for t in ftimes])

In [None]:
#!pip install rtree
import torch
import numpy as np
from rtree import index
import time
n = 1000

# Sample tensor points
points = torch.tensor([[1.0, 2.0], [2.0, 3.0]])
points = 2*torch.pi*(torch.rand(n, 2)-0.5).to('cuda')

# Convert to NumPy for R-tree
points_np = points.cpu().numpy()
v_points_np = verified_points.cpu().numpy()
radius_np = verified_radius.cpu().numpy()

start = time.time()
# Build R-tree index
p = index.Property()
idx = index.Index(properties=p)
for i, point in enumerate(v_points_np):
    x, y = point
    r = radius_np[i]
    idx.insert(i, (x-r, y-r, x+r, y+r))  # Rtree requires bounding boxes

for i in range(n):
  point = points_np[i]
  matches = list(idx.intersection((point[0], point[1], point[0], point[1])))
  #print(len(matches))
print(time.time()-start)

In [None]:
import torch
import numpy as np
#from rtree import index
import time
n = 80000

# Sample tensor points
points = torch.tensor([[1.0, 2.0], [2.0, 3.0]])
points = 2*torch.pi*(torch.rand(n, 2)-0.5).to('cuda')

start = time.time()

find_hypercube_intersections_wrap(verified_points, verified_radius, points, 0.1+torch.zeros(points.shape[0], device = 'cuda'))

print(time.time()-start)

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

In [None]:
def count_hypercube_intersections(centers, radii, query_centers, query_radii):
    """
    Count how many hypercubes intersect each query hypercube.
    centers: (N, D)
    radii: (N, 1) or (N, D)
    query_centers: (M, D)
    query_radii: (M, 1) or (M, D)
    """
    N, D = centers.shape
    M, _ = query_centers.shape

    # Make sure radii are (N, D)
    if radii.ndim == 1:
        radii = radii.unsqueeze(1).expand(N, D)
    elif radii.shape[1] == 1:
        radii = radii.expand(N, D)

    if query_radii.ndim == 1:
        query_radii = query_radii.unsqueeze(1).expand(M, D)
    elif query_radii.shape[1] == 1:
        query_radii = query_radii.expand(M, D)

    # Expand for broadcasting
    centers = centers.unsqueeze(0)             # (1, N, D)
    radii = radii.unsqueeze(0)                 # (1, N, D)
    query_centers = query_centers.unsqueeze(1) # (M, 1, D)
    query_radii = query_radii.unsqueeze(1)     # (M, 1, D)

    # Compute overlaps
    dist = torch.abs(centers - query_centers)      # (M, N, D)
    threshold = radii + query_radii                # (M, N, D)
    overlap = dist <= threshold                    # (M, N, D)

    intersects = overlap.all(dim=-1)               # (M, N)
    return intersects.sum(dim=1)                   # (M,)


start = time.time()
c = count_hypercube_intersections(verified_points, verified_radius, points, torch.zeros(points.shape[0], device = 'cuda'))
h = find_hypercubes(points, verified_points, verified_radius)
print((h == -1) == (c == 0))
print(time.time()-start)

In [None]:
#!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

from torch_scatter import scatter

def find_hypercube_intersections(centers, radii, query_centers, query_radii):
    """
    For each query hypercube, return a list of tensors with the indices of intersecting hypercubes.

    Inputs:
    - centers: (N, D)
    - radii: (N,) or (N, D)
    - query_centers: (M, D)
    - query_radii: (M,) or (M, D)

    Returns:
    - List of M tensors, each containing indices of intersecting hypercubes.
    """
    N, D = centers.shape
    M = query_centers.shape[0]

    if radii.ndim == 1:
        radii = radii.unsqueeze(1).expand(N, D)
    elif radii.shape[1] == 1:
        radii = radii.expand(N, D)

    if query_radii.ndim == 1:
        query_radii = query_radii.unsqueeze(1).expand(M, D)
    elif query_radii.shape[1] == 1:
        query_radii = query_radii.expand(M, D)

    # Broadcasting
    qc = query_centers.unsqueeze(1)  # (M, 1, D)
    qr = query_radii.unsqueeze(1)    # (M, 1, D)
    c = centers.unsqueeze(0)         # (1, N, D)
    r = radii.unsqueeze(0)           # (1, N, D)

    dist = torch.abs(c - qc)         # (M, N, D)
    threshold = r + qr               # (M, N, D)
    mask = (dist <= threshold).all(dim=-1)  # (M, N)

    # Fast extraction of intersecting indices
    query_idx, box_idx = torch.nonzero(mask, as_tuple=True)  # (K,), (K,)

    # Find how many boxes per query
    counts = torch.bincount(query_idx, minlength=M)

    # Split `box_idx` at cumulative sums of counts
    splits = counts.cumsum(0)
    splits = torch.cat([splits.new_zeros(1), splits])

    return [box_idx[splits[i]:splits[i+1]] for i in range(M)]




start = time.time()
find_hypercube_intersections(verified_points, verified_radius, points, 0.1+torch.zeros(points.shape[0], device = 'cuda'))
print(time.time()-start)

In [None]:
import torch

def find_hypercubes_old(points, centers, radii):
    """
    Vectorized version: returns the first matching hypercube index for each point,
    or -1 if none are found.
    """

    points = wrap_to_pi(points)
    k, n = points.shape
    m, _ = centers.shape

    if radii.dim() == 1:
        radii = radii.unsqueeze(1)
    radii = radii.expand(m, n)

    lower_bounds = centers - radii  # (m, n)
    upper_bounds = centers + radii  # (m, n)

    # (k, 1, n) vs (1, m, n) → (k, m, n)
    points = points.unsqueeze(1)
    contained = (points >= lower_bounds) & (points <= upper_bounds)
    inside_mask = contained.all(dim=2)  # (k, m)

    # Set all False to large positive index (m), then take min index along dim=1
    masked_indices = torch.where(inside_mask, torch.arange(m, device=points.device), m)
    min_indices = masked_indices.min(dim=1).values

    # Set to -1 if no hypercube matched (i.e. if index == m)
    result = torch.where(min_indices == m, torch.full_like(min_indices, -1), min_indices)


    return result
