# VAE Reconstruction Test

This notebook loads a trained `VectorGraphRVQVAE` model and visualizes its reconstruction capabilities on the QuickDraw dataset.

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import sys
import os

# Add src to path
sys.path.append(os.path.abspath("."))

from src.model.vae import VectorGraphRVQVAE
from src.dataset.dataset import QuickDrawGraphDataModule

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Configuration
CKPT_PATH = "ckpt_rvqvae/rvqvae-epoch=011-val/loss=4.973774.ckpt"
DATA_PATH = "./data/quickdraw_graphs.pkl"
BATCH_SIZE = 8

# Load Model
print(f"Loading model from {CKPT_PATH}...")
model = VectorGraphRVQVAE.load_from_checkpoint(CKPT_PATH)
model.to(device)
model.eval()
print("Model loaded.")

In [None]:
# Load Data
print(f"Loading data from {DATA_PATH}...")
dm = QuickDrawGraphDataModule(
    data_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=0, # Avoid multiprocessing issues in notebook
    val_split=0.1,
    max_nodes=model.hparams.max_nodes,
    max_neighbors=model.hparams.max_neighbors
)
dm.setup()
val_loader = dm.val_dataloader()
print("Data loaded.")

In [None]:
def plot_reconstruction_grid(model, batch, num_samples=4, exist_thr=0.5, edge_thr=0.5):
    x_pad = batch["x_pad"].to(device)
    mask = batch["mask"].to(device)
    neighbors = batch["neighbors"].to(device)
    ei_list = batch["edge_index"]

    with torch.no_grad():
        out = model(x_pad, mask=mask, neighbors=neighbors)
    
    coords_pred = out["coords_pred"]
    exist_logits = out["exist_logits"]
    q_embed = out["q_embed"]

    num_samples = min(num_samples, x_pad.size(0))
    fig, axes = plt.subplots(num_samples, 2, figsize=(10, 4 * num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_samples):
        # --- Ground Truth ---
        ax_gt = axes[i, 0]
        gt_coords = x_pad[i]
        gt_mask = mask[i]
        gt_edges = ei_list[i]
        
        # Convert to float32 before numpy for bfloat16 compatibility
        gt_coords_np = gt_coords[gt_mask].detach().float().cpu().numpy()
        gt_edges_np = gt_edges.detach().cpu().numpy()
        
        if len(gt_coords_np) > 0:
            ax_gt.scatter(gt_coords_np[:, 0], gt_coords_np[:, 1], c='blue', s=10, label='Node')
            for k in range(gt_edges_np.shape[1]):
                u, v = gt_edges_np[:, k]
                if u < len(gt_coords_np) and v < len(gt_coords_np):
                    ax_gt.plot([gt_coords_np[u, 0], gt_coords_np[v, 0]], 
                               [gt_coords_np[u, 1], gt_coords_np[v, 1]], 'b-', alpha=0.3, linewidth=0.5)
        
        ax_gt.set_title(f"Sample {i} - Ground Truth")
        ax_gt.invert_yaxis()
        ax_gt.axis('equal')
        ax_gt.axis('off')

        # --- Prediction ---
        ax_pred = axes[i, 1]
        pred_coords = coords_pred[i]
        pred_exist = torch.sigmoid(exist_logits[i])
        
        keep_mask = pred_exist > exist_thr
        
        # Edge prediction
        active_mask = keep_mask.unsqueeze(0)
        q_embed_1 = q_embed[i].unsqueeze(0)
        coords_pred_1 = pred_coords.unsqueeze(0)
        
        edge_logits = model.edge_head(q_embed_1, coords_pred_1, active_mask=active_mask)
        edge_probs = torch.sigmoid(edge_logits[0])
        
        pred_coords_np = pred_coords.detach().float().cpu().numpy()
        keep_idx = torch.where(keep_mask)[0].cpu().numpy()
        
        if len(keep_idx) > 0:
            kept_coords = pred_coords_np[keep_idx]
            ax_pred.scatter(kept_coords[:, 0], kept_coords[:, 1], c='red', s=10, label='Pred')
            
            # Edges
            r, c = torch.triu_indices(len(pred_exist), len(pred_exist), offset=1, device=device)
            p_vals = edge_probs[r, c]
            valid_edges = p_vals > edge_thr
            
            r_valid = r[valid_edges].cpu().numpy()
            c_valid = c[valid_edges].cpu().numpy()
            
            for u, v in zip(r_valid, c_valid):
                ax_pred.plot([pred_coords_np[u, 0], pred_coords_np[v, 0]], 
                             [pred_coords_np[u, 1], pred_coords_np[v, 1]], 'r-', alpha=0.3, linewidth=0.5)
                             
        ax_pred.set_title(f"Sample {i} - Reconstruction")
        ax_pred.invert_yaxis()
        ax_pred.axis('equal')
        ax_pred.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Run Visualization
batch = next(iter(val_loader))
plot_reconstruction_grid(model, batch, num_samples=8, exist_thr=0.5, edge_thr=0.5)