In [35]:
"""
    adopted from: https://github.com/shubhtuls/PixelTransformer/blob/03b65b8612fe583b3e35fc82b446b5503dd7b6bd/data/shapenet.py
"""
import os.path
import json

import h5py
import numpy as np
from termcolor import colored, cprint

import torch
import torch.nn.functional as F
import torchvision.utils as vutils
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.transforms.functional import InterpolationMode
from utils.util_3d import sdf_to_mesh
from pytorch3d.ops import sample_points_from_meshes
from datasets.base_dataset import BaseDataset
from utils.demo_util import preprocess_image
import glob

# from https://github.com/laughtervv/DISN/blob/master/preprocessing/info.json
class ShapData(BaseDataset):

    def initialize(self, opt, phase='train', cat='all', res=64, few=False):
        self.opt = opt
        self.max_dataset_size = 1000000
        self.res = res
        self.few = few
        self.phase = phase
        dataroot = "/home/amac/SDFusion/data"
        # with open(f'{dataroot}/ShapeNet/info.json') as f:
        with open(f'dataset_info_files/info-shapenet.json') as f:
            self.info = json.load(f)
            
        self.cat_to_id = self.info['cats']
        self.id_to_cat = {v: k for k, v in self.cat_to_id.items()}
        
        if cat == 'all':
            all_cats = self.info[phase + "_cats"]
        else:
            all_cats = [cat]
    
        all_imgs = glob.glob("/home/amac/data/ShapeNet55_3DOF-VC_LRBg/*/*/image_output/*.png")

        self.model_list = []
        self.cats_list = []
        model_ids = []
        self.cat2model_list = {}
        with open("data_split.json", "r") as f:
            splo = json.load(f)
        splo = splo[phase]
        self.splo = splo
        for c in all_cats:
            synset = self.info['cats'][c]
            # with open(f'{dataroot}/ShapeNet/filelists/{synset}_{phase}.lst') as f:
            model_list_s = []
            for l in splo[synset]:
                model_id = l.rstrip('\n')
                if res == 64:

                    path = f'{dataroot}/ShapeNet/SDF_v1_64/{synset}/{model_id}/ori_sample.h5'
                else:

                    path = f'{dataroot}/ShapeNet/SDF_v2/resolution_{self.res}/{synset}/{model_id}/ori_sample_grid.h5'
                
                if os.path.exists(path):
                    model_list_s.append(path)
                    if model_id not in self.cat2model_list:

                        self.cat2model_list[synset] = [path]
                    else:
                        self.cat2model_list[synset].append(path)
                model_ids.append(model_id)


            model_list_s = list(np.random.default_rng(seed=0).choice(model_list_s, int(len(model_list_s)*0.1)))
            print(model_list_s[:10])
            self.model_list += model_list_s
            self.cats_list += [synset] * len(model_list_s)
            print('[*] %d samples for %s (%s).' % (len(model_list_s), self.id_to_cat[synset], synset))
        
        self.model2views = {model_id: glob.glob(f"/home/amac/data/ShapeNet55_3DOF-VC_LRBg/*/{model_id}/image_output/*.png") for model_id in model_ids}
        np.random.default_rng(seed=0).shuffle(self.model_list)
        np.random.default_rng(seed=0).shuffle(self.cats_list)
        self.model_list = self.model_list[:self.max_dataset_size]
        cprint('[*] %d samples loaded.' % (len(self.model_list)), 'yellow')

        self.N = len(self.model_list)
        
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.Resize((256, 256)),
        ])
    
    def __getitem__(self, index):
                
        synset = self.cats_list[index]
        sdf_h5_file = self.model_list[index]
        assert synset == sdf_h5_file.split("/")[-3]
        h5_f = h5py.File(sdf_h5_file, 'r')
        sdf = h5_f['pc_sdf_sample'][:].astype(np.float32)

        sdf = torch.Tensor(sdf).view(1, self.res, self.res, self.res)
        # print(sdf.shape)
        # sdf = sdf[:, :64, :64, :64]

        thres = 0.2
        if thres != 0.0:
            sdf = torch.clamp(sdf, min=-thres, max=thres)
        z = np.load(sdf_h5_file.replace("ori_sample.h5", "latent_code.npy"), allow_pickle=True).squeeze(0)
        if self.phase == "test" or self.phase == "val":
            view_idx = index % len(self.model2views[sdf_h5_file.split("/")[-2]])
            view = self.model2views[sdf_h5_file.split("/")[-2]][view_idx]
        else:
            view = np.random.choice(self.model2views[sdf_h5_file.split("/")[-2]], 1)[0]
        
        _, img = preprocess_image(str(view), str(view).replace("image_output", "segmentation"))
        img = self.transforms(img)
       
        ret = {
            'sdf': sdf,
            'z': z,
            'img': img,
            'cat_id': synset,
            'cat_str': self.id_to_cat[synset],
            'path': sdf_h5_file,
        }
        if self.few:
            listo = self.cat2model_list[synset]
            sup_path = np.random.choice(listo, 1)[0]
            sup_code = np.load(sup_path.replace("ori_sample.h5", "latent_code.npy")).squeeze()
            ret["sup_z"] = sup_code
        return ret

    def __len__(self):
        return self.N

    def name(self):
        return 'ShapeNetImg2ShapeDataset'

