In [60]:
import torch
import time

device = "cuda"
N = 10_000_000
S = 1_000_000
points = torch.rand(N, 3, device=device)

start = time.perf_counter()
indices = torch.randint(0, N, (S,))  # Sample indices directly
subset = points[indices]
end = time.perf_counter()
print(f"Time taken: {(end - start) * 1000:.6f} ms")

torch.cuda.synchronize()  # Ensure previous operations are done
start = time.perf_counter()
points[torch.randperm(N, device=device)[:S]]
torch.cuda.synchronize()  # Ensure operation is completed
end = time.perf_counter()
print(f"Time taken: {(end - start) * 1000:.6f} ms")

# start = time.perf_counter()
# indices = torch.multinomial(torch.ones(N), S, replacement=False)  # Sample without replacement
# subset = points[indices]
# end = time.perf_counter()
# print(f"Time taken: {(end - start) * 1000:.6f} ms")

Time taken: 30.559619 ms
Time taken: 7.338509 ms


In [None]:
import torch
from neural_poisson.data.prepare import select_random_points

device = "cuda"
N = 100_000
S = 100_000
points = torch.rand(N, 3, device=device)

torch.cuda.synchronize()
mem_before = torch.cuda.memory_allocated(device)

# Run function
points = select_random_points(points=points, max_samples=S)

torch.cuda.synchronize()
mem_after = torch.cuda.memory_allocated(device)
print(f"Memory used after: {mem_after / 1e6:.2f} MB")

# Free unused cached memory
torch.cuda.empty_cache()

print(f"Memory used before: {mem_before / 1e6:.2f} MB")
print(f"Memory difference: {(mem_after - mem_before) / 1e6:.2f} MB")


Memory used before: 1.20 MB
Memory used after: 1.20 MB
Memory difference: 0.00 MB


In [1]:
import torch
import math
from collections import defaultdict
from neural_poisson.data.prepare import compute_chunks

num_chunks = 4
chunk_size = 50_000
points = torch.rand(200_000, 3) 
values = [points]
c1, c2 = compute_chunks(num_chunks=num_chunks, chunk_size=chunk_size, values=[points, points])


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


tensor([[0.2134, 0.0792, 0.3794],
        [0.1790, 0.8268, 0.6921],
        [0.3217, 0.7847, 0.8935],
        ...,
        [0.9791, 0.7663, 0.2652],
        [0.6422, 0.0134, 0.0133],
        [0.0294, 0.0601, 0.4536]])