# Functions

In [None]:
def get_dictionary_spots(spots_reflections_raw,timerounded=True,shift_min=True,dt=3):
    '''
    Get spot centers + angle + intensity + hkl and return a dictionary where each key
    is a frame number
    '''
    
    fractional=False
    if dt<1 :
        spots_reflections_raw[:,2]*=10
        spots_reflections_raw=round_time_SPOTXDS(spots_reflections_raw,dt*10)
        fractional=True
    else:
        # round only time up and down
        if timerounded:
            spots_reflections_raw=round_time_SPOTXDS(spots_reflections_raw,dt)
        
    
    # split indexablea and not-indexable frames
    
    indexable = []
    notindexable = []
    for sp in spots_reflections_raw:
        if sum(sp[-3:])==0.:
            notindexable.append(sp)
        else:
            indexable.append(sp)
    

    dict_spots_indexable=get_dict_from_spotlists(np.asarray(indexable),fractional=fractional)
    dict_spots_notindexable=get_dict_from_spotlists(np.asarray(notindexable),fractional=fractional)
    
    
    return dict_spots_indexable, dict_spots_notindexable

def get_dict_from_spotlists(spots_reflections_raw,fractional=False):
    """
    get a list of and return a dictionary
    """
    from collections import defaultdict
    from tqdm import tqdm
    dict_spots = defaultdict(lambda: [])
    
    frames=set(spots_reflections_raw[:,2].astype(int))
    min_framenum=min(frames) 
    
    divide=1.
    if fractional:
        divide=10.
    for fr in tqdm(frames):
        dict_spots[(fr-min_framenum)/divide]=match_spots_frame_expanded(spots_reflections_raw,fr)
        
    return dict_spots

# Just round up and down the time, we can extend this to a larger time window
def round_time_SPOTXDS(raw_spots,dt=0):
    # try to remove duplicates
    # set_of_floats=set(list(map(tuple,raw_spots[:,:3])))
    max_frame=int(np.max(raw_spots[:,2]))
    expanded_list=[]
    for spot in raw_spots:
        # round to the time to the earest integer
        rounded=np.round(spot[2])
        # set the lower bound
        min_round=max(0,rounded-dt)
        # set the upper bound
        max_round=min(rounded+dt+1,max_frame+1)
        
        for new_time in np.arange(min_round,max_round,1):
            expanded_list.append((spot[0], spot[1], new_time, spot[-4], spot[-3], spot[-2], spot[-1]))
        
    return np.asarray(expanded_list)


def match_spots_frame_expanded(predicted,n_frame):
    '''
    Given the list of spots found by XDS and a specific frame number, 
    return the list of centers found matching the frame number given
    '''
    # find all the spots in the same frame
    idx_rows=np.where(list(map(int,predicted[:,2]))==n_frame)
    #tmp_list=predicted[idx_rows,([0],[1])].T
    tmp_list=predicted[idx_rows]
    # I use set to remove possible duplicates
    return np.asarray(list(set(list(map(tuple,tmp_list)))))

def rodrigues(h,phi,rot_ax):
    import numpy as np
    
    cp=np.cos(phi)
    sp=np.sin(phi)
    omcp=1.-cp
    
    rot_h=np.zeros(3)
    
    rot_h[0] = (cp+rot_ax[0]**2*omcp)*h[0] + \
               (-rot_ax[2]*sp+rot_ax[0]*rot_ax[1]*omcp)*h[1] + \
               ( rot_ax[1]*sp+rot_ax[0]*rot_ax[2]*omcp)*h[2]
    rot_h[1] = ( rot_ax[2]*sp+rot_ax[0]*rot_ax[1]*omcp)*h[0] + \
               ( cp+rot_ax[1]**2*omcp)*h[1] + \
               (-rot_ax[0]*sp + rot_ax[1]*rot_ax[2]*omcp)*h[2]
    rot_h[2] = (-rot_ax[1]*sp+rot_ax[0]*rot_ax[2]*omcp)*h[0] + \
               ( rot_ax[1]*sp+rot_ax[1]*rot_ax[2]*omcp)*h[1] + \
               ( cp+rot_ax[2]**2*omcp)*h[2]
    
    return rot_h


