In [None]:
# not using this code, even though better.
# Why? Hypothesis: when optimizing using Reinforcement Learning, it will become difficult to get gradients along multiple models. 

In [1]:
from PIL import Image
import torch
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('creating base model...')
base_name = 'base40M' # use base300M or base1B for better results
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

creating base model...


100%|███████████████████████████████████████| 890M/890M [02:22<00:00, 6.55MiB/s]


creating upsample model...
downloading base checkpoint...


100%|██████████| 162M/162M [01:45<00:00, 1.54MiB/s] 
  return torch.load(path, map_location=device)


downloading upsampler checkpoint...


100%|██████████| 162M/162M [01:44<00:00, 1.55MiB/s] 


<All keys matched successfully>

In [4]:
# Save the base model checkpoint
torch.save(base_model.state_dict(), 'base_model_checkpoint.pth')

# Save the upsampler model checkpoint
torch.save(upsampler_model.state_dict(), 'upsampler_model_checkpoint.pth')

print("Models have been saved successfully.")


Models have been saved successfully.


In [22]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

#faster. reducing number of points. lowering guidance (it relates to how 'exact' or 'creative' we need to be)
# sampler = PointCloudSampler(
#     device=device,
#     models=[base_model, upsampler_model],
#     diffusions=[base_diffusion, upsampler_diffusion],
#     num_points=[512, 1024 - 512],
#     aux_channels=['R', 'G', 'B'],
#     guidance_scale=[2.0, 2.0],
# )


# #even faster. skipping upsampler
# sampler = PointCloudSampler(
#     device=device,
#     models=[base_model],
#     diffusions=[base_diffusion],
#     num_points=[512],
#     aux_channels=['R', 'G', 'B'],
#     guidance_scale=[2.0],
# )



In [16]:
attributes = [attr for attr in dir(base_diffusion) if not attr.startswith('_')]
print(attributes)


['alphas_cumprod', 'alphas_cumprod_next', 'alphas_cumprod_prev', 'betas', 'calc_bpd_loop', 'channel_biases', 'channel_scales', 'condition_mean', 'condition_score', 'ddim_reverse_sample', 'ddim_sample', 'ddim_sample_loop', 'ddim_sample_loop_progressive', 'discretized_t0', 'get_sigmas', 'log_one_minus_alphas_cumprod', 'loss_type', 'model_mean_type', 'model_var_type', 'num_timesteps', 'p_mean_variance', 'p_sample', 'p_sample_loop', 'p_sample_loop_progressive', 'posterior_log_variance_clipped', 'posterior_mean_coef1', 'posterior_mean_coef2', 'posterior_variance', 'q_mean_variance', 'q_posterior_mean_variance', 'q_sample', 'scale_channels', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod', 'sqrt_recipm1_alphas_cumprod', 'training_losses', 'unscale_channels', 'unscale_out_dict']


In [23]:
base_diffusion.num_timesteps, upsampler_diffusion.num_timesteps

(50, 50)

In [24]:
#faster
base_diffusion.num_timesteps=50
upsampler_diffusion.num_timesteps=50

In [25]:
# Load an image to condition on.
img = Image.open('example_data/cube_stack.jpg')

# Produce a sample from the model.
samples = None
for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
    samples = x

0it [00:00, ?it/s]


ValueError: x and y arrays must be equal in length along interpolation axis.

In [None]:
pc = sampler.output_to_point_clouds(samples)[0]
fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))