In [9]:
%load_ext autoreload
%autoreload 2
from inxss.utils_spectrum import calc_Sqw_from_Syy_Szz
from inxss.experiment import SimulatedExperiment

import torch
import numpy as np
from scipy.interpolate import RegularGridInterpolator

from inxss import SpectrumDataset, SpecNeuralRepr, Particle, PsiMask, OnlineVariance, linspace_2D_equidistant
from inxss.utils_visualization import arc_arrow, rad_arrow

import matplotlib.pyplot as plt

from tqdm import tqdm 
from inxss.experiment import Background, SimulatedExperiment
from inxss.steer_neutron import NeutronExperimentSteerer
from sklearn.model_selection import train_test_split

import os
from datetime import datetime

torch.set_default_dtype(torch.float32)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
spinw_data = torch.load('/pscratch/sd/z/zhantao/inxs_steering/SpinW_data/summarized_AFM_data_2023Sep13.pt')

train_idx, val_test_idx = train_test_split(np.arange(spinw_data['Syy'].shape[0]), test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(val_test_idx, test_size=0.5, random_state=42)

result_dict = {}

In [3]:
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

In [4]:
num_steps = 50

scale_likelihood = True
likelihood_type = 'poisson'

In [11]:
time_stamp = datetime.now().strftime("%Y%m%d-%H%M")
output_path = f'/pscratch/sd/z/zhantao/inxs_steering_production/benchmarks/lkhd_{likelihood_type}_scaled_{scale_likelihood}_steps_{num_steps}_{time_stamp}'

if not os.path.exists(output_path):
    os.makedirs(output_path)

print('output_path:', output_path)

output_path: /pscratch/sd/z/zhantao/inxs_steering_production/benchmarks/lkhd_poisson_scaled_True_steps_50_20240131-1920


In [6]:
model_path = '/pscratch/sd/z/zhantao/inxs_steering_production/models/version_14896845/checkpoints/epoch=7160-step=343728.ckpt'
data = torch.load('/pscratch/sd/z/zhantao/inxs_steering_production/experiment_data/summarized_neutron_data_w_bkg_260meV_ML.pt')
print(data.keys())

dict_keys(['grid', 'S', 'background', 'background_dict'])


In [7]:
global_mask = (data['S']>0).bool()

background = Background(
    tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]), 
    data['grid']['w_grid'], 
    data['background']
)

In [8]:
particle_filter_config = {
    "num_particles": 1000,
    "dim_particles": 2,
    "prior_configs": {'types': ['uniform', 'uniform'], 'args': [{'low': 20, 'high': 40}, {'low': -5, 'high': 5}]}
}


grid_info = {
    k: [v.min().item(), v.max().item(), len(v)] for k,v in data['grid'].items()
}

mask_config = {
    "raw_mask_path": '/pscratch/sd/z/zhantao/inxs_steering/La2NiO4_bool',
    "memmap_mask_path": '/pscratch/sd/z/zhantao/inxs_steering/mask_data',
    "grid_info": grid_info,
    "preload": False,
    "build_from_scratch_if_no_memmap": True,
    "global_mask": None
}
psi_mask = PsiMask(**mask_config)

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy


In [None]:

