In [37]:
import sys
import os
import functools
import h5py
import torch
import logging
import argparse
import shutil
import yaml
from datetime import datetime
from tqdm import tqdm
from utils import *
from flowdas import ScoreNet, marginal_prob_std, Euler_Maruyama_sampler
import subprocess
import matplotlib.pyplot as plt
from glob import glob


def prepare():
    set_seed(427)

    parser = argparse.ArgumentParser(description='FlowDAS Evaluation')
    parser.add_argument('--config', type=str, default='eval_win1_G', 
                        help='Path to evaluation config')
    parser.add_argument('--N_trajectory', type=int, default=64, 
                        help='Number of trajectories to evaluate')
    parser.add_argument('--LT', type=int, default=15, 
                        help='Number of testing states of each trajectory')
    args = parser.parse_args()
    
    config_path = PATH / 'config' / f'{args.config}.yml'
    config = get_config(config_path)

    if args.N_trajectory is not None:
        config['N_trajectory'] = args.N_trajectory
    if args.LT is not None:
        config['LT'] = args.LT

    # Create a path that includes both timestamp and config name (without extension)
    timestamp = datetime.now().strftime("%m%d_%H%M%S")
    runpath = PATH / 'runs_eval' / f'run_{timestamp}_{args.config}'
    runpath.mkdir(parents=True, exist_ok=True)
    setup_evaluation_logging(runpath)


    # Log the configuration
    logging.info("Evaluation configuration:")
    for key, value in config.items():
        logging.info(f"  {key}: {value}")
    
    logging.info("\n\n")    


    # Update config with runpath and num_workers
    # assert config['LT'] >= config['window'], "LT must be greater than or equal to window"
    config['runpath'] = runpath
    config['num_workers'] = int(subprocess.check_output(['nproc']).strip())
    config['marginal_prob_std_fn'] = functools.partial(marginal_prob_std, sigma=config['sigma'])


def get_flow_prior(config):
    flow_prior = ScoreNet(
        marginal_prob_std=config['marginal_prob_std_fn'],
        x_dim=config['x_dim'],
        extra_dim=config['extra_dim']*config['window'],
        hidden_depth=config['hidden_depth'],
        embed_dim=config['embed_dim'],
        use_bn=config['use_bn']
    ).to(config['device'])

    try:
        ckp_path = config[f'checkpoint_path']
    except KeyError:
        raise ValueError(f"Window size {config['window']} not supported")

    flow_prior = load_checkpoint(flow_prior, ckp_path)
    return flow_prior


def _to_torch(arr_like, dtype=np.float32):
    # 1) materialize as a real ndarray with desired dtype
    arr = np.asarray(arr_like, dtype=dtype)

    # 2) ensure writable + C-contiguous (HDF5 slices are often read-only)
    if (not arr.flags['C_CONTIGUOUS']) or (not arr.flags['WRITEABLE']):
        arr = arr.copy(order='C')  # makes it contiguous & writeable

    # 3) try zero-copy; if PyTorch still complains, fall back to copy
    try:
        return torch.from_numpy(arr)
    except TypeError:
        # last resort—copy into a fresh tensor (handles weird subclasses)
        return torch.tensor(arr, dtype=torch.float32)


