In [1]:
import os
import torch
import matplotlib.pyplot as plt

from cosmo_compression.data import data
from cosmo_compression.model import represent

In [None]:
# 1) Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt_path = '~/latent_ablation_workshop_outc/no_hierarchical_8/step=step=15500-val_loss=0.327.ckpt'   # ← change to your .ckpt
dataset_path = '/monolith/global_data/astro_compression/CAMELS/'

# 2) Load model
model = represent.CosmoFlow.load_from_checkpoint(ckpt_path)
model = model.to(device).eval()

# 3) Prepare dataset (single‐sample access)
cdm_dataset = data.CAMELS(
    root=dataset_path,
    idx_list=range(0, 1),
    map_type='Mcdm',
    suite='Astrid',
    dataset='1P',
    parameters=['Omega_m','sigma_8','A_SN1','A_SN2','A_AGN1','A_AGN2','Omega_b'],
)

# 4) Grab one image
img, cosmo_params = cdm_dataset[0]         # img shape = [C, H, W]
img_tensor = torch.tensor(img).unsqueeze(0).to(device)   # add batch dim → [1, C, H, W]

# 5) Reconstruct
with torch.no_grad():
    latent = model.encoder(img_tensor)

    # Available solvers are 'euler', 'rk4', and 'dopri5'
    recon = model.decoder.predict(
        x0=torch.randn_like(img_tensor),
        h=latent,
        n_sampling_steps=30,
        solver='rk4',
    )

# 6) Plot
orig = img_tensor.cpu().squeeze().permute(1,2,0).numpy()
rec  = recon.cpu().squeeze().permute(1,2,0).numpy()

fig, axes = plt.subplots(1,2, figsize=(8,4))
axes[0].imshow(orig, cmap='viridis')
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow(rec, cmap='viridis')
axes[1].set_title('Reconstruction')
axes[1].axis('off')

AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'