In [None]:
# %load_ext autoreload
# %autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import skopi as sk

import torch
from tqdm import tqdm 
import time
import h5py

from neurorient import uniform_points_on_sphere
from neurorient.reconstruction.slicing import get_real_mesh

import os
os.environ["USE_CUPY"] = "1"
os.environ.get('USE_CUPY')

In [None]:
save_dir = '/pscratch/sd/z/zhantao/neurorient_repo/data'

In [None]:
n_pixels, det_size, det_dist = (128, 0.1, 0.2)
det = sk.SimpleSquareDetector(n_pixels, det_size, det_dist)

In [None]:
pdb = '5UOE'
increase_factor = 1
poisson = False
num_images = 10000

In [None]:
# Set up x-ray beam
beam = sk.Beam("input/beam/amo86615.beam")
beam.set_photons_per_pulse(increase_factor * beam.get_photons_per_pulse())

# Set up particle
# pdb file of lidless mmCpn in open state
particle = sk.Particle()
particle.read_pdb(f"input/pdb/{pdb}.pdb", ff='WK')

# Set up SPI experiment
exp = sk.SPIExperiment(det, beam, particle)

In [None]:
# generate random orientations
np.random.seed(42)
orientations = sk.get_random_quat(num_images)
print(orientations)

# setup experiment with generated orientations
exp.set_orientations(orientations)
# preallocate memory for images
images = np.zeros((len(orientations), ) + det.shape[1:])
# calculate and fill in images
for i in tqdm(range(len(orientations))):
    if poisson:
        images[i] = exp.generate_image_stack(return_photons=True, return_intensities=False)[0]
    else:
        images[i] = exp.generate_image_stack(return_photons=False, return_intensities=True)[0]

In [None]:
img_real_mesh = get_real_mesh(det.shape[1], det.pixel_position_reciprocal.max())
_mesh, _len = exp.det.get_reciprocal_mesh(voxel_number_1d=exp.mesh_size)
vol_real_mesh = get_real_mesh(_mesh.shape[0], _mesh.max())

In [None]:
pt_fpath = os.path.join(save_dir, f'{pdb}_increase{increase_factor:d}_poisson{poisson}_num{num_images//1000:d}K.pt')
torch.save(
    {
        'orientations': torch.from_numpy(orientations).float(),
        'intensities': torch.from_numpy(images).float(),
        'pixel_position_reciprocal': torch.from_numpy(det.pixel_position_reciprocal).float(),
        'pixel_index_map': torch.from_numpy(det.pixel_index_map).long(),
        'volume': torch.from_numpy(exp.volumes[0]).to(torch.complex64),
        'img_real_mesh': img_real_mesh.float(),
        'vol_real_mesh': vol_real_mesh.float(),
        'time_stamp': time.strftime("%Y%m%d-%H%M")
    }, pt_fpath
)
print("data wrote to: \n", pt_fpath)

In [None]:
h5_fpath = os.path.join(
    save_dir, 
    f'{pdb}_increase{increase_factor:d}_poisson{poisson}_num{num_images//1000}K.h5')
with h5py.File(h5_fpath, 'w') as f:
    f.create_dataset('intensities', data=images[:, None])
    f.create_dataset('orientations', data=orientations)
    f.create_dataset('pixel_position_reciprocal', data=det.pixel_position_reciprocal)
    f.create_dataset('pixel_distance_reciprocal', data=det.pixel_distance_reciprocal)
    f.create_dataset('pixel_index_map', data=det.pixel_index_map)
print("data wrote to: \n", h5_fpath)