In [1]:
!git clone https://github.com/reemkhattarr/3D_Highlighter
%cd 3D_Highlighter

fatal: destination path '3D_Highlighter' already exists and is not an empty directory.
/content/3D_Highlighter


In [2]:
!pip install git+https://github.com/openai/CLIP.git
!pip install git+https://github.com/NVIDIAGameWorks/kaolin.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-6ch36c37
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-6ch36c37
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting git+https://github.com/NVIDIAGameWorks/kaolin.git
  Cloning https://github.com/NVIDIAGameWorks/kaolin.git to /tmp/pip-req-build-0ggee0rz
  Running command git clone --filter=blob:none --quiet https://github.com/NVIDIAGameWorks/kaolin.git /tmp/pip-req-build-0ggee0rz
  Resolved https://github.com/NVIDIAGameWorks/kaolin.git to commit 342bac64b612ee0659a68c888602d20cd787b594
  Running command git submodule update --init --recursive -q
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [3]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh

class NeuralHighlighter(nn.Module):
    def __init__(self, num_layers=3, input_dim=3, hidden_dim=64, output_dim=2):
        super(NeuralHighlighter, self).__init__()
        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        layers = []
        layers.append(nn.Linear(self.input_dim, self.hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm(self.hidden_dim))

        for _ in range(self.num_layers - 1):
            layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(self.hidden_dim))

        layers.append(nn.Linear(self.hidden_dim, self.output_dim))

        self.model = nn.Sequential(*layers)
        print(self.model)

    def forward(self, x):
        return self.model(x)

def get_clip_model(clipmodel):
  model, preprocess = clip.load(clipmodel, device=device)
  return model, preprocess


# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')


def clip_loss(language_embedding, rendered_images, clip_model):
    # Preprocess rendered images
    preprocess = transforms.Compose([
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
    ])
    processed_images = preprocess(rendered_images)
    image_embedding = clip_model.encode_image(processed_images)
    image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)


    loss = -torch.cosine_similarity(language_embedding, image_embedding, dim=-1).mean()

    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [None]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


render_res = 224
learning_rate = 0.0001
n_iter = 2500
res = 224
obj_path = 'data/horse.obj'
n_augs = 5
output_dir = './output/'
clip_model = 'ViT-L/14'

Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

log_dir = output_dir


# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model)
prompt = 'a gray horse with highlighted hat.'
tokenized_prompt = clip.tokenize([prompt]).to(device)
encoded_text = clip_model.encode_text(tokenized_prompt)


vertices = copy.deepcopy(mesh.vertices)
n_views = 5

losses = []

# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss = clip_loss(encoded_text, rendered_images, clip_model)
    loss.backward(retain_graph=True)

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


# save results
save_final_results(log_dir, objbase, mesh, mlp, vertices, colors, render, background)

# Save prompts
with open(os.path.join(log_dir, prompt), "w") as f:
    f.write('')

Sequential(
  (0): Linear(in_features=3, out_features=64, bias=True)
  (1): ReLU()
  (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=64, out_features=64, bias=True)
  (4): ReLU()
  (5): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=64, out_features=64, bias=True)
  (7): ReLU()
  (8): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=64, out_features=2, bias=True)
)


  0%|          | 2/2500 [00:00<15:36,  2.67it/s]

Last 100 CLIP score: -0.2098388671875


  4%|▍         | 102/2500 [00:15<05:53,  6.78it/s]

Last 100 CLIP score: -0.200601806640625


  8%|▊         | 202/2500 [00:29<05:37,  6.80it/s]

Last 100 CLIP score: -0.21093017578125


 12%|█▏        | 302/2500 [00:44<05:21,  6.83it/s]

Last 100 CLIP score: -0.217347412109375


 16%|█▌        | 402/2500 [00:58<05:04,  6.90it/s]

Last 100 CLIP score: -0.2237158203125


 20%|██        | 502/2500 [01:13<04:50,  6.88it/s]

Last 100 CLIP score: -0.2323681640625


 24%|██▍       | 602/2500 [01:27<04:36,  6.87it/s]

Last 100 CLIP score: -0.2334375


 28%|██▊       | 702/2500 [01:41<04:21,  6.88it/s]

Last 100 CLIP score: -0.23607666015625


 32%|███▏      | 802/2500 [01:56<04:07,  6.87it/s]

Last 100 CLIP score: -0.237955322265625


 33%|███▎      | 817/2500 [01:58<04:00,  7.00it/s]