# Shap-E

Objective: to generate 3D objects conditioned on text or images.

Reference: https://github.com/openai/shap-e

## Install Python packages

In [None]:
%pip install -qq git+https://github.com/openai/shap-e.git 
%pip install -qq mediapy

## Prepare the models

In [None]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

## Define a wrapper

In [None]:
def sample_wrapper(
    text = "a shark",
    image = None,
    guidance_scale = None,
    batch_size = 4,
  ):
  
  if image is None:
    model_kwargs=dict(texts=[text] * batch_size)
    if guidance_scale is None:
      guidance_scale = 15.0
  else:
    model_kwargs=dict(images=[image] * batch_size)
    if guidance_scale is None:
      guidance_scale = 3.0

  latents = sample_latents(
      batch_size=batch_size,
      model=model,
      diffusion=diffusion,
      guidance_scale=guidance_scale,
      model_kwargs=model_kwargs,
      progress=True,
      clip_denoised=True,
      use_fp16=True,
      use_karras=True,
      karras_steps=64,
      sigma_min=1e-3,
      sigma_max=160,
      s_churn=0,
  )

  return latents

## Sample a 3D model

### Conditioned on a text prompt

In [None]:
text = "a shark"

latents = sample_wrapper(
    text = text,
    guidance_scale = 15.0,
)

### Conditioned on an image prompt

In [None]:
import mediapy as media

image_url = "https://raw.githubusercontent.com/openai/shap-e/main/shap_e/examples/example_data/corgi.png"
image = media.read_image(image_url)

media.show_image(image)

latents = sample_wrapper(
    image = image,
    guidance_scale = 3.0,
)

## Render

In [None]:
render_mode = 'nerf' # you can change this to 'stf'
size = 64 # this is the size of the renders; higher values take longer to render.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    display(gif_widget(images))

## Save the latents as meshes

In [None]:
from shap_e.util.notebooks import decode_latent_mesh

for i, latent in enumerate(latents):
    with open(f'example_mesh_{i}.ply', 'wb') as f:
        decode_latent_mesh(xm, latent).tri_mesh().write_ply(f)