In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import Normalize
import numpy as np
import h5py
import hdf5plugin
import torch

from crystfelparser.crystfelparser import streamfile_parser

In [None]:
stream_file = "ref/KR2_LB_YESALL_000001.stream"
stream = streamfile_parser(stream_file)
h5_file = "/das/work/p16/p16371/MAXIV/20211201/raw/KR2/KR2-dark3/KR2-dark3_28_data_000001.h5"

In [None]:
def raw_frame_plot(img,vmax=20,figsize=(12,12)):
    '''Just plot a diffraction image using matplotlib'''
    im=np.clip(img,0,10000)
    plt.figure(figsize=figsize, dpi=80)
    imlog=np.log(im+1)
    imlog=imlog/np.max(imlog)*255
    plt.imshow(imlog,cmap="gist_yarg",vmin=0,vmax=vmax)
    plt.axis('off')
    
def predictBraggsReflections(reciprocal_basis_vectors, 
                             vwlen=0.9998725806451613, 
                             maxhkl=30, 
                             detector_distance_m=0.200, 
                             detectorCenter=[1122.215602,1170.680571], 
                             pixelLength_m=75e-6, 
                             nx=2068, ny=2164, 
                             eps=4e-4, 
                             device='cpu'):
    """ Given a rotation matrix in the reciprocal space, predict the ideal Bragg's reflections

    Args:
        reciprocal_basis_vectors (torch.tensor): reciprocal basis vectors (a*, b*, c*)
        vwlen (float, optional): beam's wavelenght. Defaults to 0.9998725806451613.
        maxhkl (int, optional): max integer for h, k, l numbers. Defaults to 30.
        detector_distance_m (float, optional): Detector distance in m. Defaults to 0.200.
        detectorCenter (list, optional): Beam's center in pixels. Defaults to [1122.215602,1170.680571].
        pixelLength_m (float, optional): Pixels size in m. Defaults to 75e-6.
        nx (int, optional): Detector size (x direction). Defaults to 2068.
        ny (int, optional): Detector size (y direction). Defaults to 2164.
        eps (float, optional): Tolerance (correlates with the non-monocromaticity). Defaults to 3e-4.
        device (str, optional): Torch device. Defaults to 'cpu'.

    Returns:
        _type_: _description_
    """
    
    #  for now this will stay hard-coded
    beam_direction=np.array([0,0,1./vwlen])
    detectorCenter=[detectorCenter[1],detectorCenter[0]]
    
    # generate the millers grid
    rh = torch.arange(-maxhkl, maxhkl, dtype=torch.int16, device=device)
    rk = torch.arange(-maxhkl, maxhkl, dtype=torch.int16, device=device)
    rl = torch.arange(-maxhkl, maxhkl, dtype=torch.int16, device=device)
    millers = torch.stack([
        rh.repeat_interleave(len(rk)*len(rl)),
        rk.repeat(len(rh)).repeat_interleave(len(rl)),
        rl.repeat(len(rh)*len(rk)),
    ]).T
    # generate the reciprocal lattice vectors
    reciprocalPeaks = millers @ reciprocal_basis_vectors
    
    # remove the beam vector
    reciprocalPeaks -= torch.tensor(beam_direction)
    # threshold
    cond = torch.abs(torch.norm(reciprocalPeaks,dim=1)-(1/vwlen)) < eps
    reciprocalPeaks = reciprocalPeaks[cond]
    millers = millers[cond]
    
    # reflect the x-axis
    # rotation matrix
    rt = torch.tensor([
        [-1,0,0],
        [0,1,0],
        [0,0,1]], dtype=torch.double)
    reciprocalPeaks = reciprocalPeaks @ rt

    # recenter the points
    reciprocalPeaks += torch.tensor([beam_direction])
    
    # flip
    reciprocalPeaks*=torch.tensor([-1,1,1])
    projectedPeaks = reciprocalPeaks[:, [1,0]] / (reciprocalPeaks[:, -1:] - (1/vwlen)) * detector_distance_m
    # center and flip x and y
    projectedPeaks=(projectedPeaks / pixelLength_m + torch.tensor(detectorCenter, device=device))[:,[1,0]]
    # limit the points inside the detector panel
    return projectedPeaks[(projectedPeaks[:,0] > 0) & (projectedPeaks[:,0] < nx) & (projectedPeaks[:,1] > 0) & (projectedPeaks[:,1] < ny)]


def plot_frame_reflectins(frame_idx, stream, h5_file):
    '''Plot a diffraction image using matplotlib
       and on top compare two different indexing solutions
       as well as the strong spots
    '''
    with h5py.File(h5_file, 'r') as f:
        # Get the frame data
        frame = f['/entry/data/data'][frame_idx]
    
        # check if the frame has an indexed solution
        if 'reciprocal_cell_matrix' in stream.parsed[frame_idx].keys():
            reciprocal_cell_matrix = stream.parsed[frame_idx]['reciprocal_cell_matrix'] * 0.1
            predicted_reflections_2 = predictBraggsReflections(reciprocal_cell_matrix, 
                                     vwlen=stream.wavelength, 
                                     maxhkl=80, 
                                     detector_distance_m=stream.clen/1000, 
                                     detectorCenter=[stream.beam_center_x,stream.beam_center_y], 
                                     pixelLength_m=75e-6, 
                                     nx=2068, ny=2164, 
                                     eps=9e-4, 
                                     device='cpu')
        else:
            predicted_reflections_2 = None

    
        spots2d = stream.get_spots_2d(frame_idx)

        raw_frame_plot(frame,vmax=50,figsize=(10,10))
        _ = plt.scatter(spots2d[:,0], spots2d[:,1], s=20,facecolors='none', edgecolors='r', label='spots')
        if not predicted_reflections_2 is None:
            _ = plt.scatter(predicted_reflections_2[:,0],predicted_reflections_2[:,1],s=60,facecolors='none', marker="d",edgecolors='#00FF00',label='predicted')
            predicted_reflections_1 = stream.parsed[frame_idx]['predicted_reflections'][:,:2]
            _ = plt.scatter(predicted_reflections_1[:,0],predicted_reflections_1[:,1],s=60,facecolors='none', marker="s",edgecolors='#009A29',label='indexmajig')
        plt.legend()

In [None]:
plot_frame_reflectins(stream.get_indexable_frames()[10], stream, h5_file)