def map_3D(spotslist,
           wavelength=0.999857,
           incident_beam=np.asarray([0.,0.,1.]),
           # size of the panel
           nx=4148,      
           ny=4362,
           # size of the pixels
           qx=0.075000,    
           qy=0.075000,
           orgx=2120.750488,
           orgy=2146.885498,
           det_dist=300.0,
           det_x=np.asarray([1.0,0.0,0.0]),
           det_y=np.asarray([0.0,1.0,0.0]),
           resolmax=0.1,
           resolmin=999,
           starting_angle=0,
           oscillation_range=0.0,
           rot_ax=np.asarray([1.000000,-0.000268,0.000392])
          ):
    
    import numpy as np
    
    det_z=np.zeros(3)
    # comput z in case x,y are not perpendicular to the beam
    det_z[0]=det_x[1]*det_y[2]-det_x[2]*det_y[1] # calculate detector normal -
    det_z[1]=det_x[2]*det_y[0]-det_x[0]*det_y[2] # XDS.INP does not have
    det_z[2]=det_x[0]*det_y[1]-det_x[1]*det_y[0] # this item.
    det_z = det_z/np.sqrt(np.dot(det_z,det_z))   # normalize (usually not req'd)
    
    spots=[]

    for line in spotslist:
        (ih,ik,il)=(0.,0.,0.)
        if len(line)==4:
            (x,y,phi,intensity)=line
        else:
            (x,y,phi,intensity,ih,ik,il)=line
    
        # convert detector coordinates to local coordinate system
        r= np.asarray([
            (x-orgx)*qx*det_x[0] + (y-orgy)*qy*det_y[0] +det_dist*det_z[0],
            (x-orgx)*qx*det_x[1] + (y-orgy)*qy*det_y[1] +det_dist*det_z[1],
            (x-orgx)*qx*det_x[2] + (y-orgy)*qy*det_y[2] +det_dist*det_z[2],
        ])
    
    
        # normalize scattered vector to obtain S1
        r=r/(wavelength*np.sqrt(np.dot(r,r)))
        # obtain reciprocal space vector S = S1-S0
        r=r-incident_beam
    
    
        if (np.sqrt(np.dot(r,r))>1./resolmax):
            continue # outer resolution limit
        if (np.sqrt(np.dot(r,r))<1./resolmin):
            continue # inner resolution limit
    
        # rotate  
        # NB: the term "-180." (found by trial&error) seems to make it match dials.rs_mapper
        phi=(starting_angle+oscillation_range*phi -180.)/180.*np.pi
    
        rot_r=rodrigues(r,phi,rot_ax)
    
        #rot_r=100.*rot_r + 100./resolmax  # ! transform to match dials.rs_mapper
    
        spots.append(np.hstack([rot_r,[intensity],[ih,ik,il]]))
        
    return np.asarray(spots)

# Load NPZ 

In [None]:
lyso_19=np.load("lyso_19_rawspots.npz")
list(lyso_19.keys())

# Break down into different angles SPOT.XDS -- also divide into indexable/notindexable

In [None]:
(indexable,notindexable)=get_dictionary_spots(lyso_19["spots"],timerounded=True,shift_min=True,dt=1)

## Map the to 3D

In [None]:
from tqdm import tqdm

pdbs=[]

for i in tqdm(indexable.keys()):
    
    sp3d =  map_3D(   
          indexable[i],
          starting_angle=lyso_19["starting_angle"],
          oscillation_range=0.0, # this is to not rotate the spots
          rot_ax=lyso_19["rot_ax"],
          wavelength=lyso_19["wavelength"],
          incident_beam=lyso_19["incident_beam"],
          nx=lyso_19["nx"],      
          ny=lyso_19["ny"],
          qx=lyso_19["qx"],
          qy=lyso_19["qy"],
          orgx=lyso_19["orgx"],
          orgy=lyso_19["orgy"],
          det_dist=lyso_19["det_dist"],
          det_x=lyso_19["det_x"],
          det_y=lyso_19["det_y"],
)
    if len(sp3d)>0:
        pdbs.append(sp3d[:,:3])

In [None]:
import ase
traj=[]
for fr in pdbs:
    traj.append(ase.Atoms(np.full(len(fr),"X"),positions=fr*100))
    
from ase.visualize import view

view(traj)

# Map ideal spots in 3D

In [None]:
(indexable_id,notindexable_id)=get_dictionary_spots(lyso_19["ideal_spots"],timerounded=True,shift_min=True,dt=0.5)

In [None]:
from tqdm import tqdm

pdbs=[]

for i in tqdm(indexable_id.keys()):
    
    sp3d =  map_3D(   
          indexable_id[i],
          starting_angle=lyso_19["starting_angle"],
          oscillation_range=0.0,
          rot_ax=lyso_19["rot_ax"],
          wavelength=lyso_19["wavelength"],
          incident_beam=lyso_19["incident_beam"],
          nx=lyso_19["nx"],      
          ny=lyso_19["ny"],
          qx=lyso_19["qx"],
          qy=lyso_19["qy"],
          orgx=lyso_19["orgx"],
          orgy=lyso_19["orgy"],
          det_dist=lyso_19["det_dist"],
          det_x=lyso_19["det_x"],
          det_y=lyso_19["det_y"],
)
    if len(sp3d)>0:
        pdbs.append(sp3d[:,:3])

In [None]:
import ase
traj=[]
for fr in pdbs:
    traj.append(ase.Atoms(np.full(len(fr),"X"),positions=fr*100))
    
from ase.visualize import view

view(traj)