In [1]:
pip install torch open3d tqdm numpy matplotlib

Collecting torch
  Downloading torch-2.7.1-cp39-cp39-win_amd64.whl (216.0 MB)
     -------------------------------------- 216.0/216.0 MB 4.0 MB/s eta 0:00:00
Collecting open3d
  Downloading open3d-0.19.0-cp39-cp39-win_amd64.whl (69.5 MB)
     ---------------------------------------- 69.5/69.5 MB 3.0 MB/s eta 0:00:00
Collecting sympy>=1.13.3
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ---------------------------------------- 6.3/6.3 MB 5.7 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.18.0-py3-none-any.whl (16 kB)
Collecting networkx
  Downloading networkx-3.2.1-py3-none-any.whl (1.6 MB)
     ---------------------------------------- 1.6/1.6 MB 7.5 MB/s eta 0:00:00
Collecting fsspec
  Downloading fsspec-2025.5.1-py3-none-any.whl (199 kB)
     -------------------------------------- 199.1/199.1 kB 2.4 MB/s eta 0:00:00
Collecting typing-extensions>=4.10.0
  Downloading typing_extensions-4.14.0-py3-none-any.whl (43 kB)
     --------------------------------------


[notice] A new release of pip available: 22.2.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Prepare ShapeNet Data:
Download ShapeNet point clouds or convert from meshes using tools like trimesh
Save each point cloud as a .npy file in a structured directory

Training:
Update "path_to_shapenet_point_clouds" in the dataset initialization
Run train_point_cloud_gan()
Monitor losses and visualize generated samples

Generation:
After training, load the generator:
////
generator = Generator().to(device)
generator.load_state_dict(torch.load("generator.pth"))
Generate new point clouds:
///
z = torch.randn(1, latent_dim).to(device)
new_pc = generator(z).cpu().numpy()[0]
visualize_point_cloud(new_pc)

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class PointCloudDataset(Dataset):
    def __init__(self, root_dir, num_points=2048, split='train'):
        """
        Args:
            root_dir (string): Directory with all the point clouds (.npy files)
            num_points (int): Number of points to sample from each cloud
            split (str): 'train' or 'test'
        """
        self.root_dir = root_dir
        self.num_points = num_points
        self.files = []
        
        for class_dir in os.listdir(root_dir):
            class_path = os.path.join(root_dir, class_dir)
            if os.path.isdir(class_path):
                for file in os.listdir(class_path):
                    if file.endswith('.npy'):
                        self.files.append(os.path.join(class_path, file))
        
        split_idx = int(0.8 * len(self.files))
        self.files = self.files[:split_idx] if split == 'train' else self.files[split_idx:]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        pc = np.load(self.files[idx])  # Shape: (N, 3)
        
        if pc.shape[0] > self.num_points:
            idxs = np.random.choice(pc.shape[0], self.num_points, replace=False)
            pc = pc[idxs]
        elif pc.shape[0] < self.num_points:
            idxs = np.random.choice(pc.shape[0], self.num_points, replace=True)
            pc = pc[idxs]
        
        pc = torch.from_numpy(pc).float()
        pc = self.normalize_point_cloud(pc)
        return pc

    def normalize_point_cloud(self, pc):
        """ Normalize point cloud to fit in [-1, 1] cube """
        centroid = torch.mean(pc, dim=0)
        pc = pc - centroid
        max_dist = torch.max(torch.sqrt(torch.sum(pc**2, dim=1)))
        pc = pc / max_dist
        return pc

In [2]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=128, num_points=2048):
        super(Generator, self).__init__()
        self.num_points = num_points
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            
            nn.Linear(1024, num_points * 3),
            nn.Tanh()
        )
    
    def forward(self, z):
        batch_size = z.size(0)
        out = self.model(z)
        out = out.view(batch_size, self.num_points, 3)  # (batch_size, num_points, 3)
        return out

In [3]:
class Discriminator(nn.Module):
    def __init__(self, num_points=2048):
        super(Discriminator, self).__init__()
        self.num_points = num_points
        
        # Shared MLP layers (similar to PointNet)
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 256, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        
        # Global feature layer
        self.global_conv = nn.Conv1d(256, 512, 1)
        self.global_bn = nn.BatchNorm1d(512)
        
        # Classification layers
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # x: (batch_size, num_points, 3)
        x = x.transpose(2, 1)  # (batch_size, 3, num_points)
        
        # Shared MLP
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        
        # Global features
        x = self.relu(self.global_bn(self.global_conv(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 512)
        
        # Classification
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

In [4]:
def gradient_penalty(discriminator, real_samples, fake_samples, device):
    """ Calculates the gradient penalty loss for WGAN-GP """
    batch_size = real_samples.size(0)
    # Random weight term for interpolation
    alpha = torch.rand(batch_size, 1, 1).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    d_interpolates = discriminator(interpolates)
    fake = torch.ones(batch_size, 1).to(device)
    
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def train_point_cloud_gan():
    # Hyperparameters
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    latent_dim = 128
    num_points = 2048
    batch_size = 32
    epochs = 200
    lr = 0.0001
    n_critic = 5  # Number of discriminator updates per generator update
    lambda_gp = 10  # Gradient penalty coefficient
    
    # Create dataset and dataloader
    dataset = PointCloudDataset("path_to_shapenet_point_clouds", num_points=num_points)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize networks
    generator = Generator(latent_dim, num_points).to(device)
    discriminator = Discriminator(num_points).to(device)
    
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    # Training loop
    for epoch in range(epochs):
        for i, real_pcs in enumerate(tqdm(dataloader)):
            real_pcs = real_pcs.to(device)
            batch_size = real_pcs.size(0)
            
            # Train Discriminator
            optimizer_D.zero_grad()
            
            # Generate fake point clouds
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_pcs = generator(z)
            
            # Real and fake losses
            real_loss = -torch.mean(discriminator(real_pcs))
            fake_loss = torch.mean(discriminator(fake_pcs.detach()))
            
            # Gradient penalty
            gp = gradient_penalty(discriminator, real_pcs.data, fake_pcs.data, device)
            
            # Total discriminator loss
            d_loss = real_loss + fake_loss + lambda_gp * gp
            d_loss.backward()
            optimizer_D.step()
            
            # Train Generator every n_critic steps
            if i % n_critic == 0:
                optimizer_G.zero_grad()
                
                # Generate fake point clouds
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_pcs = generator(z)
                
                # Generator loss
                g_loss = -torch.mean(discriminator(fake_pcs))
                g_loss.backward()
                optimizer_G.step()
            
        # Print progress
        print(f"[Epoch {epoch}/{epochs}] D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
        
        # Save generated samples periodically
        if epoch % 10 == 0:
            with torch.no_grad():
                sample_z = torch.randn(1, latent_dim).to(device)
                sample_pc = generator(sample_z).cpu().numpy()
                save_point_cloud(sample_pc[0], f"generated_samples/sample_{epoch}.ply")
    
    # Save models
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(discriminator.state_dict(), "discriminator.pth")

In [5]:
import open3d as o3d

def save_point_cloud(points, filename):
    """ Save point cloud to PLY file """
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    o3d.io.write_point_cloud(filename, pcd)

def visualize_point_cloud(points):
    """ Interactive visualization """
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    o3d.visualization.draw_geometries([pcd])

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
