In [None]:
"""
Text to 3D Point Cloud Generator

This script generates detailed 3D point clouds from text descriptions using
a combination of Stable Diffusion for image generation and MiDaS for depth estimation.
The output is a complete, textured point cloud in PLY format.

Requirements:
- torch
- numpy
- matplotlib
- open3d
- timm
- PIL
- cv2
"""

!pip install -q torch torchvision
!pip install -q open3d timm
!pip install -q opencv-python matplotlib
!pip install -q git+https://github.com/CompVis/stable-diffusion.git@main
!pip install -q pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/$(python -c "import torch; print(f'torch-{torch.__version__}_cu' + torch.version.cuda.replace('.','') + '_pyt' + torch.__version__.split('+')[0])").wheels.html

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
from PIL import Image
import cv2
import time
import requests
import urllib.request
from io import BytesIO
import zipfile
from IPython.display import display, HTML

# Install required packages

# Download MiDaS
print("Downloading MiDaS model...")
midas_path = "MiDaS"
if not os.path.exists(midas_path):
    os.makedirs(midas_path)
    # Download MiDaS code
    urllib.request.urlretrieve(
        "https://github.com/isl-org/MiDaS/archive/refs/heads/master.zip",
        "midas.zip"
    )

    # Extract the zip file
    with zipfile.ZipFile("midas.zip", 'r') as zip_ref:
        zip_ref.extractall(".")

    # Rename the folder
    os.rename("MiDaS-master", midas_path)

    # Download MiDaS model weights
    os.makedirs(os.path.join(midas_path, "weights"), exist_ok=True)
    urllib.request.urlretrieve(
        "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
        os.path.join(midas_path, "weights", "dpt_beit_large_512.pt")
    )

# Add MiDaS to path
import sys
sys.path.append(midas_path)

# Import MiDaS modules
from MiDaS.midas.model_loader import load_model
from MiDaS.midas.transforms import Resize, NormalizeImage, PrepareForNet
from MiDaS.midas.dpt_depth import DPTDepthModel

# Import Stable Diffusion
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf

