In [None]:
from main import process_input, retrieve_3d
import open3d as o3d
from glob import glob
import os.path as osp
from utils.misc import load_config
import open_clip
import os
import pickle
import numpy as np
from utils.refine import TextRefiner
from utils.renderer import Renderer
import matplotlib.pyplot as plt

In [None]:
config = load_config()

print("loading OpenCLIP model...")
os.makedirs('./clip_cache', exist_ok=True)

open_clip_model, _, open_clip_preprocess = open_clip.create_model_and_transforms('ViT-bigG-14', 
                                                                                 pretrained='laion2b_s39b_b160k', 
                                                                                 cache_dir='./clip_cache')
open_clip_model = open_clip_model.cpu().eval()
device = 'cpu'

print('loading Shape Embeddings...')

with open('./modelnet_embed/modelnet.pkl', 'rb') as f:
    shape_embeddings = pickle.load(f)
    shape_ids = np.array(list(shape_embeddings.keys())) # (N,)
    embeddings = np.array(list(shape_embeddings.values())) # (N, embed_dim)

refiner = TextRefiner()
renderer = Renderer(config.rendering_width, config.rendering_height)

while True:
    user_input = input("Enter a user description of shape to retrieve: ")
    k = int(input("Enter the number of shapes to retrieve: "))
    refined_text = [refiner.refine(user_input)]
    print(f'user input: {user_input}, refined_text: {refined_text[0]}')

    text_feature = process_input(refined_text, open_clip_model, device)[0] # (1, embed_dim)
    results = retrieve_3d(text_feature, embeddings, shape_ids, config, k=k)

    fig, axes = plt.subplots(1, k, figsize=(k*5, 5))
    for i, result in enumerate(results):
        file_name = osp.basename(result).split(".")[0]
        mesh = o3d.io.read_triangle_mesh(result)
        o3d.visualization.draw_geometries([mesh], window_name=file_name)
        rgb = renderer.render(mesh)
        axes[i].imshow(rgb)
        axes[i].axis('off')
        axes[i].set_title(f'Rank {i+1}')
    
    plt.tight_layout() 
    plt.show()