In [13]:
import os
import os.path as osp
import subprocess

from PIL import Image
import torch
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
from urllib.error import HTTPError

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

originals_folder = 'images/original'
postprocessed_folder = 'images/postprocessed'
pc_img_folder = 'images/pc/img'
pc_npy_folder = 'images/pc/npy'
os.makedirs(originals_folder, exist_ok=True)
os.makedirs(postprocessed_folder, exist_ok=True)
os.makedirs(pc_img_folder, exist_ok=True)
os.makedirs(pc_npy_folder, exist_ok=True)

Collect images with url's I've manually collected

In [2]:
web_images = pd.read_csv('web_images.csv')

for i, row in tqdm(list(web_images.iterrows())):
    prompt, url = row['prompt'], row['url']
    prompt = prompt.strip()
    fname = os.path.join(originals_folder, f"{prompt.lower().replace(' ', '_')}.png")
    if osp.exists(fname): continue
    try:
        urlretrieve(url, fname)
    except HTTPError:
        print(f'Could not retrieve {prompt} from {url}')

  0%|          | 0/38 [00:00<?, ?it/s]

Could not retrieve Goomba from https://mario.wiki.gallery/images/7/7d/SMBW_Goomba.png
Could not retrieve Space Invader from https://png.pngtree.com/png-clipart/20230823/original/pngtree-space-invaders-character-game-play-picture-image_8233346.png
Could not retrieve Taj Mahal from https://png.pngtree.com/background/20230621/original/pngtree-white-background-with-a-3d-glossy-taj-mahal-monument-picture-image_3893624.jpg


Generate new images via stable diffusion

In [16]:
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1").to("cuda:2")
pipe.enable_attention_slicing()

with open('stable_diffusion_prompts.txt') as f:
    sd_prompts = [x.strip() for x in f.readlines()]

for prompt in tqdm(sd_prompts):
	fname = os.path.join(sd_folder, f"{prompt.lower().replace(' ', '_')}.png")
	if osp.exists(fname): continue
	with autocast("cuda"):
		image = pipe(prompt).images[0]
	image.save(fname)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  images = (images * 255).round().astype("uint8")


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 

: 

Preprocess images

In [None]:
for image_name in tqdm(os.listdir(originals_folder)):
    # Recrop
    image = Image.open(osp.join(originals_folder, image_name))
    width, height = image.size
    square_size = int(1.2 * max(width, height))
    new_image = Image.new("RGB", (square_size, square_size), color="white")

    x = (square_size - width) // 2
    y = (square_size - height) // 2
    new_image.paste(image, (x, y))

    # Make background transparent
    tmp_path = osp.join(postprocessed_folder, f'tmp_{image_name}')
    final_path = osp.join(postprocessed_folder, f'{image_name}')
    resized_image = new_image.resize((256, 256), Image.LANCZOS)
    resized_image.save(tmp_path)
    result = subprocess.run(
        ['backgroundremover', '-i', f'{tmp_path}', '-o', f'{final_path}'],
        capture_output=True,
        text=True
    )
    if result.stderr:
        print(result.stderr)
    
    # Add back white background
    image = Image.open(final_path)
    background = Image.new("RGB", image.size, (255, 255, 255))
    background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
    background.save(final_path)
    os.remove(tmp_path)

Generate point clouds

In [3]:
torch.cuda.device_count()

8

In [4]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

print('creating base model...')
base_name = 'base1B' # use base300M or base1B for better results
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

creating base model...
creating upsample model...
downloading base checkpoint...


  0%|          | 0.00/4.98G [00:00<?, ?iB/s]

downloading upsampler checkpoint...


<All keys matched successfully>

In [5]:
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

In [6]:
for image_name in tqdm(os.listdir(postprocessed_folder)):
    img = Image.open(osp.join(postprocessed_folder, image_name))
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[img]))):
        samples = x
        
    pc = sampler.output_to_point_clouds(samples)[0]
    fig = plot_point_cloud(pc, grid_size=3, dot_size=0.1, fixed_bounds=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75)))
    plt.savefig(osp.join(pc_img_folder, image_name))
    plt.close()
    
    pc_save = np.concatenate([pc.coords, pc.channels['R'][:,None], pc.channels['G'][:,None], pc.channels['B'][:,None]], axis=1)
    np.save(osp.join(pc_npy_folder, f'{osp.splitext(image_name)[0]}.npz'), pc_save)

  0%|          | 0/34 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]