In [1]:
%load_ext autoreload
%autoreload 2
from inxss.utils_spectrum import calc_Sqw_from_Syy_Szz
from inxss.experiment import SimulatedExperiment

import torch
import numpy as np
from scipy.interpolate import RegularGridInterpolator

from inxss import SpectrumDataset, SpecNeuralRepr, Particle, PsiMask, OnlineVariance, linspace_2D_equidistant
from inxss.utils_visualization import arc_arrow, rad_arrow

import matplotlib.pyplot as plt

from tqdm import tqdm 
from inxss.experiment import Background, SimulatedExperiment
from inxss.steer_neutron import NeutronExperimentSteerer
from sklearn.model_selection import train_test_split

import os
from datetime import datetime

torch.set_default_dtype(torch.float32)

In [2]:
import hydra
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()

with initialize(config_path="conf"):
    cfg = compose(config_name="config_gauss")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="conf"):


In [3]:
spinw_data = torch.load(cfg['paths']['spinw_data_path'])

train_idx, val_test_idx = train_test_split(np.arange(spinw_data['Syy'].shape[0]), test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(val_test_idx, test_size=0.5, random_state=42)

result_dict = {}

In [4]:
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

In [5]:
num_steps = 50

scale_likelihood = cfg['likelihood']['scale']
likelihood_type = cfg['likelihood']['type']

In [6]:
time_stamp = datetime.now().strftime("%Y%m%d-%H%M")

if cfg['likelihood']['type'] == 'gaussian':
    output_path = os.path.join(
        cfg['paths']['output_path'],
        f"lkhd_{likelihood_type}_std_{cfg['likelihood']['std']}_scaled_{scale_likelihood}_steps_{num_steps}_{time_stamp}"
    )
else:
    output_path = os.path.join(
        cfg['paths']['output_path'],
        f"lkhd_{likelihood_type}_scaled_{scale_likelihood}_steps_{num_steps}_{time_stamp}"
    )

if not os.path.exists(output_path):
    os.makedirs(output_path)

print('output_path:', output_path)

output_path: /pscratch/sd/z/zhantao/inxs_steering_production/benchmarks/lkhd_gaussian_scaled_True_steps_50_20240202-1459


In [7]:
model_path = cfg['paths']['model_path']
data = torch.load(cfg['paths']['data_path'])
print(data.keys())

dict_keys(['grid', 'S', 'background', 'background_dict'])


In [8]:
global_mask = (data['S']>0).bool()

background = Background(
    tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]), 
    data['grid']['w_grid'], 
    data['background']
)

In [9]:
particle_filter_config = {
    "num_particles": 1000,
    "dim_particles": 2,
    "prior_configs": {'types': ['uniform', 'uniform'], 'args': [{'low': 20, 'high': 40}, {'low': -5, 'high': 5}]}
}

grid_info = {
    k: [v.min().item(), v.max().item(), len(v)] for k,v in data['grid'].items()
}

mask_config = {
    "raw_mask_path": cfg['paths']['raw_mask_path'],
    "memmap_mask_path": cfg['paths']['memmap_mask_path'],
    "grid_info": grid_info,
    "preload": False,
    "build_from_scratch_if_no_memmap": True,
    "global_mask": None
}
psi_mask = PsiMask(**mask_config)

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy


In [11]:

for idx_sample in tqdm(test_idx):
    sim_experiment = SimulatedExperiment(
        spinw_data['q_grid'], spinw_data['w_grid'], 
        spinw_data['Syy'][idx_sample], spinw_data['Szz'][idx_sample],
        neutron_flux=300
    )
    sim_experiment.prepare_experiment(psi_mask.hklw_grid)
    experiment_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "S_grid": torch.from_numpy(data['background']) + \
            global_mask * sim_experiment.Sqw,
        "S_scale_factor": 1.
    }

    background_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "bkg_grid": data['background']
    }

    model = SpecNeuralRepr.load_from_checkpoint(model_path).to(device)

    steer = NeutronExperimentSteerer(
        model, particle_filter_config=particle_filter_config,
        mask_config=mask_config, experiment_config=experiment_config, background_config=background_config,
        use_utility_sf=cfg['utility']['use_utility_sf'], utility_sf=cfg['utility']['utility_sf_sigma'],
        tqdm_pbar=False, lkhd_dict=cfg['likelihood'], device=device)
        
    mean_list = [steer.particle_filter.mean().detach().cpu()]
    std_list = [steer.particle_filter.std().detach().cpu()]

    posisition_list = [steer.particle_filter.positions.data.T[None].cpu()]
    weights_list = [steer.particle_filter.weights.data[None].cpu()]

    true_params = spinw_data['params'][idx_sample].numpy()

    print('true params: ', true_params)
    with torch.no_grad():
        progress_bar = tqdm(range(num_steps))
        for i in progress_bar:
            steer.step_steer(mode='unique_optimal')
            current_mean = steer.particle_filter.mean().detach().cpu()
            current_std = steer.particle_filter.std().detach().cpu()
            progress_bar.set_description(
                f'means: [{current_mean[0]:.3f}, {current_mean[1]:.3f}] '
                f' stds: [{current_std[0]:.3f}, {current_std[1]:.3f}]'
            )
            mean_list.append(current_mean)
            std_list.append(current_std)
            
            posisition_list.append(steer.particle_filter.positions.data.T[None].cpu())
            weights_list.append(steer.particle_filter.weights.data[None].cpu())

    sub_result_dict = {
        'means': torch.vstack(mean_list).double(),
        'positions': torch.vstack(posisition_list).double(),
        'weights': torch.vstack(weights_list).double(),
        'measured_angles': torch.from_numpy(np.vstack(steer.measured_angles_history).squeeze()).double(),
        'background_signal_factors': torch.stack(steer.sig_bkg_factors_history).double(),
        'utility': torch.from_numpy(np.vstack(steer.utility_history).squeeze()).double(),
        'true_params': torch.from_numpy(true_params).double(),
    }
    
    torch.save(sub_result_dict, os.path.join(output_path, f'{idx_sample}.pt'))

  0%|          | 0/60 [00:00<?, ?it/s]

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [26.40645264 -3.44447412]


  self.rng.multivariate_normal(
means: [26.514, -3.451]  stds: [0.341, 0.240]: 100%|██████████| 50/50 [11:43<00:00, 14.08s/it]
  2%|▏         | 1/60 [13:44<13:30:44, 824.49s/it]

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [22.4045718   3.47758564]


means: [22.370, 3.415]  stds: [0.350, 0.252]:  72%|███████▏  | 36/50 [08:40<03:22, 14.46s/it]
  2%|▏         | 1/60 [24:24<24:00:05, 1464.50s/it]


KeyboardInterrupt: 