class TextTo3DPointCloud:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.resolution = 512
        self.num_inference_steps = 50
        self.strength = 0.75
        self.angle_views = 8  # Number of views to generate
        self.setup_models()

    def setup_models(self):
        """Set up the necessary models for text-to-image and depth estimation."""
        print("Setting up models...")

        # Download and load the Stable Diffusion model
        print("Setting up Stable Diffusion model...")
        sd_config = OmegaConf.load("stable-diffusion/configs/stable-diffusion/v1-inference.yaml")
        sd_config.model.params.cond_stage_config.params.version = "openai/clip-vit-large-patch14"

        # Download Stable Diffusion checkpoint if not exists
        if not os.path.exists("sd-v1-4.ckpt"):
            print("Downloading Stable Diffusion checkpoint...")
            urllib.request.urlretrieve(
                "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
                "sd-v1-4.ckpt"
            )

        # Load Stable Diffusion model
        sd_model = instantiate_from_config(sd_config.model)
        sd_model.load_state_dict(torch.load("sd-v1-4.ckpt")["state_dict"], strict=False)
        sd_model.to(self.device)
        sd_model.eval()
        self.sd_model = sd_model
        self.sampler = DDIMSampler(sd_model)

        # Load MiDaS model
        print("Setting up MiDaS depth estimation model...")
        model_type = "dpt_beit_large_512"  # Model type
        model_path = os.path.join("MiDaS", "weights", "dpt_beit_large_512.pt")  # Path to model weights

        # Load model
        self.midas_model = load_model(model_path, model_type, self.device)

        # Setup MiDaS transforms
        self.midas_transform = torch.nn.Sequential(
            Resize(
                width=self.resolution,
                height=self.resolution,
                resize_target=False,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="minimal",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            PrepareForNet(),
        )

        print("Models loaded successfully!")

    def generate_multiview_images(self, prompt, negative_prompt=None):
        """Generate multiple views of the same object from different angles."""
        images = []
        angle_prompts = []

        base_prompt = prompt
        negative_prompt = negative_prompt or "blurry, low quality, incomplete, cropped, disfigured, deformed"

        # Create angle prompts
        view_angles = ["front view", "back view", "side view", "top view", "bottom view",
                       "45 degree angle view", "three-quarter view", "isometric view"]

        # Take only the specified number of angles
        view_angles = view_angles[:self.angle_views]

        # Generate images for each view
        print("Generating multiple views...")
        for i, angle in enumerate(view_angles):
            print(f"Generating {angle}...")
            angle_prompt = f"{base_prompt}, {angle}, detailed, high quality, 3D object, centered"

            # Generate image with Stable Diffusion
            # Set up text prompt
            c = self.sd_model.get_learned_conditioning([angle_prompt])
            uc = self.sd_model.get_learned_conditioning([negative_prompt])

            # Sample image
            shape = [4, self.resolution // 8, self.resolution // 8]  # 4 is channels for latent space
            samples, _ = self.sampler.sample(
                S=self.num_inference_steps,
                conditioning=c,
                batch_size=1,
                shape=shape,
                verbose=False,
                unconditional_guidance_scale=7.5,
                unconditional_conditioning=uc,
                eta=0.0
            )

            # Decode samples
            x_samples = self.sd_model.decode_first_stage(samples)
            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

            # Convert to PIL image
            x_sample = x_samples[0].cpu().numpy()
            x_sample = (255. * x_sample.transpose(1, 2, 0)).clip(0, 255).astype(np.uint8)
            img = Image.fromarray(x_sample)

            images.append(img)
            angle_prompts.append(angle)

        return images, angle_prompts

    def estimate_depth(self, image):
        """Estimate depth map from a single image using MiDaS."""
        # Convert PIL image to numpy array
        img = np.array(image)

        # Convert image to torch tensor for MiDaS
        img_tensor = self.midas_transform({"image": img})["image"]

        # Add batch dimension
        img_tensor = torch.unsqueeze(img_tensor, 0).to(self.device)

        # Inference with MiDaS
        with torch.no_grad():
            # Forward pass
            prediction = self.midas_model.forward(img_tensor)
            # Perform disparity normalization (as done in MiDaS)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=(self.resolution, self.resolution),
                mode="bicubic",
                align_corners=False,
            )
            prediction = prediction.squeeze().cpu().numpy()

        # MiDaS returns inverse depth, so we need to invert it
        depth_map = 1.0 / (prediction + 1e-6)  # Add small value to avoid division by zero

        # Normalize depth map
        depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())

        # Convert to uint8 for visualization
        depth_map = (depth_map * 255).astype(np.uint8)

        # Convert to PIL Image
        depth_image = Image.fromarray(depth_map)
        return depth_image

    def create_point_cloud_from_depth(self, image, depth_map, view_idx=0, total_views=1):
        """Create a point cloud from an image and its depth map."""
        # Convert PIL images to numpy arrays
        img_np = np.array(image)
        depth_np = np.array(depth_map).astype(float)
        depth_np = depth_np / np.max(depth_np)  # Normalize depth

        # Create coordinates grid
        h, w = depth_np.shape
        x, y = np.meshgrid(np.arange(w), np.arange(h))

        # Create 3D points
        z = depth_np.reshape(-1) * 2  # Scale depth
        x = x.reshape(-1) - w / 2
        y = y.reshape(-1) - h / 2

        # Rotate points based on the view index
        theta = 2 * np.pi * view_idx / total_views
        x_rot = x * np.cos(theta) - z * np.sin(theta)
        z_rot = x * np.sin(theta) + z * np.cos(theta)

        # Create point cloud
        points = np.vstack((x_rot, y, z_rot)).T

        # Get colors from the image
        colors = img_np.reshape(-1, 3) / 255.0

        # Remove points with invalid depth
        valid_points = ~np.isnan(points).any(axis=1)
        points = points[valid_points]
        colors = colors[valid_points]

        return points, colors

    def merge_point_clouds(self, all_points, all_colors):
        """Merge multiple point clouds into a single one."""
        merged_points = np.vstack(all_points)
        merged_colors = np.vstack(all_colors)

        # Optional: Remove duplicate or very close points
        # This uses a voxel grid to downsample and clean the point cloud
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(merged_points)
        pcd.colors = o3d.utility.Vector3dVector(merged_colors)

        # Voxel downsampling to remove duplicate points
        voxel_size = 0.01  # Adjust based on your scale
        pcd_down = pcd.voxel_down_sample(voxel_size)

        # Statistical outlier removal
        pcd_clean, _ = pcd_down.remove_statistical_outlier(
            nb_neighbors=20, std_ratio=2.0)

        # Get final points and colors
        final_points = np.asarray(pcd_clean.points)
        final_colors = np.asarray(pcd_clean.colors)

        return final_points, final_colors

    def apply_mesh_reconstruction(self, points, colors):
        """Reconstruct a surface mesh from the point cloud for better quality."""
        # Create a point cloud object
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)

        # Estimate normals for better mesh reconstruction
        pcd.estimate_normals()
        pcd.orient_normals_consistent_tangent_plane(100)

        # Reconstruct mesh using Poisson surface reconstruction
        print("Reconstructing mesh surface...")
        mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
            pcd, depth=9, width=0, scale=1.1, linear_fit=False)

        # Remove low density vertices
        vertices_to_remove = densities < np.quantile(densities, 0.1)
        mesh.remove_vertices_by_mask(vertices_to_remove)

        # Convert back to point cloud for consistency
        refined_pcd = mesh.sample_points_uniformly(number_of_points=len(points))

        return np.asarray(refined_pcd.points), np.asarray(refined_pcd.colors)

    def save_to_ply(self, points, colors, filename):
        """Save the point cloud to a PLY file."""
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)

        # Save to PLY file
        o3d.io.write_point_cloud(filename, pcd)
        print(f"Point cloud saved to {filename}")

        return filename

    def visualize_point_cloud(self, points, colors):
        """Visualize the point cloud."""
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)

        # Create a visualizer
        vis = o3d.visualization.Visualizer()
        vis.create_window()
        vis.add_geometry(pcd)

        # Configure visualization settings
        opt = vis.get_render_option()
        opt.point_size = 1.0
        opt.background_color = np.array([0.1, 0.1, 0.1])

        # Run the visualizer
        vis.run()
        vis.destroy_window()

    def process(self, prompt, negative_prompt=None, output_filename="output.ply", visualize=True):
        """Process a text prompt to generate a 3D point cloud."""
        start_time = time.time()

        print(f"Processing prompt: '{prompt}'")

        # Generate multi-view images
        images, angle_prompts = self.generate_multiview_images(prompt, negative_prompt)

        # Display generated images
        fig, axes = plt.subplots(1, len(images), figsize=(15, 3))
        for i, (img, angle) in enumerate(zip(images, angle_prompts)):
            if len(images) == 1:
                axes.imshow(img)
                axes.set_title(angle)
                axes.axis('off')
            else:
                axes[i].imshow(img)
                axes[i].set_title(angle)
                axes[i].axis('off')
        plt.tight_layout()
        plt.show()

        # Process each image to get point clouds
        all_points = []
        all_colors = []

        print("Converting images to point clouds...")
        for i, image in enumerate(images):
            depth_map = self.estimate_depth(image)

            # Display depth map
            plt.figure(figsize=(5, 5))
            plt.imshow(depth_map, cmap='plasma')
            plt.title(f"Depth Map - {angle_prompts[i]}")
            plt.axis('off')
            plt.show()

            # Create point cloud from this view
            points, colors = self.create_point_cloud_from_depth(
                image, depth_map, view_idx=i, total_views=len(images))

            all_points.append(points)
            all_colors.append(colors)

        # Merge all point clouds
        print("Merging point clouds from all views...")
        merged_points, merged_colors = self.merge_point_clouds(all_points, all_colors)

        # Apply mesh reconstruction for better quality
        refined_points, refined_colors = self.apply_mesh_reconstruction(merged_points, merged_colors)

        # Save point cloud
        output_path = self.save_to_ply(refined_points, refined_colors, output_filename)

        # Optionally visualize
        if visualize:
            print("Visualizing point cloud...")
            self.visualize_point_cloud(refined_points, refined_colors)

        elapsed_time = time.time() - start_time
        print(f"Processing completed in {elapsed_time:.2f} seconds")

        return output_path

