In [1]:
# %load_ext autoreload
# %autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import skopi as sk

import torch
from tqdm import tqdm 
import time
import h5py

from neurorient import uniform_points_on_sphere
from neurorient.reconstruction.slicing import get_real_mesh

import os
os.environ["USE_CUPY"] = "1"
os.environ.get('USE_CUPY')

  def get_phase(atom_pos, q_xyz):
  def cal(f_hkl, atom_pos, q_xyz, xyz_ind, pixel_number):
  def euler_to_quaternion(psi, theta, phi):
  def rotmat_to_quaternion(rotmat):
  def quaternion2rot3d(quat):
  def cross_talk_effect(dbase, photons, shape, dbsize, boundary):


'1'

In [2]:
save_dir = '/pscratch/sd/z/zhantao/neurorient_repo/data'

In [3]:
n_pixels, det_size, det_dist = (128, 0.1, 0.2)
det = sk.SimpleSquareDetector(n_pixels, det_size, det_dist)

In [4]:
pdb = '1BXR'
increase_factor = 1
poisson = False
num_images = 10000

In [5]:
# Set up x-ray beam
beam = sk.Beam("input/beam/amo86615.beam")
beam.set_photons_per_pulse(increase_factor * beam.get_photons_per_pulse())

# Set up particle
# pdb file of lidless mmCpn in open state
particle = sk.Particle()
particle.read_pdb(f"input/pdb/{pdb}.pdb", ff='WK')

# Set up SPI experiment
exp = sk.SPIExperiment(det, beam, particle)

Unknown element or wrong line: 
 HETATM44708 MN    MN A1074      18.497  28.122  73.953  1.00 24.31          MN  

Unknown element or wrong line: 
 HETATM44709  K     K A1075      19.175  40.400  68.787  1.00 27.23           K  

Unknown element or wrong line: 
 HETATM44710  K     K A1076      16.179  26.429  76.392  1.00 28.97           K  

Unknown element or wrong line: 
 HETATM44711 MN    MN A1077      46.831  24.868  51.269  1.00 29.70          MN  

Unknown element or wrong line: 
 HETATM44712 MN    MN A1078      48.316  28.297  52.317  1.00 33.25          MN  

Unknown element or wrong line: 
 HETATM44713  K     K A1079      42.460  17.035  48.978  1.00 26.16           K  

Unknown element or wrong line: 
 HETATM44798  K     K B 984      39.799  64.742  81.201  1.00 31.65           K  

Unknown element or wrong line: 
 HETATM44799 MN    MN C1901       7.038  40.795 -48.602  1.00 29.00          MN  

Unknown element or wrong line: 
 HETATM44800  K     K C1903       9.575  28.579 



In [6]:
# generate random orientations
np.random.seed(42)
orientations = sk.get_random_quat(num_images)
print(orientations)

# setup experiment with generated orientations
exp.set_orientations(orientations)
# preallocate memory for images
images = np.zeros((len(orientations), ) + det.shape[1:])
# calculate and fill in images
for i in tqdm(range(len(orientations))):
    if poisson:
        images[i] = exp.generate_image_stack(return_photons=True, return_intensities=False)[0]
    else:
        images[i] = exp.generate_image_stack(return_photons=False, return_intensities=True)[0]

[[ 0.56397793 -0.55442653 -0.60717024 -0.07670999]
 [ 0.19255406 -0.11049266  0.89366123  0.38997937]
 [ 0.46295972  0.23167726  0.70263045 -0.48816431]
 ...
 [ 0.21783528 -0.07641909  0.11864684  0.96572814]
 [ 0.27062828 -0.72751106  0.36735123 -0.51238761]
 [ 0.78133746  0.41517632  0.46547994 -0.02165244]]


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [01:46<00:00, 93.54it/s]


In [7]:
img_real_mesh = get_real_mesh(det.shape[1], det.pixel_position_reciprocal.max())
_mesh, _len = exp.det.get_reciprocal_mesh(voxel_number_1d=exp.mesh_size)
vol_real_mesh = get_real_mesh(_mesh.shape[0], _mesh.max())

In [8]:
pt_fpath = os.path.join(save_dir, f'{pdb}_increase{increase_factor:d}_poisson{poisson}_num{num_images//1000:d}K.pt')
torch.save(
    {
        'orientations': torch.from_numpy(orientations).float(),
        'intensities': torch.from_numpy(images).float(),
        'pixel_position_reciprocal': torch.from_numpy(det.pixel_position_reciprocal).float(),
        'pixel_index_map': torch.from_numpy(det.pixel_index_map).long(),
        'volume': torch.from_numpy(exp.volumes[0]).to(torch.complex64),
        'img_real_mesh': img_real_mesh.float(),
        'vol_real_mesh': vol_real_mesh.float(),
        'time_stamp': time.strftime("%Y%m%d-%H%M")
    }, pt_fpath
)
print("data wrote to: \n", pt_fpath)

data wrote to: 
 /pscratch/sd/z/zhantao/neurorient_repo/data/1BXR_increase1_poissonFalse_num10K.pt


In [9]:
h5_fpath = os.path.join(
    save_dir, 
    f'{pdb}_increase{increase_factor:d}_poisson{poisson}_num{num_images//1000}K.h5')
with h5py.File(h5_fpath, 'w') as f:
    f.create_dataset('intensities', data=images[:, None])
    f.create_dataset('orientations', data=orientations)
    f.create_dataset('pixel_position_reciprocal', data=det.pixel_position_reciprocal)
    f.create_dataset('pixel_distance_reciprocal', data=det.pixel_distance_reciprocal)
    f.create_dataset('pixel_index_map', data=det.pixel_index_map)
print("data wrote to: \n", h5_fpath)

data wrote to: 
 /pscratch/sd/z/zhantao/neurorient_repo/data/1BXR_increase1_poissonFalse_num10K.h5
