In [None]:
!git clone https://github.com/reemkhattarr/3D_Highlighter
%cd 3D_Highlighter
!pip install git+https://github.com/openai/CLIP.git
!pip install git+https://github.com/NVIDIAGameWorks/kaolin.git

In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh

class NeuralHighlighter(nn.Module):
    def __init__(self, depth, width, out_dim, input_dim=3, sigma=5.0):
        super(NeuralHighlighter, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        print(self.mlp)

    def forward(self, x):
        for layer in self.mlp:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    clip_model, preprocess = clip.load(clipmodel, device)
    return clip_model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name+'.jpg')


def clip_loss(rendered_images, encoded_text, clip_model, n_augs):
    clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
    ])
    augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
    ])
    loss = 0.0
    for _ in range(n_augs):
        augmented_image = augment_transform(rendered_images)
        encoded_renders = clip_model.encode_image(augmented_image)
        if encoded_text.shape[0] > 1:
            loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                            torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                            encoded_text)
    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [None]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = 224
n_iter = 2500
res = 224

def optimize(obj_path='data/candle.obj', prompt='a gray candle with highlighted hat.',
             learning_rate=0.0001, depth=4, n_augs=5, n_views=5):
  output_dir = './output/'
  clip_model = 'ViT-L/14'

  Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

  objbase, extension = os.path.splitext(os.path.basename(obj_path))

  render = Renderer(dim=(render_res, render_res))
  mesh = Mesh(obj_path)
  MeshNormalizer(mesh)()

  # Initialize variables
  background = torch.tensor((1., 1., 1.)).to(device)

  log_dir = output_dir


  # MLP Settings
  mlp = NeuralHighlighter(depth, 256, 2).to(device)
  optim = torch.optim.Adam(mlp.parameters(), learning_rate)

  # list of possible colors
  rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
  color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
  full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
  colors = torch.tensor(full_colors).to(device)


  # --- Prompt ---
  # encode prompt with CLIP
  clip_model, preprocess = get_clip_model(clip_model)
  with torch.no_grad():
          prompt_token = clip.tokenize([prompt]).to(device)
          encoded_text = clip_model.encode_text(prompt_token)
          encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)


  vertices = copy.deepcopy(mesh.vertices)

  losses = []

  # Optimization loop
  for i in tqdm(range(n_iter)):
      optim.zero_grad()

      # predict highlight probabilities
      pred_class = mlp(vertices)

      # color and render mesh
      sampled_mesh = mesh
      color_mesh(pred_class, sampled_mesh, colors)
      rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                              show=False,
                                                              center_azim=0,
                                                              center_elev=0,
                                                              std=1,
                                                              return_views=True,
                                                              lighting=True,
                                                              background=background)

      # Calculate CLIP Loss
      loss = clip_loss(rendered_images, encoded_text, clip_model, n_augs)
      loss.backward(retain_graph=True)

      optim.step()

      # update variables + record loss
      with torch.no_grad():
          losses.append(loss.item())

      # report results
      if i % 100 == 0:
          print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
          save_renders(log_dir, i, rendered_images)
          with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
              f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


  # save results
  save_final_results(log_dir, objbase, mesh, mlp, vertices, colors, render, background)

  # Save prompts
  with open(os.path.join(log_dir, prompt), "w") as f:
      f.write('')

  #Exploration directories
  file_name = str(learning_rate) + '_' + str(depth) + '_' + str(n_augs) + '_' + str(n_views) + '_' + str(losses[-1])
  exploration_dir = os.path.join(output_dir, objbase)
  Path(exploration_dir).mkdir(parents=True, exist_ok=True)
  save_final_results(exploration_dir, file_name, mesh, mlp, vertices, colors, render, background)


In [None]:
#Hypterparameter exploration
objects = {
    'data/horse.obj' : 'a gray horse with highlighted shoes.'
}
learning_rates = [0.001, 0.0001, 0.00001]
depths = [2, 4, 6]
n_augss = [2, 5, 8]
n_viewss = [2, 5, 8]

#learning_rate
for obj_path in objects:
  prompt = objects[obj_path]
  for learning_rate in learning_rates:
    optimize(obj_path, prompt, learning_rate, 4, 5, 5)

#depth
for obj_path in objects:
  prompt = objects[obj_path]
  for depth in depths:
    optimize(obj_path, prompt, 0.0001, depth, 5, 5)

#n_augs
for obj_path in objects:
  prompt = objects[obj_path]
  for n_augs in n_augss:
    optimize(obj_path, prompt, 0.0001, 4, n_augs, 5)

#n_views
for obj_path in objects:
  prompt = objects[obj_path]
  for n_views in n_viewss:
    optimize(obj_path, prompt, 0.0001, 4, 5, n_views)


In [None]:

#Hyperparamter testing
objects = {
    'data/candle.obj' : 'a gray candle with highlighted hat.',
    'data/dog.obj' : 'a gray dog with highlighted hat.'
}
learning_rates = [0.001, 0.0001, 0.00001]
depths = [2, 4, 6]
n_augss = [2, 5, 8]
n_viewss = [2, 5, 8]

#learning_rate
for obj_path in objects:
  prompt = objects[obj_path]
  for learning_rate in learning_rates:
    optimize(obj_path, prompt, learning_rate, 4, 5, 5)

#depth
for obj_path in objects:
  prompt = objects[obj_path]
  for depth in depths:
    optimize(obj_path, prompt, 0.0001, depth, 5, 5)

#n_augs
for obj_path in objects:
  prompt = objects[obj_path]
  for n_augs in n_augss:
    optimize(obj_path, prompt, 0.0001, 4, n_augs, 5)

#n_views
for obj_path in objects:
  prompt = objects[obj_path]
  for n_views in n_viewss:
    optimize(obj_path, prompt, 0.0001, 4, 5, n_views)
