In [None]:
%matplotlib widget

import sys
import os
from pathlib import Path

SCRIPT_DIR = Path(os.getcwd()).parent
sys.path.append(os.path.dirname(SCRIPT_DIR))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from python.fluorophores import FlStatic
from python.psfs import PsfVectorial
from python.estimators import est_quad2Diter
from python.simulators import Simulator

In [None]:
psf_vecpp = PsfVectorial()
psf_vecpp.setpinhole(AU=1)
phaseplateposmm = np.arange(6)/10  # mm
zpos = [-200, -100, 0, 100, 200,]
phaseplateposrel = phaseplateposmm/2.5  # pupil diameter assumed to be 5 mm; 
stdx = np.zeros((len(phaseplateposrel),len(zpos),3))
crb1 = np.zeros((len(phaseplateposrel),len(zpos),3))
biasx = np.zeros((len(phaseplateposrel),len(zpos),3))
rmsex = np.zeros((len(phaseplateposrel),len(zpos),3))
phot = np.zeros((len(phaseplateposrel),len(zpos),1))

In [None]:
fl = FlStatic(brightness=1000)  # define a static fluorophore
fl.pos = [10, 0, 0]

sim = Simulator(fluorophores=fl)

numberOfLocalizations=1000

# define scan pattern
L = 75
zeroposx = np.atleast_2d(np.array([-1,1,0])*L/2)
probecenter = True  # should we also probe the center?
orbitpoints = 6
laserpower = 10  # relative, increases brightness
pointdwelltime = 0.1  # ms, measurement time in each point
repetitions = 1  # how often to repeat the pattern scan

sim.defineComponent("estdonut", "estimator", est_quad2Diter, parameters=[L, probecenter], dim=(0,1))

In [None]:
psfall = []
stdx = np.zeros((len(phaseplateposrel), len(zpos), 3))
crb1 = np.zeros((len(phaseplateposrel), len(zpos), 3))
biasx = np.zeros((len(phaseplateposrel), len(zpos), 3))
rmsex = np.zeros((len(phaseplateposrel), len(zpos), 3))
phot = np.zeros((len(phaseplateposrel), len(zpos)))
for k in range(len(phaseplateposrel)):
    sys_mis = {}
    sys_mis['maskshift'] = [phaseplateposrel[k], 0]  # radius of pupil function is 1
    psf_vecpp.setpar(**sys_mis)
    sim.definePattern("donut_misaligned", psf_vecpp, 
                      phasemask="vortex", 
                      makepattern="orbitscan", 
                      orbitpoints=orbitpoints, 
                      probecenter=probecenter,
                      orbitL=L,
                      pointdwelltime=pointdwelltime,
                      laserpower=laserpower,
                      repetitions=repetitions)
    stack, gridv = psf_vecpp.imagestack("vortex")
    psfall.append(stack)
    for z in range(len(zpos)):
        sim.fluorophores.pos = [0, 0, zpos[z]]
        seq = ["donut_misaligned", "estdonut"]
        out = sim.runSequence(seq)
        sr = sim.summarize_results(out)
        stdx[k,z,:] = sr.std
        crb1[k,z,:] = sr.sCRB1
        biasx[k,z,:] = sr.bias
        rmsex[k,z,:] = sr.rmse
        phot[k,z] = sr.phot

stdxrel=stdx/crb1[0,:,:]*np.sqrt(phot)[...,None]  # normalized to perfectly aligned phaseplate and photon numbers
rmserel=rmsex/crb1[0,:,:]*np.sqrt(phot)[...,None]  # normalized to perfectly aligned phaseplate and photon numbers

psfall = np.stack(psfall,axis=3)  # to array

In [None]:
fig = plt.figure()
gs = GridSpec(3, 3, figure=fig)

ax = fig.add_subplot(gs[0,0])              
ax.plot(phaseplateposmm, stdxrel[:,0,0], phaseplateposmm, stdxrel[:,0,1],'--')
ax.set_xlabel('misalignment of phase plate (mm)')
ax.set_ylabel('std / CRB aligned')
ax.set_title("Standard deviation")
ax.legend(["x","y"])

ax = fig.add_subplot(gs[0,2])
ax.plot(zpos, biasx[:,:,0].T, zpos, biasx[:,:,1].T,'--')
ax.set_xlabel('z position (nm)')
ax.set_ylabel('bias (nm)')
ax.legend(phaseplateposmm)
ax.set_title("Bias")

ax = fig.add_subplot(gs[0,1])
ax.plot(phaseplateposmm, rmserel[:,:,0], phaseplateposmm, rmserel[:,:,1],'--')
ax.set_xlabel('misalignment of phase plate (mm)')
ax.set_ylabel('rmse / CRB aligned')
ax.set_title("Root mean square error (rmse)")
ax.legend(zpos)

In [None]:

for z in range(len(zpos)):
    if z==0:
        indz = np.where(gridv[2] >= zpos[z])[0][0].astype(int)
        psfz = psfall[:,:,indz,-1]
    else:
        indzp = indz
        indz = np.zeros(z+1, dtype=int)
        indz[:z] = indzp
        indz[z] = np.where(gridv[2] >= zpos[z])[0][0].astype(int)
        psfz=np.hstack([psfz, psfall[:,:,indz[z],-1]])

ax = fig.add_subplot(gs[1,:])
ax.imshow(psfz)
ax.set_title('z pos')

In [None]:
indz0 = np.where(gridv[2] >= 0)[0][0].astype(int)
for k in range(len(phaseplateposmm)):
    if k==0:
        psfx = psfall[:,:,indz0,k]
    else:
        psfx = np.hstack([psfx,psfall[:,:,indz0,k]])
ax = fig.add_subplot(gs[2,:])
ax.imshow(psfx)
ax.set_title('misalignment')
