In [None]:
! pip install torch
! pip install git+https://github.com/openai/shap-e.git
! pip install gradio
! pip install trimesh

In [None]:
import torch
import gradio as gr

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import decode_latent_mesh

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

def generate_mesh(prompt):
    batch_size = 1
    guidance_scale = 15.0

    latents = sample_latents(
        batch_size=batch_size,
        model=model,
        diffusion=diffusion,
        guidance_scale=guidance_scale,
        model_kwargs=dict(texts=[prompt] * batch_size),
        progress=True,
        clip_denoised=True,
        use_fp16=True,
        use_karras=True,
        karras_steps=64,
        sigma_min=1e-3,
        sigma_max=160,
        s_churn=0,
    )

 
    import trimesh
    from trimesh.visual import ColorVisuals
    import numpy as np
    for i, latent in enumerate(latents):
     print(i)
     latent_mesh = decode_latent_mesh(xm, latent).tri_mesh()
     vertex_colors = np.vstack((
        latent_mesh.vertex_channels['R'],
        latent_mesh.vertex_channels['G'],
        latent_mesh.vertex_channels['B']
    )).T
    mesh = trimesh.Trimesh(vertices=latent_mesh.verts,
                           faces=latent_mesh.faces,
                           face_normals=latent_mesh.normals,
                           visual=ColorVisuals(vertex_colors=vertex_colors))
    scene = trimesh.Scene()
    scene.add_geometry(mesh)
    # with open(f'example_mesh_{i}.ply', 'wb') as f:
    with open(f"models_{i}.glb", "wb") as f:
        f.write(trimesh.exchange.gltf.export_glb(scene))
    return "models_0.glb", "models_0.glb"

inputs = gr.inputs.Textbox(label="Prompt")

outputs=[gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Model"),  gr.outputs.File(label="Generated Mesh")]
gr.Interface(generate_mesh, inputs, outputs).launch(share=True, server_port=8082)


In [None]:
import trimesh

# load the mesh from file
mesh = trimesh.load("models_0.glb")

# visualize the mesh
mesh.show()
