In [1]:
!git clone https://github.com/reemkhattarr/3D_Highlighter
%cd 3D_Highlighter
!pip install git+https://github.com/openai/CLIP.git
!pip install git+https://github.com/NVIDIAGameWorks/kaolin.git

Cloning into '3D_Highlighter'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 30 (delta 2), reused 25 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (30/30), 1.81 MiB | 30.45 MiB/s, done.
Resolving deltas: 100% (2/2), done.
/content/3D_Highlighter
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-vcntygjz
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-vcntygjz
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x8

In [2]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh

class NeuralHighlighter(nn.Module):
    def __init__(self, depth, width, out_dim, input_dim=3, sigma=5.0):
        super(NeuralHighlighter, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        print(self.mlp)

    def forward(self, x):
        for layer in self.mlp:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    clip_model, preprocess = clip.load(clipmodel, device)
    return clip_model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')


def clip_loss(rendered_images, encoded_text, clip_model, n_augs):
    clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
    ])
    augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
    ])
    loss = 0.0
    for _ in range(n_augs):
        augmented_image = augment_transform(rendered_images)
        encoded_renders = clip_model.encode_image(augmented_image)
        if encoded_text.shape[0] > 1:
            loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                            torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                            encoded_text)
    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [4]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


#Exploration
learning_rate = 0.0001
depth = 4
n_augs = 5
n_views = 5

render_res = 224
n_iter = 2500
res = 224
obj_path = 'data/candle.obj'
output_dir = './output/'
clip_model = 'ViT-L/14'

Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

log_dir = output_dir


# MLP Settings
mlp = NeuralHighlighter(depth, 256, 2).to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model)
prompt = 'a gray candle with highlighted hat.'
with torch.no_grad():
        prompt_token = clip.tokenize([prompt]).to(device)
        encoded_text = clip_model.encode_text(prompt_token)
        encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)


vertices = copy.deepcopy(mesh.vertices)
n_views = 5

losses = []

# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss = clip_loss(rendered_images, encoded_text, clip_model, n_augs)
    loss.backward(retain_graph=True)

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


# save results
save_final_results(log_dir, objbase, mesh, mlp, vertices, colors, render, background)

# Save prompts
with open(os.path.join(log_dir, prompt), "w") as f:
    f.write('')

ModuleList(
  (0): Linear(in_features=3, out_features=256, bias=True)
  (1): ReLU()
  (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=256, out_features=256, bias=True)
  (7): ReLU()
  (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=256, out_features=256, bias=True)
  (10): ReLU()
  (11): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (12): Linear(in_features=256, out_features=256, bias=True)
  (13): ReLU()
  (14): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (15): Linear(in_features=256, out_features=2, bias=True)
  (16): Softmax(dim=1)
)


  0%|          | 1/2500 [00:00<14:17,  2.91it/s]

Last 100 CLIP score: -1.11328125


  4%|▍         | 101/2500 [00:34<13:48,  2.90it/s]

Last 100 CLIP score: -1.32294921875


  8%|▊         | 201/2500 [01:08<13:01,  2.94it/s]

Last 100 CLIP score: -1.334462890625


 12%|█▏        | 301/2500 [01:43<12:33,  2.92it/s]

Last 100 CLIP score: -1.3278125


 16%|█▌        | 401/2500 [02:17<11:57,  2.93it/s]

Last 100 CLIP score: -1.334951171875


 20%|██        | 501/2500 [02:52<11:57,  2.79it/s]

Last 100 CLIP score: -1.33328125


 24%|██▍       | 601/2500 [03:26<10:50,  2.92it/s]

Last 100 CLIP score: -1.34283203125


 28%|██▊       | 701/2500 [04:01<10:26,  2.87it/s]

Last 100 CLIP score: -1.339150390625


 32%|███▏      | 801/2500 [04:36<09:58,  2.84it/s]

Last 100 CLIP score: -1.329560546875


 36%|███▌      | 901/2500 [05:11<09:16,  2.87it/s]

Last 100 CLIP score: -1.33078125


 40%|████      | 1001/2500 [05:45<08:38,  2.89it/s]

Last 100 CLIP score: -1.3321484375


 44%|████▍     | 1101/2500 [06:20<08:04,  2.89it/s]

Last 100 CLIP score: -1.340029296875


 48%|████▊     | 1201/2500 [06:55<07:31,  2.88it/s]

Last 100 CLIP score: -1.343984375


 52%|█████▏    | 1301/2500 [07:30<06:54,  2.89it/s]

Last 100 CLIP score: -1.332802734375


 56%|█████▌    | 1401/2500 [08:04<06:21,  2.88it/s]

Last 100 CLIP score: -1.335751953125


 60%|██████    | 1501/2500 [08:38<05:48,  2.87it/s]

Last 100 CLIP score: -1.361748046875


 64%|██████▍   | 1601/2500 [09:13<05:13,  2.87it/s]

Last 100 CLIP score: -1.358115234375


 68%|██████▊   | 1701/2500 [09:48<04:42,  2.82it/s]

Last 100 CLIP score: -1.352275390625


 72%|███████▏  | 1801/2500 [10:23<04:02,  2.89it/s]

Last 100 CLIP score: -1.347080078125


 76%|███████▌  | 1901/2500 [10:57<03:30,  2.84it/s]

Last 100 CLIP score: -1.3337890625


 80%|████████  | 2001/2500 [11:32<02:51,  2.91it/s]

Last 100 CLIP score: -1.357587890625


 84%|████████▍ | 2101/2500 [12:07<02:18,  2.87it/s]

Last 100 CLIP score: -1.37265625


 88%|████████▊ | 2201/2500 [12:41<01:43,  2.90it/s]

Last 100 CLIP score: -1.3528125


 92%|█████████▏| 2301/2500 [13:16<01:09,  2.88it/s]

Last 100 CLIP score: -1.34623046875


 96%|█████████▌| 2401/2500 [13:51<00:34,  2.90it/s]

Last 100 CLIP score: -1.34693359375


100%|██████████| 2500/2500 [14:25<00:00,  2.89it/s]
