In [1]:
%load_ext autoreload
%autoreload 2
from inxss.utils_spectrum import calc_Sqw_from_Syy_Szz
from inxss.experiment import SimulatedExperiment

import torch
import numpy as np
from scipy.interpolate import RegularGridInterpolator

from inxss import SpectrumDataset, SpecNeuralRepr, Particle, PsiMask, OnlineVariance, linspace_2D_equidistant
from inxss.utils_visualization import arc_arrow, rad_arrow

import matplotlib.pyplot as plt

from tqdm import tqdm 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.set_default_dtype(torch.float32)

In [2]:
from inxss.experiment import Background, SimulatedExperiment
from inxss.steer_neutron import NeutronExperimentSteerer
from sklearn.model_selection import train_test_split

In [None]:
num_steps = 50

In [3]:
model_path = '/pscratch/sd/z/zhantao/inxs_steering_production/models/version_14896845/checkpoints/epoch=7160-step=343728.ckpt'
data = torch.load('/pscratch/sd/z/zhantao/inxs_steering_production/experiment_data/summarized_neutron_data_w_bkg_260meV_ML.pt')
print(data.keys())

dict_keys(['grid', 'S', 'background', 'background_dict'])


In [4]:
global_mask = (data['S']>0).bool()

background = Background(
    tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]), 
    data['grid']['w_grid'], 
    data['background']
)

In [10]:
particle_filter_config = {
    "num_particles": 1000,
    "dim_particles": 2,
    "prior_configs": {'types': ['uniform', 'uniform'], 'args': [{'low': 20, 'high': 40}, {'low': -5, 'high': 5}]}
}


grid_info = {
    k: [v.min().item(), v.max().item(), len(v)] for k,v in data['grid'].items()
}

mask_config = {
    "raw_mask_path": '/pscratch/sd/z/zhantao/inxs_steering/La2NiO4_bool',
    "memmap_mask_path": '/pscratch/sd/z/zhantao/inxs_steering/mask_data',
    "grid_info": grid_info,
    "preload": False,
    "build_from_scratch_if_no_memmap": True,
    "global_mask": None
}
psi_mask = PsiMask(**mask_config)

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy


In [7]:
spinw_data = torch.load('/pscratch/sd/z/zhantao/inxs_steering/SpinW_data/summarized_AFM_data_2023Sep13.pt')

train_idx, val_test_idx = train_test_split(np.arange(spinw_data['Syy'].shape[0]), test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(val_test_idx, test_size=0.5, random_state=42)

result_dict = {}

In [8]:
# idx_sample = 13
for idx_sample in tqdm(test_idx):
    sim_experiment = SimulatedExperiment(
        spinw_data['q_grid'], spinw_data['w_grid'], 
        spinw_data['Syy'][test_idx[idx_sample]], spinw_data['Szz'][test_idx[idx_sample]],
        neutron_flux=300
    )
    sim_experiment.prepare_experiment(psi_mask.hklw_grid)
    experiment_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "S_grid": torch.from_numpy(data['background']) + \
            global_mask * sim_experiment.Sqw,
        "S_scale_factor": 1.
    }

    background_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "bkg_grid": data['background']
    }

    model = SpecNeuralRepr.load_from_checkpoint(model_path).to(device)

    steer = NeutronExperimentSteerer(
        model, particle_filter_config=particle_filter_config,
        mask_config=mask_config, experiment_config=experiment_config, background_config=background_config,
        likelihood_sample_ratio=0.25, tqdm_pbar=False, device='cuda')
    
    mean_list = []
    std_list = []

    posisition_list = []
    weights_list = []

    steer.reset()

    true_params = spinw_data['params'][test_idx[idx_sample]].numpy()

    print('true params: ', true_params)
    with torch.no_grad():
        progress_bar = tqdm(range(num_steps))
        for i in progress_bar:
            steer.step_steer(mode='unique_optimal')
            current_mean = steer.particle_filter.mean().detach().cpu()
            current_std = steer.particle_filter.std().detach().cpu()
            progress_bar.set_description(
                f'means: [{current_mean[0]:.3f}, {current_mean[1]:.3f}] '
                f' stds: [{current_std[0]:.3f}, {current_std[1]:.3f}]'
            )
            mean_list.append(current_mean)
            std_list.append(current_std)
            
            posisition_list.append(steer.particle_filter.positions.data.T[None])
            weights_list.append(steer.particle_filter.weights.data.T[None])

    sub_result_dict = {
        'positions': torch.vstack(posisition_list).float(),
        'weights': torch.vstack(weights_list).float(),
        'measured_angles': torch.from_numpy(np.vstack(steer.measured_angles_history).squeeze()).float(),
        'background_signal_factors': torch.stack(steer.sig_bkg_factors_history).float(),
        'true_params': torch.from_numpy(true_params).float(),
    }
    
    result_dict[idx_sample] = sub_result_dict
    torch.save(result_dict, f'/pscratch/sd/z/zhantao/inxs_steering_production/benchmark/{idx_sample}.pt')

  0%|          | 0/60 [00:00<?, ?it/s]

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [22.54245013 -3.92064747]


  weights_list.append(steer.particle_filter.weights.data.T[None])
means: [27.225, -0.432]  stds: [3.897, 2.649]: 100%|██████████| 5/5 [01:09<00:00, 13.95s/it]
  0%|          | 0/60 [02:58<?, ?it/s]
