# 3D Particle Structure Generator Exploration

Step-by-step walkthrough of AI-generated 3D particles. 
Comparison of original data against the Baseline VAE and the Improved VAE (where data sorting broke permutation symmetry).

In [None]:
import torch
import matplotlib.pyplot as plt
from models.vae import VAE
from data.generator import generate_particles
from utils.metrics import compute_pairwise_distances

# Seed for reproducibility
torch.manual_seed(42)

### Model Loading
Two versions available: Baseline VAE and Improved Beta-VAE.

In [None]:
# Parameter configuration
n_particles = 50
input_dim = n_particles * 3

# Baseline Model
baseline_model = VAE(input_dim, hidden_dim=128, latent_dim=16)
baseline_model.load_state_dict(torch.load('model.pth', weights_only=True))
baseline_model.eval()

# Improved Model
improved_model = VAE(input_dim, hidden_dim=256, latent_dim=32)
improved_model.load_state_dict(torch.load('model_improved.pth', weights_only=True))
improved_model.eval()
print('Models loaded.')

### Data Generation
Reference data generation followed by latent vector sampling for VAE reconstruction.

In [None]:
# Original structure generation
orig_pts = generate_particles(n_particles, min_distance=0.08)

with torch.no_grad():
    # Baseline reconstruction
    z_base = torch.randn(1, 16)
    gen_pts_base = baseline_model.decode(z_base).view(n_particles, 3)
    
    # Improved reconstruction
    z_improved = torch.randn(1, 32)
    gen_pts_improved = improved_model.decode(z_improved).view(n_particles, 3)

### 3D Visualization

In [None]:
fig = plt.figure(figsize=(15, 5))

ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(orig_pts[:, 0], orig_pts[:, 1], orig_pts[:, 2], alpha=0.6, c='b')
ax1.set_title('Original Data')
ax1.set_xlim(0,1); ax1.set_ylim(0,1); ax1.set_zlim(0,1)

ax2 = fig.add_subplot(132, projection='3d')
ax2.scatter(gen_pts_base[:, 0], gen_pts_base[:, 1], gen_pts_base[:, 2], alpha=0.6, c='r')
ax2.set_title('Baseline VAE')
ax2.set_xlim(0,1); ax2.set_ylim(0,1); ax2.set_zlim(0,1)

ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(gen_pts_improved[:, 0], gen_pts_improved[:, 1], gen_pts_improved[:, 2], alpha=0.6, c='g')
ax3.set_title('Improved VAE')
ax3.set_xlim(0,1); ax3.set_ylim(0,1); ax3.set_zlim(0,1)

plt.show()

### Distance Evaluation
Mathematical verification through pairwise distance histograms. Overlap between original and generated distributions indicates performance quality.

In [None]:
orig_dists = compute_pairwise_distances(orig_pts)
gen_dists_base = compute_pairwise_distances(gen_pts_base)
gen_dists_impr = compute_pairwise_distances(gen_pts_improved)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.hist(orig_dists.numpy(), bins=30, alpha=0.5, label='Original')
ax1.hist(gen_dists_base.numpy(), bins=30, alpha=0.5, label='Baseline', color='r')
ax1.set_title('Baseline Comparison')
ax1.set_xlabel('Distance')
ax1.legend()

ax2.hist(orig_dists.numpy(), bins=30, alpha=0.5, label='Original')
ax2.hist(gen_dists_impr.numpy(), bins=30, alpha=0.5, label='Improved', color='g')
ax2.set_title('Improved Comparison')
ax2.set_xlabel('Distance')
ax2.legend()

plt.show()