In [None]:
'''
Notebook to create "inference w/o color" figure (Fig 8, right)
'''

import os

import cmws
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from cmws import util
from cmws.examples.scene_understanding import data, render, run
from cmws.examples.scene_understanding import util as scene3d_util 
import seaborn as sns
import numpy as np
from cmws.examples.scene_understanding.plot import *
from cmws.examples.scene_understanding.util import importance_sample_memory
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


In [None]:
experiment_name = "cmws_vs_rws_noColor"
device = "cpu"
save_dir = f"../save/{experiment_name}"
checkpoint_paths = []
for config_name in sorted(os.listdir(save_dir)):
    checkpoint_paths.append(util.get_checkpoint_path(experiment_name, config_name, -1))
    
# manually choose best run
checkpoint_path = f'save/{experiment_name}/cmws_5_2_0.01_2/checkpoints/latest.pt'#checkpoint_paths[0]
checkpoint_path = f"../{checkpoint_path}"
model, optimizer, stats, run_args = scene3d_util.load_checkpoint(
                            checkpoint_path, device="cpu"
                        )
generative_model, guide = model["generative_model"], model["guide"]
num_iterations = len(stats.losses) # note: can use to filter out jobs!
save_dir = util.get_save_dir(run_args.experiment_name, run.get_config_name(run_args))

# Plot stats
plot_stats(f"{save_dir}/stats.png", stats)

# Plot reconstructions and other things
# Test data
# NOTE: Plotting the train dataset only
train_dataset = data.SceneUnderstandingDataset(
    device, run_args.num_grid_rows, run_args.num_grid_cols, test=False,
    remove_color=(run_args.remove_color == 1),
    mode=run_args.mode
)
obs, obs_id = train_dataset[:70]#train_dataset[50:80]
memory = model["memory"]

In [None]:
obs = obs.squeeze(1)
num_test_obs, num_channels, im_size, _ = obs.shape
im_size = 128
num_samples = 1

num_particles = memory.size
latent, log_weight = importance_sample_memory(
    num_particles, obs, obs_id, generative_model, guide, memory, im_size
)

num_blocks, stacking_program, raw_locations = latent

# Sort by log weight
# [num_test_obs, num_particles], [num_test_obs, num_particles]
_, sorted_indices = torch.sort(log_weight.T, descending=True)

# Sample predictions
# -- Expand obs
obs_expanded = obs[None].expand(num_particles, num_test_obs, 3, im_size, im_size)

In [None]:
test_obs_id = 11
particle_id = 0

sorted_particle_id = sorted_indices[test_obs_id, particle_id]

num_blocks_selected = num_blocks[sorted_particle_id, test_obs_id]
stacking_program_selected = stacking_program[sorted_particle_id, test_obs_id]
raw_locations_selected = raw_locations[sorted_particle_id, test_obs_id]

img = obs_expanded[0][test_obs_id].permute(1,2,0)
plt.imshow(img)

In [None]:
sampled_latent = (num_blocks_selected, stacking_program_selected, raw_locations_selected)
sampled_obs = generative_model.get_obs_loc(sampled_latent)

img = sampled_obs.permute(1,2,0).detach().numpy()
plt.imshow(img)

In [None]:
# run cell to generate final pdf

col_headersize = 11
# inches
text_width = 6.75
column_width = 6.5 / 2.
text_height = 9.
golden = (1 + 5 ** 0.5) / 2 # golden ratio

axis_size = 14
title_size = 20


# for the samples from memory
views = [0, -40, 40] # azimuths
num_views = len(views)
num_samples = 3 

example_idxs = [25+7,9,50+16]

num_primitives = generative_model.num_primitives

high_res_img = 256
generative_model.im_size = high_res_img

rows = len(example_idxs) # number of examples 
cols = 3 # observation + hmws 

# gridspec inside gridspec
f = plt.figure(figsize=(text_width, (3/2) * text_width), dpi=600)
gs0 = gridspec.GridSpec(rows, cols, figure=f)

