<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/05-multi-modal/CLIP-Mesh.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Name

## 0. Info

### Paper
* title: CLIP-Mesh: Generating textured meshes from text using pretrained image-text models
* author: Nasir Mohammad Khalid et al.
* url: https://arxiv.org/abs/2203.13333

### Features
* not used diffusion prior
* used transformers clip model

### Reference
* https://github.com/NasirKhalid24/CLIP-Mesh
* https://colab.research.google.com/drive/15Fm4EhLlB20EugLUnTdhSJElvGVCU7Ys?usp=sharing

## 1. Setup

In [None]:
# !git clone --recurse-submodules https://github.com/NasirKhalid24/CLIP-Mesh.git
# !cd /content/CLIP-Mesh/; git submodule update --init --recursive
# !cd /content/CLIP-Mesh/; pip install -r requirements.txt

# !cd /content/CLIP-Mesh/loop_limitation/; pip install .
# !cd /content/CLIP-Mesh/DALLE2-pytorch/; pip install .

# !python -m pip uninstall matplotlib --y
# !pip install -q matplotlib==3.1.3 
# !pip install -q transformers mediapy trimesh

In [None]:
import sys
sys.path.append('CLIP-Mesh')

In [None]:
import os
import glm
import math
import einops
import random
import easydict
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime

import kornia
import trimesh
import loop_limitation
import nvdiffrast.torch as dr

import clip
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from transformers import CLIPProcessor, CLIPModel

from nvdiffmodeling.src import (
    obj,
    util,
    mesh,
    render,
    texture,
    regularizer
)

Using /root/.cache/torch_extensions/py37_cu102 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py37_cu102/renderutils_plugin/build.ninja...
Building extension module renderutils_plugin...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module renderutils_plugin...


In [None]:
cfg = easydict.EasyDict(
    device = 'cuda',
    clip_model = 'openai/clip-vit-base-patch16',

    mesh = '/content/CLIP-Mesh/primitives/sphere.obj',
    epochs = 2000,
    lr = 0.01,

    batch_size = 25,
    texture_resolution = 512,
    train_resolution = 356,
    kernel_size = (7, 7),           
    blur_sigma = (3, 3),
    
    light_power = 5.0,
    num_layers = 2,
        
    clip_weight = 1.0,
    laplacian_weight = 30.0,
    laplacian_min = 0.6,
)


camera_cfg = easydict.EasyDict(
    image_resolution = cfg.train_resolution,
    distances = [5.0, 8.0], # (min, max)
    azimuths = [-360.0, 360.0], # (min, max)
    elevation_params = [1.0, 5.0, 60.0], # (alpha, beta, max)
    fovs = [30.0, 90.0], # (min, max)
    aug_loc = True,
    aug_light = True,
    aug_bkg = True,
    bs = cfg.batch_size, # batch size
)

## 2. Utils

In [None]:
# limit
class limitation_evaluate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, loop_obj):
        limitation = loop_obj.compute_limitation(input)
        jacobian = loop_obj.get_J()
        ctx.in1 = jacobian
        return limitation

    @staticmethod
    def backward(ctx, grad_output):
        grad = ctx.in1.T
        out = torch.matmul(grad,grad_output)      
        return out, None


class LimitSubdivide():
    def __init__(self, vertices, faces) -> None:
        self.loop_limit = loop_limitation.loop_limitation()
        self.loop_limit.init_J(vertices.to('cpu').double(), faces.to('cpu').int())
        self.compute_limit = limitation_evaluate.apply

    def get_limit(self, vertices):
        new_verts  = self.compute_limit(vertices.to('cpu').double(), self.loop_limit)
        return new_verts

In [None]:
# helpers
def cosine_avg(features, targets):
    return -F.cosine_similarity(features, targets).mean()


def _merge_attr_idx(a, b, a_idx, b_idx, scale_a=1.0, scale_b=1.0, add_a=0.0, add_b=0.0):
    if a is None and b is None:
        return None, None
    elif a is not None and b is None:
        return (a*scale_a)+add_a, a_idx
    elif a is None and b is not None:
        return (b*scale_b)+add_b, b_idx
    else:
        return torch.cat(((a*scale_a)+add_a, (b*scale_b)+add_b), dim=0), torch.cat((a_idx, b_idx + a.shape[0]), dim=0)


