# 3D Particle Structure Generator Exploration
Hey! Welcome to the notebook. This walkthrough lets you look at the AI creating 3D particles in a box step by step.
We compare the original data against our Baseline VAE and the Improved VAE (where the data was sorted to break permutation symmetry).

In [None]:
import torch
import matplotlib.pyplot as plt
from models.vae import VAE
from data.generator import generate_particles
from utils.metrics import compute_pairwise_distances

# Set seed for reproducibility
torch.manual_seed(42)

### Load the models
We trained two versions: The Baseline VAE and the Improved Beta-VAE.

In [None]:
# Universal properties
n_particles = 50
input_dim = n_particles * 3

# Load Baseline
baseline_model = VAE(input_dim, hidden_dim=128, latent_dim=16)
baseline_model.load_state_dict(torch.load('model.pth', weights_only=True))
baseline_model.eval()

# Load Improved
improved_model = VAE(input_dim, hidden_dim=256, latent_dim=32)
improved_model.load_state_dict(torch.load('model_improved.pth', weights_only=True))
improved_model.eval()
print('Models loaded successfully!')

### Generate the data
We generate reference data, then sample random latent vectors to ask the VAEs to build new particle setups.

In [None]:
# Generate references (Real structures)
orig_pts = generate_particles(n_particles, min_distance=0.08)

with torch.no_grad():
    # Sample Baseline
    z_base = torch.randn(1, 16)
    gen_pts_base = baseline_model.decode(z_base).view(n_particles, 3)
    
    # Sample Improved
    z_improved = torch.randn(1, 32)
    gen_pts_improved = improved_model.decode(z_improved).view(n_particles, 3)

### Plot the 3D results visually

In [None]:
fig = plt.figure(figsize=(15, 5))

ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(orig_pts[:, 0], orig_pts[:, 1], orig_pts[:, 2], alpha=0.6, c='b')
ax1.set_title('Original Data')
ax1.set_xlim(0,1); ax1.set_ylim(0,1); ax1.set_zlim(0,1)

ax2 = fig.add_subplot(132, projection='3d')
ax2.scatter(gen_pts_base[:, 0], gen_pts_base[:, 1], gen_pts_base[:, 2], alpha=0.6, c='r')
ax2.set_title('Baseline VAE Generated')
ax2.set_xlim(0,1); ax2.set_ylim(0,1); ax2.set_zlim(0,1)

ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(gen_pts_improved[:, 0], gen_pts_improved[:, 1], gen_pts_improved[:, 2], alpha=0.6, c='g')
ax3.set_title('Improved VAE Generated')
ax3.set_xlim(0,1); ax3.set_ylim(0,1); ax3.set_zlim(0,1)

plt.show()

### Pairwise Distance Evaluation
Visually, they look like boxes of dots. To actually figure out how accurate the AI's spacing holds up, we compute the histogram of distances between *every single dot pair*.

In [None]:
orig_dists = compute_pairwise_distances(orig_pts)
gen_dists_base = compute_pairwise_distances(gen_pts_base)
gen_dists_impr = compute_pairwise_distances(gen_pts_improved)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Compare Original vs Baseline
ax1.hist(orig_dists.numpy(), bins=30, alpha=0.5, label='Original')
ax1.hist(gen_dists_base.numpy(), bins=30, alpha=0.5, label='Baseline', color='r')
ax1.set_title('Baseline Comparison')
ax1.set_xlabel('Distance')
ax1.legend()

# Compare Original vs Improved
ax2.hist(orig_dists.numpy(), bins=30, alpha=0.5, label='Original')
ax2.hist(gen_dists_impr.numpy(), bins=30, alpha=0.5, label='Improved', color='g')
ax2.set_title('Improved Comparison')
ax2.set_xlabel('Distance')
ax2.legend()

plt.show()