# ShapeNet PointCloud Visualization
- By running this script, you can get images of pointclouds.
## To run this code...
- You should prepare the summary file by running sample_and_summarize.py with a trained checkpoint.
- You should install below libraries.
    - matplotlib
    - open3d
    - numpy
    - torch
    - torchvision
    - tqdm

In [1]:
import os
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
#import open3d as o3d

import numpy as np
import torch
from torchvision.utils import save_image

from draw import draw, draw_pointcloud

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Set directories
1. experiment_name: log_name.lstrip('gen/') in scripts
2. save_dir: path to save images

In [2]:
save_dir = 'images_attn'
experiment_name = 'shapenet15k-airplane/n_samples_2000/'
summary_name = os.path.join('/data/rna_rep_learning/nmsingh/scSet_ckpts/sample_complexity/', experiment_name, 'summary.pth')
summary_train_name = os.path.join('/data/rna_rep_learning/nmsingh/scSet_ckpts/sample_complexity/', experiment_name, 'summary_train_recon.pth')

imgdir = os.path.join(save_dir, experiment_name)
imgdir_gt = os.path.join(imgdir, 'gt')
imgdir_recon = os.path.join(imgdir, 'recon')
imgdir_gen = os.path.join(imgdir, 'gen')
imgdir_gt_train = os.path.join(imgdir, 'gt_train')

os.makedirs(save_dir, exist_ok=True)
os.makedirs(imgdir_gt, exist_ok=True)
os.makedirs(imgdir_recon, exist_ok=True)
os.makedirs(imgdir_gen, exist_ok=True)
os.makedirs(imgdir_gt_train, exist_ok=True)

In [3]:
summary = torch.load(summary_name)
for k, v in summary.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {len(v)}")

smp_set: torch.Size([405, 2500, 3])
smp_mask: torch.Size([405, 2500])
smp_att: 7
priors: 8
recon_set: torch.Size([405, 2500, 3])
recon_mask: torch.Size([405, 2500])
posteriors: 8
dec_att: 7
enc_att: 7
enc_hiddens: 13
gt_set: torch.Size([405, 2048, 3])
gt_mask: torch.Size([405, 2048])
mean: torch.Size([405, 1, 3])
std: torch.Size([405, 1, 1])
sid: 13
mid: 13
pid: 13
cardinality: 13


In [None]:
# train
summary_train = torch.load(summary_train_name)
for k, v in summary_train.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {len(v)}")
len_att_train = len(summary_train['dec_att'])

## Select the samples to visualize
- parse the samples by index.
- below default code will visualize all samples. **Warning: Requires Huge Memory**

In [5]:
n_viz=10
gen_targets=list(range(len(summary['smp_mask'])))[:n_viz]
recon_targets=list(range(len(summary['gt_mask'])))[:n_viz]

gen = summary['smp_set'][gen_targets]
if 'smp_mask' in summary.keys():
    gen_mask = summary['smp_mask'][gen_targets]
else:
    gen_mask = torch.zeros_like(gen)[:,:,0].bool()
gt = summary['gt_set'][recon_targets]
gt_mask = summary['gt_mask'][recon_targets]
recon = summary['recon_set'][recon_targets]
recon_mask = summary['recon_mask'][recon_targets]

In [None]:
recon_targets_train = list(range(len(summary_train['gt_mask'])))[:n_viz]

gt_train = summary_train['gt_set'][recon_targets_train]
gt_mask_train = summary_train['gt_mask'][recon_targets_train]
enc_att_train = [summary_train['enc_att'][l][:, :, recon_targets_train] for l in range(len_att_train)]

## Visualize

In [9]:
def visualize(gt, gt_mask):
    return draw_pointcloud(gt, gt_mask)

### Visualize Recon

In [11]:
recon_imgs = visualize(recon, recon_mask)
for idx in range(len(recon_targets)):
    data_idx = recon_targets[idx]
    if torch.nonzero(recon_imgs[idx].float().mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    pos_min = torch.nonzero(recon_imgs[idx].float().mean(0) != 1).min(0)[0]
    pos_max = torch.nonzero(recon_imgs[idx].float().mean(0) != 1).max(0)[0]
    recon_img = recon_imgs[idx][:, pos_min[0]:pos_max[0]+1, pos_min[1]:pos_max[1]+1]
    save_image(recon_img.float(), os.path.join(imgdir_recon, f'{data_idx}.png'))
del recon_img

RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Byte

### Visualize GT

In [None]:
gt_imgs = visualize(gt, gt_mask)
for idx in range(len(recon_targets)):
    data_idx = recon_targets[idx]
    if torch.nonzero(gt_imgs[idx].mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    pos_min = torch.nonzero(gt_imgs[idx].mean(0) != 1).min(0)[0]
    pos_max = torch.nonzero(gt_imgs[idx].mean(0) != 1).max(0)[0]
    gt_img = gt_imgs[idx][:, pos_min[0]:pos_max[0]+1, pos_min[1]:pos_max[1]+1]
    save_image(gt_img, os.path.join(imgdir_gt, f'{data_idx}.png'))
del gt_imgs
print("DONE")

### Visualize Generated Samples

In [None]:
gen_imgs = visualize(gen, gen_mask)
for idx in range(len(gen_targets)):
    if torch.nonzero(gen_imgs[idx].mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    data_idx = gen_targets[idx]
    pos_min = torch.nonzero(gen_imgs[idx].mean(0) != 1).min(0)[0]
    pos_max = torch.nonzero(gen_imgs[idx].mean(0) != 1).max(0)[0]
    gen_img = gen_imgs[idx][:, pos_min[0]:pos_max[0]+1, pos_min[1]:pos_max[1]+1]
    save_image(gen_img.float(), os.path.join(imgdir_gen, f'{data_idx}.png'))
del gen_imgs
print("DONE")

### Visualize Train Data

In [None]:
gt_imgs = visualize(gt_train, gt_mask_train)
for idx in range(len(recon_targets_train)):
    data_idx = recon_targets_train[idx]
    if torch.nonzero(gt_imgs[idx].mean(0) != 1).shape[0] == 0:
        print("SKIP")
        continue
    pos_min = torch.nonzero(gt_imgs[idx].mean(0) != 1).min(0)[0]
    pos_max = torch.nonzero(gt_imgs[idx].mean(0) != 1).max(0)[0]
    gt_img = gt_imgs[idx][:, pos_min[0]:pos_max[0]+1, pos_min[1]:pos_max[1]+1]
    save_image(gt_img, os.path.join(imgdir_gt_train, f'{data_idx}.png'))
del gt_imgs
print('DONE')