def create_scene(meshes, sz=1024):
    scene = mesh.Mesh()

    tot = len(meshes) if len(meshes) % 2 == 0 else len(meshes)+1

    nx = 2
    ny = math.ceil(tot / 2) if math.ceil(tot / 2) % 2 == 0 else math.ceil(tot / 2) + 1

    w = int(sz*ny)
    h = int(sz*nx)

    dev = meshes[0].v_tex.device

    kd_atlas = torch.ones ( (1, w, h, 4) ).to(dev)
    ks_atlas = torch.zeros( (1, w, h, 3) ).to(dev)
    kn_atlas = torch.ones ( (1, w, h, 3) ).to(dev)

    for i, m in enumerate(meshes):
        v_pos, t_pos_idx = _merge_attr_idx(scene.v_pos, m.v_pos, scene.t_pos_idx, m.t_pos_idx)
        v_nrm, t_nrm_idx = _merge_attr_idx(scene.v_nrm, m.v_nrm, scene.t_nrm_idx, m.t_nrm_idx)
        v_tng, t_tng_idx = _merge_attr_idx(scene.v_tng, m.v_tng, scene.t_tng_idx, m.t_tng_idx)

        pos_x = i % nx
        pos_y = int(i / ny)

        sc_x = 1./nx
        sc_y = 1./ny

        v_tex, t_tex_idx = _merge_attr_idx(
            scene.v_tex,
            m.v_tex,
            scene.t_tex_idx,
            m.t_tex_idx,
            scale_a=1.,
            scale_b=torch.tensor([sc_x, sc_y]).to(dev),
            add_a=0.,
            add_b=torch.tensor([sc_x*pos_x, sc_y*pos_y]).to(dev)
        )

        kd_atlas[:, pos_y*sz:(pos_y*sz)+sz, pos_x*sz:(pos_x*sz)+sz, :m.material['kd'].data.shape[-1]] = m.material['kd'].data
        ks_atlas[:, pos_y*sz:(pos_y*sz)+sz, pos_x*sz:(pos_x*sz)+sz, :m.material['ks'].data.shape[-1]] = m.material['ks'].data
        kn_atlas[:, pos_y*sz:(pos_y*sz)+sz, pos_x*sz:(pos_x*sz)+sz, :m.material['normal'].data.shape[-1]] = m.material['normal'].data

        scene = mesh.Mesh(
            v_pos=v_pos,
            t_pos_idx=t_pos_idx,
            v_nrm=v_nrm,
            t_nrm_idx=t_nrm_idx,
            v_tng=v_tng,
            t_tng_idx=t_tng_idx,
            v_tex=v_tex,
            t_tex_idx=t_tex_idx,
            base=scene 
        )

    scene = mesh.Mesh(
        material={
            'bsdf': 'diffuse',
            'kd': texture.Texture2D(kd_atlas),
            'ks': texture.Texture2D(ks_atlas),
            'normal': texture.Texture2D(kn_atlas),
        },
        base=scene # gets uvs etc from here
    )

    return scene

In [None]:
# camera
blurs = [
    transforms.Compose([
        transforms.GaussianBlur(11, sigma=(5, 5))
    ]),
    transforms.Compose([
        transforms.GaussianBlur(11, sigma=(2, 2))
    ]),
    transforms.Compose([
        transforms.GaussianBlur(5, sigma=(5, 5))
    ]),
    transforms.Compose([
        transforms.GaussianBlur(5, sigma=(2, 2))
    ]),
]


def get_random_bg(h, w):
        p = torch.rand(1)

        if p > 0.66666:
            background =  blurs[random.randint(0, 3)]( torch.rand((1, 3, h, w)) ).permute(0, 2, 3, 1)
        elif p > 0.333333:
            size = random.randint(5, 10)
            background = torch.vstack([
                torch.full( (1, size, size), torch.rand(1).item() / 2),
                torch.full( (1, size, size), torch.rand(1).item() / 2 ),
                torch.full( (1, size, size), torch.rand(1).item() / 2 ),
            ]).unsqueeze(0)

            second = torch.rand(3)

            background[:, 0, ::2, ::2] = second[0]
            background[:, 1, ::2, ::2] = second[1]
            background[:, 2, ::2, ::2] = second[2]

            background[:, 0, 1::2, 1::2] = second[0]
            background[:, 1, 1::2, 1::2] = second[1]
            background[:, 2, 1::2, 1::2] = second[2]

            # background = blurs[random.randint(0, 3)]( resize(background, out_shape=(h, w)) )
            background = blurs[random.randint(0, 3)](F.interpolate(background, size=(h, w)))

            background = background.permute(0, 2, 3, 1)

        else:
            background = torch.vstack([
                torch.full( (1, h, w), torch.rand(1).item()),
                torch.full( (1, h, w), torch.rand(1).item()),
                torch.full( (1, h, w), torch.rand(1).item()),
            ]).unsqueeze(0).permute(0, 2, 3, 1)

        return background


