In [2]:
import sys
sys.path.append('../')
import torch
from torch import nn
from dataset import AEDataset, RobustAEDataset
from torch.utils.data import DataLoader
from lightning_model.autoencoder import LitAE 
from lightning_model.imle import LitIMLEGenerator
from lightning_model.clip import LitTextPointCloudCLIP

In [3]:
from types import SimpleNamespace
config = {
    'enc_filters': (64, 128, 128, 256),
    'latent_dim': 128,
    'enc_bn': True,
    'dec_features': (256, 256),
    'n_pts': 256,
    'dec_bn': False,
    'noise_dim': 32,
    'num_latent': 80,
    'imle_features': (256, 512),
    'latent_loss_weight': 1000
}
config = SimpleNamespace(**config)

In [4]:
checkpoint = '../lightning_logs/ae_model_20240410-103850/version_0/checkpoints/epoch=472-step=84667.ckpt'
autoencoder = LitAE.load_from_checkpoint(checkpoint, config=config)

In [5]:
autoencoder.eval().cuda()
# clip_model.eval().cuda()
# clip_model.device

LitAE(
  (autoencoder): PointAE(
    (encoder): EncoderPointNet(
      (model): Sequential(
        (0): Conv1d(39, 64, kernel_size=(1,), stride=(1,))
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
        (3): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
        (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): LeakyReLU(negative_slope=0.01, inplace=True)
        (6): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
        (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): LeakyReLU(negative_slope=0.01, inplace=True)
        (9): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
        (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (11): LeakyReLU(negative_slope=0.01, inplace=True)
        (12): Conv1d(256, 128, kernel_size=(1,), stride=(1

In [30]:
import clip
from dataset import shapenetcore_cat2id
shapenetcore_id2cat = {v: k for k, v in shapenetcore_cat2id.items()}

pretrained_model, preprocess = clip.load("ViT-B/32", device='cuda', jit=False)

label_list = list(shapenetcore_cat2id.keys())
text = clip.tokenize([
    l for l in label_list
]).to('cuda')
with torch.no_grad():
    label_latents = pretrained_model.encode_text(text).detach().float()

In [7]:
imle_checkpoint = '../lightning_logs/imle/lightning_logs/version_13/checkpoints/epoch=315-step=352656.ckpt'
lit_imle_model = LitIMLEGenerator.load_from_checkpoint(
    imle_checkpoint, 
    config=config,
    label_latents=label_latents,
    autoencoder=autoencoder    
)

In [8]:
lit_imle_model.eval()

LitIMLEGenerator(
  (autoencoder): LitAE(
    (autoencoder): PointAE(
      (encoder): EncoderPointNet(
        (model): Sequential(
          (0): Conv1d(39, 64, kernel_size=(1,), stride=(1,))
          (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
          (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): LeakyReLU(negative_slope=0.01, inplace=True)
          (6): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
          (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (8): LeakyReLU(negative_slope=0.01, inplace=True)
          (9): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
          (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (11): LeakyReLU(negative_slope=0.01, inplace

In [59]:
import os
from visualize import mitsuba
def generate_category(category_i, n_samples=1):
    os.makedirs(f'results/{id2cat[category_i]}', exist_ok=True)
    for i in range(n_samples):
        pts = lit_imle_model.generate(category_i)[0].detach().cpu().numpy().T
        mitsuba(pts, f'results/{id2cat[category_i]}/{i}.xml')

In [60]:
cat2id = {}
for n, i in zip(test_dataset.name, test_dataset.label):
    cat2id[n] = i[0]
id2cat = {v: k for k, v in cat2id.items()}

In [61]:
id2cat

{14: 'chair',
 22: 'train',
 0: 'airplane',
 18: 'table',
 33: 'speaker',
 48: 'sofa',
 17: 'monitor',
 13: 'car',
 29: 'jar',
 44: 'remote_control',
 31: 'lamp',
 54: 'bookshelf',
 42: 'pot',
 19: 'telephone',
 3: 'bathtub',
 9: 'cabinet',
 10: 'can',
 50: 'vessel',
 28: 'helmet',
 45: 'rifle',
 24: 'earphone',
 41: 'pistol',
 49: 'stove',
 5: 'bench',
 15: 'clock',
 2: 'basket',
 8: 'bus',
 25: 'faucet',
 39: 'piano',
 27: 'guitar',
 6: 'bottle',
 43: 'printer',
 38: 'mug',
 26: 'file',
 7: 'bowl',
 32: 'laptop',
 21: 'tower',
 37: 'motorcycle',
 36: 'microwave',
 34: 'mailbox',
 30: 'knife',
 23: 'keyboard',
 20: 'tin_can',
 53: 'birdhouse',
 40: 'pillow',
 16: 'dishwasher',
 51: 'washer',
 47: 'skateboard',
 4: 'bed',
 12: 'cap',
 46: 'rocket',
 11: 'camera',
 52: 'cellphone',
 1: 'bag',
 35: 'microphone'}

In [62]:
cat2id 

{'chair': 14,
 'train': 22,
 'airplane': 0,
 'table': 18,
 'speaker': 33,
 'sofa': 48,
 'monitor': 17,
 'car': 13,
 'jar': 29,
 'remote_control': 44,
 'lamp': 31,
 'bookshelf': 54,
 'pot': 42,
 'telephone': 19,
 'bathtub': 3,
 'cabinet': 9,
 'can': 10,
 'vessel': 50,
 'helmet': 28,
 'rifle': 45,
 'earphone': 24,
 'pistol': 41,
 'stove': 49,
 'bench': 5,
 'clock': 15,
 'basket': 2,
 'bus': 8,
 'faucet': 25,
 'piano': 39,
 'guitar': 27,
 'bottle': 6,
 'printer': 43,
 'mug': 38,
 'file': 26,
 'bowl': 7,
 'laptop': 32,
 'tower': 21,
 'motorcycle': 37,
 'microwave': 36,
 'mailbox': 34,
 'knife': 30,
 'keyboard': 23,
 'tin_can': 20,
 'birdhouse': 53,
 'pillow': 40,
 'dishwasher': 16,
 'washer': 51,
 'skateboard': 47,
 'bed': 4,
 'cap': 12,
 'rocket': 46,
 'camera': 11,
 'cellphone': 52,
 'bag': 1,
 'microphone': 35}

In [101]:
generate_category(13, 10)

In [70]:
root = '..'
dataset_name = 'shapenetcorev2'

test_dataset = RobustAEDataset(root=root, dataset_name=dataset_name, split='test', class_choice='birdhouse')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [71]:
for i, batch in enumerate(test_loader):
    ori_pc = batch['points'][0].numpy().T
    selected_generated_pc_latent, real_pc_latent, generated_pc, real_pc = lit_imle_model(batch, 0)
    generated_pc = generated_pc[0].detach().cpu().numpy().T
    break

In [37]:
label_list = batch['label'].squeeze(-1).tolist()
real_pc = batch['points'].cuda()
real_pc_enc = batch['points_encoded'].cuda()                                     
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
with torch.no_grad():
    label_latent = torch.stack([lit_imle_model.label_latents[i] for i in label_list]).squeeze(-1).cuda()
    real_pc_latent = lit_imle_model.autoencoder.autoencoder.encoder(real_pc_enc)
    
generated_pc_latent = lit_imle_model.imle_gen(label_latent, real_pc_latent)
generated_pc = lit_imle_model.autoencoder.autoencoder.decoder(generated_pc_latent[0])    
generated_pc = generated_pc[0].detach().cpu().numpy().T        

In [38]:
print(label_list)

[27]


In [72]:
mitsuba(ori_pc, 'results/original.xml')
mitsuba(generated_pc, 'results/generated.xml')