In [37]:
from torch.utils.data import DataLoader
data = ShapData()
data.initialize(None, "test", few=True)

['/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/dbe05209a14fca8fdf72e713dd4f492a/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/2212a794bfca650384d5ba37e7a649b7/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/355a85d0156984c75e559927dcb9417c/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/1dce61f6dd85dc469811751e3fab8939/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/8e1778cdc0bfec3e18693dd92ffa710d/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/5aac718c51fc73ca00223dcc18ecf69/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/63f170670881b2deaf6320700e3cf173/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/4f245403e6366d48fb3294f1e40c8a29/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/8b25d01f3fd3d5a373e9b20a29bc1d50/ori_sample.h5', '/home/amac/SDFusion/data/ShapeNet/SDF_v1_64/04530566/4d8ae6327ab4ed301e6

In [38]:
dl = DataLoader(data, batch_size=8, num_workers=8)

In [39]:
# options for the model. please check `utils/demo_util.py` for more details
from utils.demo_util import SDFusionImage2ShapeOpt

seed = 2023
opt = SDFusionImage2ShapeOpt(gpu_ids=0, seed=seed)
device = "cuda"


[*] SDFusionImage2ShapeOption initialized.


In [40]:
from models.base_model import create_model

ckpt_path = 'logs/img2shapeshapenet_frozen_clip/ckpt/df_400k.pth'

opt.init_model_args(ckpt_path=ckpt_path, vq_ckpt_path="/home/amac/develop/SDFusion/logs_home/2023-06-25T13-48-37-vqvae-snet-all-res64-LR1e-4-T0.2-release/ckpt/vqvae_epoch-best.pth")
opt.model = "sdfusion-img2shape"
SDFusion = create_model(opt)
cprint(f'[*] "{SDFusion.name()}" loaded.', 'cyan')

Working with z of shape (1, 3, 16, 16, 16) = 12288 dimensions.
[34m[*] VQVAE: weight successfully load from: /home/amac/develop/SDFusion/logs_home/2023-06-25T13-48-37-vqvae-snet-all-res64-LR1e-4-T0.2-release/ckpt/vqvae_epoch-best.pth[0m
[34m[*] weight successfully load from: logs/img2shapeshapenet_frozen_clip/ckpt/df_400k.pth[0m
[34m[*] setting ddim_steps=100[0m
[34m[*] Model has been created: SDFusionImage2ShapeModel[0m
[36m[*] "SDFusionImage2ShapeModel" loaded.[0m


In [41]:
ngen = 1 # number of generated shapes
ddim_steps = 100
ddim_eta = 0.
uc_scale = 3.
SDFusion.eval()

In [None]:
from utils.util_3d import render_sdf, render_mesh, sdf_to_mesh, save_mesh_as_gif
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
from tqdm import tqdm
import json
cd = {}
for data in tqdm(dl):
    SDFusion.inference(data, ddim_steps = ddim_steps, ddim_eta=ddim_eta, uc_scale=uc_scale)
    cat_strs = data["cat_str"]
    gen_df = SDFusion.gen_df
    meshes_gen = sdf_to_mesh(gen_df)
    gt_df = SDFusion.x
    meshes_gt = sdf_to_mesh(gt_df)
    points_gen = sample_points_from_meshes(meshes_gen).cuda()
    points_gt = sample_points_from_meshes(meshes_gt).cuda()


    cdist =  chamfer_distance(points_gen, points_gt, batch_reduction=None)[0].detach().cpu().numpy()
    print(cdist.shape)
    for i, cat in enumerate(cat_strs):
        if cat not in cd:
            cd[cat] = [cdist[i]]
        else:
            cd[cat].append(cdist[i])

result_dict = {}
for key,val in cd.items():
    result_dict[key] = sum(val) / len(val)

with open("result_baseline.json", "w+") as f:

    json.dump(result_dict, f)    


  0%|                                                                                            | 0/61 [00:00<?, ?it/s]

Data shape for DDIM sampling is (8, 3, 16, 16, 16), eta 0.0
Running DDIM Sampling with 100 timesteps



DDIM Sampler:   0%|                                                                             | 0/100 [00:00<?, ?it/s][A
DDIM Sampler:   1%|▋                                                                    | 1/100 [00:00<00:26,  3.70it/s][A
DDIM Sampler:   2%|█▍                                                                   | 2/100 [00:00<00:25,  3.79it/s][A
DDIM Sampler:   3%|██                                                                   | 3/100 [00:00<00:25,  3.82it/s][A
DDIM Sampler:   4%|██▊                                                                  | 4/100 [00:01<00:25,  3.83it/s][A
DDIM Sampler:   5%|███▍                                                                 | 5/100 [00:01<00:24,  3.84it/s][A
DDIM Sampler:   6%|████▏                                                                | 6/100 [00:01<00:24,  3.84it/s][A
DDIM Sampler:   7%|████▊                                                                | 7/100 [00:01<00:24,  3.85it/s][A
DDIM Sa

(8,)
Data shape for DDIM sampling is (8, 3, 16, 16, 16), eta 0.0
Running DDIM Sampling with 100 timesteps



DDIM Sampler:   0%|                                                                             | 0/100 [00:00<?, ?it/s][A
DDIM Sampler:   1%|▋                                                                    | 1/100 [00:00<00:25,  3.82it/s][A
DDIM Sampler:   2%|█▍                                                                   | 2/100 [00:00<00:25,  3.81it/s][A
DDIM Sampler:   3%|██                                                                   | 3/100 [00:00<00:25,  3.81it/s][A
DDIM Sampler:   4%|██▊                                                                  | 4/100 [00:01<00:25,  3.80it/s][A
DDIM Sampler:   5%|███▍                                                                 | 5/100 [00:01<00:24,  3.81it/s][A
DDIM Sampler:   6%|████▏                                                                | 6/100 [00:01<00:24,  3.81it/s][A
DDIM Sampler:   7%|████▊                                                                | 7/100 [00:01<00:24,  3.80it/s][A
DDIM Sa

(8,)
Data shape for DDIM sampling is (8, 3, 16, 16, 16), eta 0.0
Running DDIM Sampling with 100 timesteps



DDIM Sampler:   0%|                                                                             | 0/100 [00:00<?, ?it/s][A
DDIM Sampler:   1%|▋                                                                    | 1/100 [00:00<00:26,  3.80it/s][A
DDIM Sampler:   2%|█▍                                                                   | 2/100 [00:00<00:25,  3.80it/s][A
DDIM Sampler:   3%|██                                                                   | 3/100 [00:00<00:25,  3.80it/s][A
DDIM Sampler:   4%|██▊                                                                  | 4/100 [00:01<00:25,  3.80it/s][A
DDIM Sampler:   5%|███▍                                                                 | 5/100 [00:01<00:25,  3.80it/s][A
DDIM Sampler:   6%|████▏                                                                | 6/100 [00:01<00:24,  3.80it/s][A
DDIM Sampler:   7%|████▊                                                                | 7/100 [00:01<00:24,  3.80it/s][A
DDIM Sa