def cosine_sample(N : np.ndarray) -> np.ndarray:
    # construct local frame
    N = N/np.linalg.norm(N)

    dx0 = np.array([0, N[2], -N[1]])
    dx1 = np.array([-N[2], 0, N[0]])

    dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1
    dx = dx/np.linalg.norm(dx)
    dy = np.cross(N,dx)
    dy = dy/np.linalg.norm(dy)

    # cosine sampling in local frame
    phi = 2.0*np.pi*np.random.uniform()
    s = np.random.uniform()
    costheta = np.sqrt(s)
    sintheta = np.sqrt(1.0 - s)

    # cartesian vector in local space
    x = np.cos(phi)*sintheta
    y = np.sin(phi)*sintheta
    z = costheta

    # local to world
    return dx*x + dy*y + N*z


def persp_proj(fov_x=45, ar=1, near=1.0, far=50.0):
    fov_rad = np.deg2rad(fov_x)

    tanhalffov = np.tan( (fov_rad / 2) )
    max_y = tanhalffov * near
    min_y = -max_y
    max_x = max_y * ar
    min_x = -max_x

    z_sign = -1.0
    proj_mat = np.array([[0, 0, 0, 0],
                        [0, 0, 0, 0],
                        [0, 0, 0, 0],
                        [0, 0, 0, 0]])

    proj_mat[0, 0] = 2.0 * near / (max_x - min_x)
    proj_mat[1, 1] = 2.0 * near / (max_y - min_y)
    proj_mat[0, 2] = (max_x + min_x) / (max_x - min_x)
    proj_mat[1, 2] = (max_y + min_y) / (max_y - min_y)
    proj_mat[3, 2] = z_sign

    proj_mat[2, 2] = z_sign * far / (far - near)
    proj_mat[2, 3] = -(far * near) / (far - near)
    
    return proj_mat


def get_camera_params(elev_angle, azim_angle, distance, resolution, fov=60, look_at=[0, 0, 0], up=[0, -1, 0]):
    elev = np.radians(elev_angle)
    azim = np.radians(azim_angle) 
    
    # Generate random view
    cam_z = distance * np.cos(elev) * np.sin(azim)
    cam_y = distance * np.sin(elev)
    cam_x = distance * np.cos(elev) * np.cos(azim)

    modl = glm.mat4()
    view  = glm.lookAt(
        glm.vec3(cam_x, cam_y, cam_z),
        glm.vec3(look_at[0], look_at[1], look_at[2]),
        glm.vec3(up[0], up[1], up[2]),
    )

    a_mv = view * modl
    a_mv = np.array(a_mv.to_list()).T
    proj_mtx = persp_proj(fov)
    
    a_mvp = np.matmul(proj_mtx, a_mv).astype(np.float32)[None, ...]
    
    a_lightpos = np.linalg.inv(a_mv)[None, :3, 3]
    a_campos = a_lightpos

    return {
        'mvp' : a_mvp,
        'lightpos' : a_lightpos,
        'campos' : a_campos,
        'resolution' : [resolution, resolution], 
        }


