In [2]:
'''
Notebook to analyze and compare across sweep runs 
'''

import os

import cmws
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from cmws import util
from cmws.examples.scene_understanding import data, render, run, plot
from cmws.examples.scene_understanding import util as scene3d_util 

In [11]:
experiment_name = "cmws_vs_rws"

save_dir = f"../save/{experiment_name}"
checkpoint_paths = [] 
for config_name in sorted(os.listdir(save_dir)):
    checkpoint_paths.append(util.get_checkpoint_path(experiment_name, config_name, -1)) 

In [12]:
checkpoint_paths

['save/cmws_vs_rws/cmws_2_2_0/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_2_1/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_2_2/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_2_3/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_2_4/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_3_0/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_3_1/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_3_2/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_3_3/checkpoints/latest.pt',
 'save/cmws_vs_rws/cmws_2_3_4/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_2_0/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_2_1/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_2_2/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_2_3/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_2_4/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_3_0/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_3_1/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_3_2/checkpoints/latest.pt',
 'save/cmws_vs_rws/rws_3_3/checkpoints/latest.pt',
 

In [20]:
for grid_size in [2,3]: # grid size
    fig, axs = plt.subplots(1, 2, figsize=(2 * 6, 1 * 4))

    colors = {"cmws_2": "C0", "rws": "C1"}
    for checkpoint_path in checkpoint_paths:
        checkpoint_path = f"../{checkpoint_path}"
        
        # Fix seed
        util.set_seed(1)

        if os.path.exists(checkpoint_path):
            # Load checkpoint
            model, optimizer, stats, run_args = scene3d_util.load_checkpoint(
                checkpoint_path, device="cpu"
            )
            
            if run_args.num_grid_cols != grid_size: continue 
            
            generative_model, guide = model["generative_model"], model["guide"]
            num_iterations = len(stats.losses)
            if run_args.model_type != model_type:
                continue

            label = run_args.algorithm if run_args.seed == 0 else None
            color = colors[run_args.algorithm]
            plot_kwargs = {"label": label, "color": color, "alpha": 0.8, "linewidth": 1.5}

            # Logp
            ax = axs[0]
            ax.plot([x[0] for x in stats.log_ps], [x[1] for x in stats.log_ps], **plot_kwargs)

            # KL
            ax = axs[1]
            ax.plot([x[0] for x in stats.kls], [x[1] for x in stats.kls], **plot_kwargs)
    ax = axs[0]
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Log p")

    ax = axs[1]
    ax.set_xlabel("Iteration")
    ax.set_ylabel("KL")
    # ax.set_ylim(0, 2000)
    ax.legend()
    for ax in axs:
        # ax.set_xlim(0, 20000)
        sns.despine(ax=ax, trim=True)
    util.save_fig(fig, f"{save_dir}/losses_{grid_size}.png", dpi=200)

12:54:40 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:40 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:40 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1023.94it/s]


12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1043.96it/s]

12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1064.18it/s]


12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1053.68it/s]


12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1054.35it/s]


12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1014.83it/s]


12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:41 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:41 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1034.74it/s]


12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:42 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1063.13it/s]


12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:42 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1034.40it/s]

12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:42 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:42 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1036.39it/s]


12:54:43 | /om/user/katiemc/continuous_mws/cmws/util.py:289 | INFO: Saved to ../save/cmws_vs_rws/losses_2.png
12:54:43 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:43 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:43 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1061.19it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1057.33it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1030.79it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1053.62it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1046.60it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory



100%|██████████| 100/100 [00:00<00:00, 1006.50it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...





12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1025.62it/s]

12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...





12:54:44 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:44 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 996.20it/s]


12:54:45 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:45 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:45 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1022.00it/s]


12:54:45 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:184 | INFO: Loading dataset (test = False)...
12:54:45 | /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data.py:188 | INFO: Dataset (test = False) loaded /om/user/katiemc/continuous_mws/cmws/examples/scene_understanding/data/1_1/train.pt
12:54:45 | /om/user/katiemc/continuous_mws/cmws/memory.py:20 | INFO: Initializing memory


100%|██████████| 100/100 [00:00<00:00, 1015.54it/s]


12:54:46 | /om/user/katiemc/continuous_mws/cmws/util.py:289 | INFO: Saved to ../save/cmws_vs_rws/losses_3.png


In [19]:
run_args.num_grid_cols

3