# Example usage in Colab
def main():
    # Initialize the model
    generator = TextTo3DPointCloud()

    # Get user input for text description
    text_prompt = input("Enter a description of the 3D object you want to generate: ")
    output_filename = input("Enter output filename (default: output.ply): ") or "output.ply"

    # Generate point cloud
    output_path = generator.process(
        prompt=text_prompt,
        negative_prompt="blurry, low quality, incomplete, cropped, disfigured, deformed",
        output_filename=output_filename,
        visualize=True
    )

    print(f"3D point cloud saved to: {output_path}")

    # Provide code for downloading the file
    from google.colab import files
    files.download(output_path)

    # Show visualization of the 3D point cloud using Open3D's web visualization
    print("\nDisplaying 3D point cloud in a web viewer...")
    pcd = o3d.io.read_point_cloud(output_path)
    o3d.visualization.draw_geometries([pcd])

if __name__ == "__main__":
    # Print banner
    print("=" * 80)
    print("Text to 3D Point Cloud Generator with MiDaS Depth Estimation")
    print("=" * 80)
    main()

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m447.7/447.7 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m228.0/228.0 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m88.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[31mERROR: Could not find a version that satisfies the requirement pytorch3d (from versions: none)[0m[31m



ModuleNotFoundError: No module named 'ldm'