In [None]:
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
import imageio

OUTPUT_DIR = os.path.abspath("./data")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using torch device: {device}")

WIDTH = 512

torch.manual_seed(0)

Using torch device: cuda


<torch._C.Generator at 0x1c6a3311230>

In [19]:
def get_grav_acc(x, m, G):
    """
    Calculate gravitational acceleration of all points.
        x: (n, 2) array of positions
        m: (n,) array of masses
        G: gravitational constant
    """
    n = x.shape[0]
    assert x.shape == (n, 3) and m.shape == (n,)

    epsilon = 8  # For buffering/smoothing effect

    # Calculate pairwise displacement vectors (x_i - x_j)
    dx = x[:,None,:] - x[None,:,:]  # Shape: (n, n, 2)
    d = torch.norm(dx, dim=2)

    mapped_masses = m[:,None].expand(n, n)

    d3 = (d**2 + epsilon**2)**1.5
    F = G * dx * mapped_masses[:,:,None] / d3[:,:,None]
    acc = torch.sum(F, dim=0) / m[:, None]

    return acc

In [20]:
def update_system(x, v, m, dt, G):
    """
    Update points due to gravitational attraction.
        x:  (n, 3) array of positions
        v:  (n, 3) array of velocities
        m:  (n,) array of masses
        dt: time step
        G:  gravitational constant
    """
    n = x.shape[0]
    assert x.shape == (n, 3) and v.shape == (n, 3) and m.shape == (n,)
    
    # Update positions and velocities using Verlet integration
    # https://en.wikipedia.org/wiki/Verlet_integration#Velocity_Verlet
    a = get_grav_acc(x, m, G)

    x_new = x + v * dt + 0.5 * a * dt**2
    a_new = get_grav_acc(x_new, m, G)
    v_new = v + 0.5 * (a + a_new) * dt

    return x_new, v_new

In [21]:
def generate_timeline(x0, v0, m, G, dt, F):
    """
    Generate one timeline of frames for gravity simulation.

    arguments:
        x0: (n, 3) array of initial positions
        v0: (n, 3) array of initial velocities
        m:  (n,) array of masses
        G:  gravitational constant
        dt: timestep per frame
        F:  frames to simulate for
    """
    n = x0.shape[0]
    assert x0.shape == (n, 3) and v0.shape == (n, 3) and m.shape == (n,)

    X = torch.zeros((F, n, 3)).to(device)
    V = torch.zeros((F, n, 3)).to(device)

    X[0] = x0
    V[0] = v0
    m = m.to(device)

    for i in range(1, F):
        X[i], V[i] = update_system(X[i-1], V[i-1], m, dt, G)
    
    return {
        "dt": dt,
        "G": G,
        "m": m.cpu(),
        "X": X.cpu(),
        "V": V.cpu(),
    }

In [80]:
F = 512             # Frames per timeline
dt = 0.5            # Timestep per frame
G = 100             # Gravitational constant
n_samples = 10      # Number of timelines to generate

n = 512             # Number of particles

x0 = torch.hstack([torch.rand((n, 3)) * WIDTH])
v0 = torch.randn((n, 3)) * 2
m = torch.exp(torch.randn((n,)) * 0.5 + 1)

x, v = x0, v0

In [None]:
frames = []

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

ax.set_facecolor("black")

for i in tqdm(range(30 * 10)):
    x, v = update_system(x, v, m, dt, G)

    ax.clear()
    ax.set_xlim(0, 512)
    ax.set_ylim(0, 512)
    ax.set_zlim(0, 512)
    ax.scatter(x[:,0], x[:,1], x[:,2], color="white", s=torch.sqrt(m))
    plt.axis("off")
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype="uint8")
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    
    # Make the image writable by creating a copy
    image = image.copy()
    frames.append(image)

plt.close(fig)
print(f"Finished rendering, saving to MP4...")

# Save frames as an animated GIF with looping
imageio.mimsave(f"3d_test.mp4", frames, fps=30)

100%|██████████| 300/300 [00:08<00:00, 36.27it/s]


Finished rendering, saving to MP4...