class CameraBatch(torch.utils.data.Dataset):
    def __init__(
        self,
        image_resolution,
        distances,
        azimuths,
        elevation_params,
        fovs,
        aug_loc, 
        aug_light,
        aug_bkg,
        bs,
        look_at=[0, 0, 0], up=[0, -1, 0]
    ):

        self.res = image_resolution

        self.dist_min = distances[0]
        self.dist_max = distances[1]

        self.azim_min = azimuths[0]
        self.azim_max = azimuths[1]

        self.fov_min = fovs[0]
        self.fov_max = fovs[1]
        
        self.elev_alpha = elevation_params[0]
        self.elev_beta  = elevation_params[1]
        self.elev_max   = elevation_params[2]

        self.aug_loc   = aug_loc
        self.aug_light = aug_light
        self.aug_bkg   = aug_bkg

        self.look_at = look_at
        self.up = up

        self.batch_size = bs

    def __len__(self):
        return self.batch_size
        
    def __getitem__(self, index):
        elev = np.radians(np.random.beta( self.elev_alpha, self.elev_beta ) * self.elev_max)
        azim = np.radians(np.random.uniform( self.azim_min, self.azim_max+1.0 ))
        dist = np.random.uniform(self.dist_min, self.dist_max)
        fov = np.random.uniform(self.fov_min, self.fov_max)
        proj_mtx = persp_proj(fov)
        
        # Generate random view
        cam_z = dist * np.cos(elev) * np.sin(azim)
        cam_y = dist * np.sin(elev)
        cam_x = dist * np.cos(elev) * np.cos(azim)
        
        if self.aug_loc:

            # Random offset
            limit  = self.dist_min // 2
            rand_x = np.random.uniform( -limit, limit )
            rand_y = np.random.uniform( -limit, limit )

            modl = glm.translate(glm.mat4(), glm.vec3(rand_x, rand_y, 0))

        else:
        
            modl = glm.mat4()
            
        view  = glm.lookAt(
            glm.vec3(cam_x, cam_y, cam_z),
            glm.vec3(self.look_at[0], self.look_at[1], self.look_at[2]),
            glm.vec3(self.up[0], self.up[1], self.up[2]),
        )

        r_mv = view * modl
        r_mv = np.array(r_mv.to_list()).T

        mvp     = np.matmul(proj_mtx, r_mv).astype(np.float32)
        campos  = np.linalg.inv(r_mv)[:3, 3]

        if self.aug_light:
            lightpos = cosine_sample(campos)*dist
        else:
            lightpos = campos*dist

        if self.aug_bkg:
            bkgs = get_random_bg(self.res, self.res).squeeze(0)
        else:
            bkgs = torch.ones(self.res, self.res, 3)

        return {
            'mvp': torch.from_numpy( mvp ).float(),
            'lightpos': torch.from_numpy( lightpos ).float(),
            'campos': torch.from_numpy( campos ).float(),
            'bkgs': bkgs
        }

## 3. Run

In [None]:
TEXT_PROMPT = 'a cup of coffee'

In [None]:
processor = CLIPProcessor.from_pretrained(cfg.clip_model)
model = CLIPModel.from_pretrained(cfg.clip_model).to(cfg.device)

clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=cfg.device)[None, :, None, None]
clip_std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=cfg.device)[None, :, None, None]

In [None]:
text_inputs = processor(text=TEXT_PROMPT, return_tensors='pt').to(cfg.device)
with torch.no_grad():
    texts_embeds = model.get_text_features(**text_inputs)
    texts_embeds = texts_embeds / texts_embeds.norm(dim=1, keepdim=True)

In [None]:
glctx = dr.RasterizeGLContext()

In [None]:
load_mesh = obj.load_obj(cfg.mesh)
load_mesh = mesh.unit_size(load_mesh)

In [None]:
vertices = load_mesh.v_pos.clone().detach().requires_grad_(True)
faces = load_mesh.t_pos_idx.clone().detach()

resolution = [cfg.texture_resolution, cfg.texture_resolution]
texture_init = np.random.uniform(size=[cfg.texture_resolution, cfg.texture_resolution, 4], low=0.0, high=1.0)
texture_map = texture.create_trainable(texture_init, res=resolution, auto_mipmaps=True)

normal_init = np.array([0, 0, 1])
normal_map = texture.create_trainable(normal_init, res=resolution, auto_mipmaps=True)

specular_init = np.array([0, 0, 0])
specular_map = texture.create_trainable(specular_init, res=resolution, auto_mipmaps=True)

print(texture_map.data.shape, normal_map.data.shape, specular_map.data.shape)

torch.Size([1, 512, 512, 4]) torch.Size([1, 512, 512, 3]) torch.Size([1, 512, 512, 3])


In [None]:
train_params = []
train_params += [vertices]
train_params += texture_map.getMips()
train_params += normal_map.getMips()
train_params += specular_map.getMips()

optimizer  = torch.optim.Adam(train_params, lr=cfg.lr)
scheduler  = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) 

In [None]:
train_mesh = mesh.Mesh(
    vertices,
    faces,
    material = {
        'bsdf': 'diffuse',
        'kd': texture_map,
        'ks': specular_map,
        'normal': normal_map,
    },
    base = load_mesh # Get UVs from original loaded mesh
)

train_subdiv = LimitSubdivide(
    load_mesh.v_pos.clone().detach(),
    load_mesh.t_pos_idx.clone().detach(),
)

In [None]:
camera_ds = CameraBatch(**camera_cfg)
camera_loader = torch.utils.data.DataLoader(
    camera_ds,
    batch_size = len(camera_ds),
    num_workers = 0,
    pin_memory = True
)