for x in range(rows):
    
    test_obs_id = example_idxs[x]
    
    # Context images
    ax = f.add_subplot(gs0[x, 0])
    if not x: ax.set_title("Observations", fontsize=axis_size, pad=10.)
    img = obs_expanded[0][test_obs_id].permute(1,2,0)
    ax.imshow(img)
    ax.set_xticks([])
    ax.set_yticks([])
    
    for y in range(cols - 1):
        if y == 1: continue
        gs00 = gridspec.GridSpecFromSubplotSpec(num_samples, num_views, subplot_spec=gs0[x, y + 1:], 
                                                wspace=-0.57, hspace=0.0)
        # One row per sample
        for xx in range(3):
            # One column per view
            for yy in range(3):
                ax = f.add_subplot(gs00[xx, yy])
                particle_id = xx
                sorted_particle_id = sorted_indices[test_obs_id, particle_id]

                num_blocks_selected = num_blocks[sorted_particle_id, test_obs_id]
                stacking_program_selected = stacking_program[sorted_particle_id, test_obs_id]
                raw_locations_selected = raw_locations[sorted_particle_id, test_obs_id]

                sampled_latent = (num_blocks_selected, stacking_program_selected, raw_locations_selected)

                
                camera_azimuth = views[yy]
                if camera_azimuth == 0: camera_elevation = 0.1 # default
                else: camera_elevation = 30

                sampled_obs = generative_model.get_obs_loc(sampled_latent, (camera_elevation, camera_azimuth))

                img = sampled_obs.permute(1,2,0).detach().numpy()
                ax.imshow(img)
                ax.set_xticks([])
                ax.set_yticks([])

                if not x and not xx and yy == 1: # First outer row, first inner row, center plot
                    ax.set_title("Posterior Samples", fontsize=axis_size, pad=6.)

                if not yy and not y: # Leftmost view of samples
                    ax.set_ylabel(f'Sample {xx+1}', fontsize=10)
            
f.tight_layout()
f.suptitle("Inferring scene parse without color", fontsize=20)#, pad=10)
path = "noColor_samples.pdf"
util.save_fig(f, path, dpi=400)
plt.close(f)

In [None]:
# original layout!! 


# col_headersize = 11
# # inches
# text_width = 6.75
# column_width = 6.5 / 2.
# text_height = 9.
# golden = (1 + 5 ** 0.5) / 2 # golden ratio

# axis_size = 14
# title_size = 20

# # for the samples from memory
# views = [0, -40, 40] # azimuths
# num_views = len(views)
# num_samples = 3 

# example_idxs = [9,50+16,25+7]#[9, 11, 25+7]#[9, 15, 25+7] #[8, 14, 15, 20, 23, 26][10,11,12,13,27,37,39]

# num_primitives = generative_model.num_primitives

# high_res_img = 256

# rows = len(example_idxs) # number of examples 
# cols = 2 # observation + hmws 

# # gridspec inside gridspec
# f = plt.figure(figsize=(text_width, rows/cols * text_width), dpi=600)
# gs0 = gridspec.GridSpec(rows, cols, figure=f)

# for x in range(rows):
    
#     test_obs_id = example_idxs[x]
    
#     # Context images
#     ax = f.add_subplot(gs0[x, 0])
#     if not x: ax.set_title("Observations", fontsize=axis_size, pad=2.)
#     img = obs_expanded[0][test_obs_id].permute(1,2,0)
#     ax.imshow(img)
#     ax.set_xticks([])
#     ax.set_yticks([])
    
#     for y in range(cols - 1):
#         gs00 = gridspec.GridSpecFromSubplotSpec(num_samples, num_views, subplot_spec=gs0[x, y + 1], 
#                                                 wspace=-0.01, hspace=-0.3)
#         # One row per sample
#         for xx in range(3):
#             # One column per view
#             for yy in range(3):
#                 ax = f.add_subplot(gs00[xx, yy])
#                 particle_id = xx
#                 sorted_particle_id = sorted_indices[test_obs_id, particle_id]

#                 num_blocks_selected = num_blocks[sorted_particle_id, test_obs_id]
#                 stacking_program_selected = stacking_program[sorted_particle_id, test_obs_id]
#                 raw_locations_selected = raw_locations[sorted_particle_id, test_obs_id]

#                 sampled_latent = (num_blocks_selected, stacking_program_selected, raw_locations_selected)

                
#                 camera_azimuth = views[yy]
#                 if camera_azimuth == 0: camera_elevation = 0.1 # default
#                 else: camera_elevation = 30

#                 sampled_obs = generative_model.get_obs_loc(sampled_latent, (camera_elevation, camera_azimuth))

#                 img = sampled_obs.permute(1,2,0).detach().numpy()
#                 ax.imshow(img)
#                 ax.set_xticks([])
#                 ax.set_yticks([])

#                 if not x and not xx and yy == 1: # First outer row, first inner row, center plot
#                     ax.set_title("Posterior Samples", fontsize=axis_size, pad=10.)

#                 if not yy and not y: # Leftmost view of samples
#                     ax.set_ylabel(f'Sample {xx+1}', fontsize=10)

# f.suptitle("Inferring shape without color", fontsize=20)#, pad=10)
# f.tight_layout()
# path = "noColor_samples.pdf"
# util.save_fig(f, path, dpi=400)
# plt.close(f)