def create_observations(config):
    """
    Create observation files for the combined-para dataset.
    
    The obs.h5 file contains the first (L+1) steps of the testing trajectory.
    
    Args:
        config: Configuration dictionary with parameters
        
    Input data shape: 
        x: (N, L+1, 3) - N trajectories with L+1 timesteps of 3D coordinates
        
    Output data shape:
        obs: (N, L+1, 1) - Observations for each trajectory and timestep
    """
    logging.info("Creating observations...")
    path_dataset = config[f'path_dataset']
    
    # Define the observation file path
    obs_file_path = f'{path_dataset}/obs_L{config["LT"]}_win{config["window"]}.h5'
    
    # Read input data
    with h5py.File(f'{path_dataset}/test.h5', mode='r') as f:
        x = f['data'][:, :config['LT']+config['window']]  # shape: (N, L+w, 3)
    
    # Check if the observation file already exists
    if os.path.exists(obs_file_path):
        # Generate what the new observations would be
        # x_tensor = torch.from_numpy(x)
        arr = x[...]            # shape e.g. (L, 3)
        x_tensor = _to_torch(arr, dtype=np.float32)
        new_obs = observation_generator(x_tensor, config['sigma_obs_hi'])
        
        # Load existing observations
        with h5py.File(obs_file_path, mode='r') as f:
            existing_obs = f['obs'][:]
        
        # Compare existing and new observations
        # if np.array_equal(existing_obs, new_obs.numpy()):
        if np.allclose(existing_obs, new_obs.numpy(), rtol=1e-3, atol=1e-3):
            logging.info(f"Existing observations file is identical. Skipping creation.")
        else:
            logging.error(f"Existing observations file contains different data!")
            logging.error(f"This requires attention. Exiting evaluation.")
            sys.exit(1)
    else:
        # Create new observations file
        with h5py.File(obs_file_path, mode='w') as f:
            # x_tensor = torch.from_numpy(x)
            arr = x[...]            # shape e.g. (L, 3)
            x_tensor = _to_torch(arr, dtype=np.float32)
            obs = observation_generator(x_tensor, config['sigma_obs_hi'])
            f.create_dataset('obs', data=obs)
            logging.info(f"Created new observations file at {obs_file_path}")

    logging.info("Observations created successfully.\n\n")

In [16]:
if 'SCRATCH' in os.environ:
    SCRATCH = os.environ['SCRATCH']
    PATH = Path(SCRATCH) / 'sda/lorenz'
else:
    PATH = Path('.')

config_path = PATH / 'config' / f'eval_win1.yml'
config = get_config(config_path)

dataset_class = TrajectoryDatasetV2 # if config['study_generalizability'] else TrajectoryDataset
data_dir = 'data'

# Load datasets
# data_dir = 
train_path = PATH / f'{data_dir}/dataset/train.h5'
valid_path = PATH / f'{data_dir}/dataset/valid.h5'
trainset = dataset_class(train_path, window=config['window'])
validset = dataset_class(valid_path, window=config['window'])

In [19]:
timestamp = datetime.now().strftime("%m%d_%H%M%S")
runpath = PATH / 'runs_eval' / f"run_{timestamp}_{'eval_win1'}"
runpath.mkdir(parents=True, exist_ok=True)
config['runpath'] = runpath
config['num_workers'] = int(subprocess.check_output(['nproc']).strip())
config['marginal_prob_std_fn'] = functools.partial(marginal_prob_std, sigma=config['sigma'])

flow_prior = get_flow_prior(config)

In [41]:
create_observations(config)

ValueError: object __array__ method not producing an array

In [30]:
# Ground truth
n = 0
path_dataset = config[f'path_dataset']
with h5py.File(f'{path_dataset}/test.h5', mode='r') as f:
    arr = f['data'][n][...]            # shape e.g. (L, 3)
    gt = _to_torch(arr, dtype=np.float32)
    # gt = torch.from_numpy(f['data'][n]).to(config['device']) 
    gt = gt[:config['LT']+config['window']] # shape: (L+1, 3)

# Observation
with h5py.File(f'{path_dataset}/obs_L{config["LT"]}_win{config["window"]}.h5', mode='r') as f:
    # TODO: deal with window size.
    arr = f['obs'][n][...]            # shape e.g. (L, 3)
    obs = _to_torch(arr, dtype=np.float32)
    # obs = torch.from_numpy(f['obs'][n]).to(config['device']) # shape: (L+1, 1)

flow_prior = get_flow_prior(config)

gt_win, obs_win = gt, obs
initial_cond = gt[:config['window']].reshape(1, -1)

# Monte Carlo sampling
cond_win = []
cond_win.append(initial_cond)
est_all_win = [gt[l, :].unsqueeze(0) for l in range(config['window'])]

x_t_gen = Euler_Maruyama_sampler(
                flow_prior,
                num_steps=config['num_steps'],
                device=config['device'],
                base=est_all_win[i+config['window']-1],
                cond=cond_win[i],
                measurement=obs_win[i+1, :],
                noisy_level=config['sigma_obs_hi'],
                MC_times=config['N_MC'], 
                batch_size=1, # TODO: why batch size is 1?
                step_size=config['step_size'] 
            ) # shape: (B, 3*window)


FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = './data/dataset/obs_L15_win1.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [31]:
f.close()

In [32]:
path_dataset

'./data/dataset'