In [None]:
import run
import torch
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from omegaconf import OmegaConf
import os
from huggingface_hub import hf_hub_download
from src.utils.train_util import instantiate_from_config
import rembg
import numpy as np
from PIL import Image
from src.utils.infer_util import remove_background, resize_foreground, save_video
from einops import rearrange
from src.utils.camera_util import get_zero123plus_input_cameras
from torchvision.transforms import v2
from kornia.filters import bilateral_blur

torch.cuda.is_available()

In [None]:
config = OmegaConf.load("configs/instant-mesh-base.yaml")
config_name = "instant-mesh-base"
model_config = config.model_config
infer_config = config.infer_config

IS_FLEXICUBES = True

device = torch.device("cuda")

In [None]:
output_path = "out"
input_path = "examples/cartoon_dinosaur.png"

# make output directories
image_path = os.path.join(output_path, config_name, 'images')
mesh_path = os.path.join(output_path, config_name, 'meshes')
video_path = os.path.join(output_path, config_name, 'videos')
os.makedirs(image_path, exist_ok=True)
os.makedirs(mesh_path, exist_ok=True)
os.makedirs(video_path, exist_ok=True)

# process input files
if os.path.isdir(input_path):
    input_files = [
        os.path.join(input_path, file) 
        for file in os.listdir(input_path) 
        if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
    ]
else:
    input_files = [input_path]
print(f'Total number of input images: {len(input_files)}')

In [None]:
pipeline = DiffusionPipeline.from_pretrained(
    "sudo-ai/zero123plus-v1.2", 
    custom_pipeline="zero123plus",
    torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
    pipeline.scheduler.config, timestep_spacing='trailing'
)
if os.path.exists(infer_config.unet_path):
    unet_ckpt_path = infer_config.unet_path
else:
    unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)

pipeline = pipeline.to(device)

In [None]:
no_rembg = True
diffusion_steps = 75

rembg_session = None if no_rembg else rembg.new_session()

outputs = []
for idx, image_file in enumerate(input_files):
    name = os.path.basename(image_file).split('.')[0]
    print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')

    # remove background optionally
    input_image = Image.open(image_file)
    if not no_rembg:
        input_image = remove_background(input_image, rembg_session)
        input_image = resize_foreground(input_image, 0.85)
    
    # sampling
    output_image = pipeline(
        input_image, 
        num_inference_steps=diffusion_steps, 
    ).images[0]

    output_image.save(os.path.join(image_path, f'{name}.png'))
    print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")

    images = np.asarray(output_image, dtype=np.float32) / 255.0
    images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()     # (3, 960, 640)
    images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)        # (6, 3, 320, 320)

    outputs.append({'name': name, 'images': images})

In [None]:
# delete pipeline to save memory
del pipeline

In [None]:
model = instantiate_from_config(model_config)
if os.path.exists(infer_config.model_path):
    model_ckpt_path = infer_config.model_path
else:
    model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
model.load_state_dict(state_dict, strict=True)

model = model.to(device)
if IS_FLEXICUBES:
    model.init_flexicubes_geometry(device, fovy=30.0)
model = model.eval()

In [None]:
scale = 1.0
view = 6
export_texmap = True
do_save_video = True
distance = 4.5

name = os.path.basename(input_files[0]).split('.')[0]
img_path = os.path.join(image_path, f'{name}.png')
img = Image.open(img_path)
images = np.asarray(img, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()     # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)    
outputs = [{'name': name, 'images': images}]

input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*scale).to(device)
chunk_size = 20 if IS_FLEXICUBES else 1

for idx, sample in enumerate(outputs):
    name = sample['name']
    print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')

    images = sample['images'].unsqueeze(0).to(device)
    images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)

    if view == 4:
        indices = torch.tensor([0, 2, 4, 5]).long().to(device)
        images = images[:, indices]
        input_cameras = input_cameras[:, indices]

    with torch.no_grad():
        # get triplane
        planes = model.forward_planes(images, input_cameras)
        
        # Get freeplanes
        k = (3, 3)
        sigma_c = 10
        sigma_r = (10, 10)
    
        freeplanes = bilateral_blur(
            planes.transpose(0, 2).squeeze(), k, sigma_c, sigma_r
        ).unsqueeze(2).transpose(0, 2)

        # get mesh
        mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')

        mesh_out = model.extract_mesh(
            planes,
            use_texture_map=export_texmap,
            freeplanes=freeplanes,
            **infer_config,
        )
        if export_texmap:
            vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
            run.save_obj_with_mtl(
                vertices.data.cpu().numpy(),
                uvs.data.cpu().numpy(),
                faces.data.cpu().numpy(),
                mesh_tex_idx.data.cpu().numpy(),
                tex_map.permute(1, 2, 0).data.cpu().numpy(),
                mesh_path_idx,
            )
        else:
            vertices, faces, vertex_colors = mesh_out
            run.save_obj(vertices, faces, vertex_colors, mesh_path_idx)
        print(f"Mesh saved to {mesh_path_idx}")

        # get video
        if do_save_video:
            video_path_idx = os.path.join(video_path, f'{name}.mp4')
            render_size = infer_config.render_resolution
            render_cameras = run.get_render_cameras(
                batch_size=1, 
                M=120, 
                radius=distance, 
                elevation=20.0,
                is_flexicubes=IS_FLEXICUBES,
            ).to(device)
            
            frames = run.render_frames(
                model, 
                planes, 
                render_cameras=render_cameras, 
                render_size=render_size, 
                chunk_size=chunk_size, 
                is_flexicubes=IS_FLEXICUBES,
            )

            save_video(
                frames,
                video_path_idx,
                fps=30,
            )
            print(f"Video saved to {video_path_idx}")
