In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Installation

In [None]:
!git clone https://github.com/openai/point-e
%cd point-e
!pip install point -e .

# Import Libraries

In [None]:
import torch
import plotly.graph_objects as go
from point_e.models.download import load_checkpoint
from point_e.diffusion.sampler import PointCloudSampler
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Load Models

In [None]:
base_model = model_from_config(MODEL_CONFIGS["base40M-textvec"] , device)
base_model.eval()

In [None]:
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS["base40M-textvec"])

In [None]:
upsampler_model = model_from_config(MODEL_CONFIGS["upsample"], device)
upsampler_model.eval()

In [None]:
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS["upsample"])

In [None]:
base_model.load_state_dict(load_checkpoint("base40M-textvec" , device))
upsampler_model.load_state_dict(load_checkpoint("upsample" , device))

# Generate

In [None]:
sampler = PointCloudSampler(
    device = device,
    models = [base_model , upsampler_model],
    diffusions = [base_diffusion , upsampler_diffusion],
    num_points = [1024 , 4096 - 1024],
    aux_channels = ["R" , "G" , "B"],
    guidance_scale = [3.0, 0.0],
    model_kwargs_key_filter=("texts" , "")
)

In [None]:
prompt = 'a BLUE truck'
batch_size = 1
samples = None
for x in sampler.sample_batch_progressive(batch_size=batch_size, model_kwargs=dict(texts=[prompt])):
    samples = x
    
pred = sampler.output_to_point_clouds(samples)[0]

# Show

In [None]:
rgb_values = zip(pred.channels["R"], pred.channels["G"], pred.channels["B"])
color = list(map(lambda rgb: "rgb({}, {}, {})".format(*rgb), rgb_values))

In [None]:
fig = go.Figure(
    data = [
        go.Scatter3d(
            x = pred.coords[:, 0],
            y = pred.coords[:, 1],
            z = pred.coords[:, 2],
            mode = "markers" ,
            marker = dict(
                size = 2,
                color = color,
            )
        )
    ],
    layout = dict(
        scene=dict(
            xaxis = dict(visible = False),
            yaxis = dict(visible = False),
            zaxis = dict(visible = False)
        )
    ),
)
fig.show()


In [None]:
fig = plot_point_cloud(pred, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75)))
fig.show()