for idx_sample in tqdm(test_idx):
    sim_experiment = SimulatedExperiment(
        spinw_data['q_grid'], spinw_data['w_grid'], 
        spinw_data['Syy'][idx_sample], spinw_data['Szz'][idx_sample],
        neutron_flux=300
    )
    sim_experiment.prepare_experiment(psi_mask.hklw_grid)
    experiment_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "S_grid": torch.from_numpy(data['background']) + \
            global_mask * sim_experiment.Sqw,
        "S_scale_factor": 1.
    }

    background_config = {
        "q_grid": tuple([data['grid'][_grid] for _grid in ['h_grid', 'k_grid', 'l_grid']]),
        "w_grid": data['grid']['w_grid'],
        "bkg_grid": data['background']
    }

    model = SpecNeuralRepr.load_from_checkpoint(model_path).to(device)

    steer = NeutronExperimentSteerer(
        model, particle_filter_config=particle_filter_config,
        mask_config=mask_config, experiment_config=experiment_config, background_config=background_config,
        likelihood_sample_ratio=0.25, tqdm_pbar=False, likelihood_type=likelihood_type, scale_likelihood=scale_likelihood,device=device)
        
    mean_list = [steer.particle_filter.mean().detach().cpu()]
    std_list = [steer.particle_filter.std().detach().cpu()]

    posisition_list = [steer.particle_filter.positions.data.T[None].cpu()]
    weights_list = [steer.particle_filter.weights.data[None].cpu()]

    true_params = spinw_data['params'][idx_sample].numpy()

    print('true params: ', true_params)
    with torch.no_grad():
        progress_bar = tqdm(range(num_steps))
        for i in progress_bar:
            steer.step_steer(mode='unique_optimal')
            current_mean = steer.particle_filter.mean().detach().cpu()
            current_std = steer.particle_filter.std().detach().cpu()
            progress_bar.set_description(
                f'means: [{current_mean[0]:.3f}, {current_mean[1]:.3f}] '
                f' stds: [{current_std[0]:.3f}, {current_std[1]:.3f}]'
            )
            mean_list.append(current_mean)
            std_list.append(current_std)
            
            posisition_list.append(steer.particle_filter.positions.data.T[None].cpu())
            weights_list.append(steer.particle_filter.weights.data[None].cpu())

    sub_result_dict = {
        'means': torch.vstack(mean_list).double(),
        'positions': torch.vstack(posisition_list).double(),
        'weights': torch.vstack(weights_list).double(),
        'measured_angles': torch.from_numpy(np.vstack(steer.measured_angles_history).squeeze()).double(),
        'background_signal_factors': torch.stack(steer.sig_bkg_factors_history).double(),
        'true_params': torch.from_numpy(true_params).double(),
    }
    
    torch.save(sub_result_dict, os.path.join(output_path, f'{idx_sample}.pt'))

  0%|          | 0/15 [00:00<?, ?it/s]

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [26.40645264 -3.44447412]



  0%|          | 0/50 [00:00<?, ?it/s][A
means: [31.495, -0.275]  stds: [5.056, 2.826]:   0%|          | 0/50 [00:12<?, ?it/s][A
means: [31.495, -0.275]  stds: [5.056, 2.826]:   2%|▏         | 1/50 [00:12<10:06, 12.38s/it][A
means: [31.223, -0.327]  stds: [4.793, 2.812]:   2%|▏         | 1/50 [00:24<10:06, 12.38s/it][A
means: [31.223, -0.327]  stds: [4.793, 2.812]:   4%|▍         | 2/50 [00:24<09:43, 12.15s/it][A
means: [31.210, -0.355]  stds: [4.477, 2.796]:   4%|▍         | 2/50 [00:36<09:43, 12.15s/it][A
means: [31.210, -0.355]  stds: [4.477, 2.796]:   6%|▌         | 3/50 [00:36<09:29, 12.11s/it][A
means: [30.872, -0.402]  stds: [4.256, 2.774]:   6%|▌         | 3/50 [00:48<09:29, 12.11s/it][A
means: [30.872, -0.402]  stds: [4.256, 2.774]:   8%|▊         | 4/50 [00:48<09:12, 12.01s/it][A
means: [30.784, -0.487]  stds: [4.096, 2.743]:   8%|▊         | 4/50 [01:00<09:12, 12.01s/it][A
means: [30.784, -0.487]  stds: [4.096, 2.743]:  10%|█         | 5/50 [01:00<08:57, 11.94s/it

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [22.4045718   3.47758564]



  0%|          | 0/50 [00:00<?, ?it/s][A
means: [29.874, -0.008]  stds: [5.851, 2.888]:   0%|          | 0/50 [00:11<?, ?it/s][A
means: [29.874, -0.008]  stds: [5.851, 2.888]:   2%|▏         | 1/50 [00:11<09:39, 11.84s/it][A
means: [26.805, 0.489]  stds: [5.643, 2.935]:   2%|▏         | 1/50 [00:23<09:39, 11.84s/it] [A
means: [26.805, 0.489]  stds: [5.643, 2.935]:   4%|▍         | 2/50 [00:23<09:35, 12.00s/it][A
means: [24.033, 1.436]  stds: [4.157, 2.578]:   4%|▍         | 2/50 [00:35<09:35, 12.00s/it][A
means: [24.033, 1.436]  stds: [4.157, 2.578]:   6%|▌         | 3/50 [00:35<09:19, 11.90s/it][A
means: [23.090, 2.019]  stds: [3.078, 2.207]:   6%|▌         | 3/50 [00:47<09:19, 11.90s/it][A
means: [23.090, 2.019]  stds: [3.078, 2.207]:   8%|▊         | 4/50 [00:47<09:06, 11.88s/it][A
means: [23.120, 1.991]  stds: [3.186, 2.224]:   8%|▊         | 4/50 [00:59<09:06, 11.88s/it][A
means: [23.120, 1.991]  stds: [3.186, 2.224]:  10%|█         | 5/50 [00:59<08:51, 11.82s/it][A
me

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [20.39750846 -1.98794413]



  0%|          | 0/50 [00:00<?, ?it/s][A
means: [29.861, -0.110]  stds: [5.676, 2.869]:   0%|          | 0/50 [00:11<?, ?it/s][A
means: [29.861, -0.110]  stds: [5.676, 2.869]:   2%|▏         | 1/50 [00:11<09:40, 11.85s/it][A
means: [26.312, 0.592]  stds: [4.244, 2.575]:   2%|▏         | 1/50 [00:23<09:40, 11.85s/it] [A
means: [26.312, 0.592]  stds: [4.244, 2.575]:   4%|▍         | 2/50 [00:23<09:33, 11.94s/it][A
means: [25.123, 0.687]  stds: [3.497, 2.461]:   4%|▍         | 2/50 [00:36<09:33, 11.94s/it][A
means: [25.123, 0.687]  stds: [3.497, 2.461]:   6%|▌         | 3/50 [00:36<09:27, 12.08s/it][A
means: [24.923, 0.651]  stds: [3.285, 2.382]:   6%|▌         | 3/50 [00:48<09:27, 12.08s/it][A
means: [24.923, 0.651]  stds: [3.285, 2.382]:   8%|▊         | 4/50 [00:48<09:17, 12.11s/it][A
means: [24.806, 0.674]  stds: [3.190, 2.301]:   8%|▊         | 4/50 [01:00<09:17, 12.11s/it][A
means: [24.806, 0.674]  stds: [3.190, 2.301]:  10%|█         | 5/50 [01:00<09:00, 12.00s/it][A
me

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [30.68027503  3.06292597]



  0%|          | 0/50 [00:00<?, ?it/s][A
means: [26.196, 0.977]  stds: [3.954, 2.554]:   0%|          | 0/50 [00:11<?, ?it/s][A
means: [26.196, 0.977]  stds: [3.954, 2.554]:   2%|▏         | 1/50 [00:11<09:38, 11.80s/it][A
means: [27.004, 0.927]  stds: [3.978, 2.580]:   2%|▏         | 1/50 [00:23<09:38, 11.80s/it][A
means: [27.004, 0.927]  stds: [3.978, 2.580]:   4%|▍         | 2/50 [00:23<09:29, 11.86s/it][A
means: [27.090, 0.901]  stds: [3.884, 2.562]:   4%|▍         | 2/50 [00:35<09:29, 11.86s/it][A
means: [27.090, 0.901]  stds: [3.884, 2.562]:   6%|▌         | 3/50 [00:35<09:17, 11.86s/it][A
means: [27.364, 0.918]  stds: [3.872, 2.559]:   6%|▌         | 3/50 [00:47<09:17, 11.86s/it][A
  self.rng.multivariate_normal(

means: [27.302, 0.882]  stds: [3.895, 2.579]:   8%|▊         | 4/50 [00:59<09:07, 11.91s/it][A
means: [27.302, 0.882]  stds: [3.895, 2.579]:  10%|█         | 5/50 [00:59<08:56, 11.92s/it][A
means: [27.739, 1.007]  stds: [3.731, 2.513]:  10%|█         | 5/50 

obtained memmap mask name as: mask_h_-2.0_2.0_121_k_-2.0_2.0_121_l_-10.0_4.5_30_w_20.0_200.0_91.npy
true params:  [38.24393509  0.8350447 ]



  0%|          | 0/50 [00:00<?, ?it/s][A
means: [33.646, -0.881]  stds: [4.506, 2.711]:   0%|          | 0/50 [00:11<?, ?it/s][A
means: [33.646, -0.881]  stds: [4.506, 2.711]:   2%|▏         | 1/50 [00:11<09:47, 11.98s/it][A
means: [33.730, -0.932]  stds: [4.552, 2.655]:   2%|▏         | 1/50 [00:24<09:47, 11.98s/it][A
means: [33.730, -0.932]  stds: [4.552, 2.655]:   4%|▍         | 2/50 [00:24<09:37, 12.02s/it][A
means: [34.550, -1.150]  stds: [4.042, 2.515]:   4%|▍         | 2/50 [00:36<09:37, 12.02s/it][A
means: [34.550, -1.150]  stds: [4.042, 2.515]:   6%|▌         | 3/50 [00:36<09:24, 12.00s/it][A
means: [35.126, -1.239]  stds: [3.625, 2.394]:   6%|▌         | 3/50 [00:48<09:24, 12.00s/it][A
means: [35.126, -1.239]  stds: [3.625, 2.394]:   8%|▊         | 4/50 [00:48<09:13, 12.03s/it][A
means: [35.493, -1.260]  stds: [3.288, 2.315]:   8%|▊         | 4/50 [00:59<09:13, 12.03s/it][A
means: [35.493, -1.260]  stds: [3.288, 2.315]:  10%|█         | 5/50 [00:59<08:57, 11.95s/it