In [None]:
t_loop = tqdm(range(cfg.epochs), leave=False)
for it in t_loop:
    lapl_funcs = []

    m = train_mesh
    n_vert = train_subdiv.get_limit(m.v_pos.to('cpu').double()).to(cfg.device)

    ready_texture = texture.Texture2D(
        kornia.filters.gaussian_blur2d(
            m.material['kd'].data.permute(0, 3, 1, 2),
            kernel_size = cfg.kernel_size,
            sigma = cfg.blur_sigma,
        ).permute(0, 2, 3, 1).contiguous()
    )

    ready_specular = texture.Texture2D(
        kornia.filters.gaussian_blur2d(
            m.material['ks'].data.permute(0, 3, 1, 2),
            kernel_size = cfg.kernel_size,
            sigma = cfg.blur_sigma,
        ).permute(0, 2, 3, 1).contiguous()
    )

    ready_normal = texture.Texture2D(
        kornia.filters.gaussian_blur2d(
            m.material['normal'].data.permute(0, 3, 1, 2),
            kernel_size = cfg.kernel_size,
            sigma = cfg.blur_sigma,
        ).permute(0, 2, 3, 1).contiguous()
    )

    load_mesh = mesh.Mesh(
        n_vert,
        m.t_pos_idx,
        material = {
            'bsdf': 'diffuse',
            'kd': ready_texture,
            'ks': ready_specular,
            'normal': ready_normal,
        },
        base = m # gets uvs etc from here
    )

    rendered_mesh = load_mesh.eval()
    lapl_funcs.append(regularizer.laplace_regularizer_const(m))

    complete_scene = create_scene([rendered_mesh], sz=cfg.texture_resolution)
    complete_scene = mesh.auto_normals(complete_scene)
    complete_scene = mesh.compute_tangents(complete_scene)

    params_camera = next(iter(camera_loader))
    for key in params_camera:
        params_camera[key] = params_camera[key].to(cfg.device)


    params = {
        'mvp': params_camera['mvp'],
        'lightpos': params_camera['lightpos'],
        'campos': params_camera['campos'],
        'resolution': [cfg.train_resolution, cfg.train_resolution]
    }

    train_render = render.render_mesh(
        ctx = glctx,
        mesh = complete_scene.eval(params),
        mtx_in = params["mvp"],
        view_pos = params["campos"],
        light_pos = params["lightpos"],
        light_power = cfg.light_power,
        resolution = cfg.train_resolution,
        spp = 1, # no upscale here / render at any resolution then use resize_right to downscale
        num_layers = cfg.num_layers,
        msaa = False,
        background = params_camera["bkgs"],
    ).permute(0, 3, 1, 2) # switch to B, C, H, W


    train_render = F.interpolate(train_render, size=(224, 224), mode='bicubic')
    train_render = (train_render - clip_mean) / clip_std
    image_embeds = model.get_image_features(train_render)
    image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
    clip_loss  = cosine_avg(image_embeds, texts_embeds)

    lapls = []
    for fn_l in lapl_funcs:
        if fn_l is not None:
            lapls.append(fn_l.eval(params))

    if it == 0:
        laplacian_weight = cfg.laplacian_weight
        laplacian_min = cfg.laplacian_min
    else:
        laplacian_weight = (laplacian_weight - laplacian_min) * 10**(-it*0.000001) + laplacian_min

    lapls_loss = 0
    for lap_l in lapls:
        lapls_loss += (laplacian_weight * lap_l)

    loss = cfg.clip_weight * clip_loss + lapls_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    normal_map.clamp_(min=-1, max=1)
    specular_map.clamp_(min=0, max=1)
    texture_map.clamp_(min=0, max=1)

    log = {'clip': clip_loss.item(), 'lapl': lapls_loss.item()}
    t_loop.set_postfix(log)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [None]:
os.makedirs('mesh', exist_ok=True)
obj.write_obj('mesh', train_mesh)
!zip -r mesh.zip mesh/*

Writing mesh:  mesh/mesh.obj
    writing 642 vertices
    writing 729 texcoords
    writing 1280 normals
    writing 1280 faces
Writing material:  mesh/mesh.mtl
Done exporting mesh
updating: mesh/mesh.mtl (deflated 27%)
updating: mesh/mesh.obj (deflated 75%)
updating: mesh/texture_kd.png (deflated 0%)
updating: mesh/texture_ks.png (deflated 90%)
updating: mesh/texture_n.png (deflated 0%)


In [None]:
m = trimesh.load('mesh/mesh.obj')
im = Image.open('mesh/texture_kd.png')

material = trimesh.visual.texture.SimpleMaterial(image=im)
color_visuals = trimesh.visual.TextureVisuals(uv=m.visual.uv, image=im, material=material)
mesh = trimesh.Trimesh(vertices=m.vertices, faces=m.faces, visual=color_visuals, validate=True, process=False)

In [None]:
mesh.show()