In [2]:
import numpy as np
from scipy.integrate import quad


In [3]:
def subf_baseline_removal(im, bval):
    """
    Remove scalar baseline from images and truncate negative values to zero.
    """
    im = im - bval
    im[im < 0] = 0
    return im
def el_rang(a, b, phi):
    return a * b / np.sqrt((b * np.cos(phi))**2 + (a * np.sin(phi))**2)

def estimate_collection_efficiency(p):
    # Prepare geometric collection efficiency estimate

    # Check Working Distance scale (often in mm, not um)
    if p['obj']['WkDist'] < 100:
        p['obj']['WkDist'] *= 1000

    thet = np.arcsin(p['obj']['NA'] / p['obj']['RefIndex'])
    lens_half = np.tan(thet) * p['obj']['WkDist']
    dx_max = np.sqrt((p['cam']['PixNumX'] * p['cam']['PixSizeX'])**2 + 
                     (p['cam']['PixNumY'] * p['cam']['PixSizeY'])**2) / 2

    npts = int(np.ceil(np.sqrt(p['cam']['PixNumX']**2 + p['cam']['PixNumY']**2) / 2))
    dx = np.linspace(0, dx_max, npts)
    A = np.column_stack((dx + lens_half, np.ones(npts) * p['obj']['WkDist']))
    B = np.column_stack((dx - lens_half, np.ones(npts) * p['obj']['WkDist']))

    ta = np.arctan(lens_half / np.sqrt(p['obj']['WkDist']**2 + dx**2))
    tb = np.arccos(np.sum(A * B, axis=1) / (np.linalg.norm(A, axis=1) * np.linalg.norm(B, axis=1))) / 2

    # Numerical integration to evaluate collection efficiency
    e_col = (2 * np.pi - np.array([quad(lambda x: np.cos(el_rang(ta_i, tb_i, x)), -np.pi, np.pi)[0] for ta_i, tb_i in zip(ta, tb)])) / (4 * np.pi)

    # Adjustment for illumination uniformity
    WD_il = p['obj']['WkDist'] * lens_half / (lens_half - dx_max)
    R_0 = WD_il - p['obj']['WkDist']
    R = np.sqrt(R_0**2 + dx**2)
    SA = 2 * np.pi * R**2 * (1 - np.cos(thet))
    RI = 2 * np.pi * R_0**2 * (1 - np.cos(thet)) / SA
    e_col *= RI
    return dx, e_col

def subf_objective_correction(im, p, e_inv=None):
    # Version check provision
    if im == 'version':
        return 'v1.0', None

    # Version check provision
    # Interpolate the estimated collection efficiency to match image (if not provided)
    if e_inv is None:
        # Assume estimate_collection_efficiency is another function defined somewhere
        dx, e_col = estimate_collection_efficiency(p)

        # Compile image information
        pix_dx = p['cam']['PixSizeX']
        pix_dy = p['cam']['PixSizeY']
        imsz = im.shape
        if len(imsz) < 3:
            imsz = (*imsz, 1)

        # Define displacement of center of each pixel from image center
        dxmat = abs(np.arange(imsz[0]) - (imsz[0] + 1) / 2) * pix_dx
        dxmat = np.tile(dxmat[:, np.newaxis, np.newaxis], (1, imsz[1], imsz[2]))

        dymat = abs(np.arange(imsz[1]) - (imsz[1] + 1) / 2) * pix_dy
        dymat = np.tile(dymat[np.newaxis, :, np.newaxis], (imsz[0], 1, imsz[2]))

        imdx = np.sqrt(dxmat**2 + dymat**2)

        # Perform interpolation
        e_col_im = np.interp(imdx, dx, e_col) #find a better interpolation function

        # Modify original image by spatially modified collection efficiency estimate
        ceff_nom = (1 - np.sqrt(1 - (p['obj']['NA'] / p['obj']['RefIndex'])**2)) / 2
        e_inv = ceff_nom / e_col_im

        # Perform image correction
        imout = im * e_inv

        return imout, e_inv
    else:
        imout = im * e_inv

        return imout





In [4]:
def iman_refine(im, bval, X=None, c=None):
    # Assuming subf_baseline_removal and subf_objective_correction are defined elsewhere
    # or imported at the beginning of the script

    # Version check provision
    if im == 'version':
        return 'v2.0', X

    # Remove baseline value from image
    im = subf_baseline_removal(im, bval)

    # Remove flat-field defects
    if X is not None:
        # IF a GMD (using dictionary to simulate struct)
        if isinstance(X, dict) and 'obj' in X and 'cam' in X:
            im, X = subf_objective_correction(im, X)
            
        # IF a matrix (assuming numpy array)
        elif isinstance(X, np.ndarray) and X.ndim == im.ndim and X.shape == im.shape:  
            im = subf_objective_correction(im, None, X)
        # IF a Field Flattener (assuming FieldFlattener is a defined Python class)
        elif isinstance(X, FieldFlattener):
            # Catch omitted channel specification (assume all channels)
            if c is None:
                c = range(len(X.model))
            im, X = X.Flatten(im, c)
    
    return im, X
    


In [5]:
#im = iman_refine(im, bk.bval, flat(h), c)