# fasta algorithm

> API details.

In [None]:
#default_exp algorithm

In [None]:
#export 
import numpy as np
import torch as th

from __future__ import annotations

from smpr3d.operators import Qoverlap_real,  calc_psi_denom, \
    calc_psi, A_realspace, Qsplit, Qoverlap, SubpixShift, gradient_fourier_2d, prox_D_gaussian

from tqdm import trange
from smpr3d.util import plotAbsAngle, plot 
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union, Callable
from dataclasses import dataclass
from .data import LinearIndexEncoded4DDataset, Dense4DDataset, SMeta
from torch.utils.data import BatchSampler, SequentialSampler
import logging

@dataclass
class ADMMOptions:
    r12 : float
    r22 : float
    eps1 : float
    eps2 : float
    dL_dr_momentum : float
    dL_dr_step : float
    beta : float 
    do_position_correction : Callable
    do_subpix : bool
    non_blocking : bool
    margin : int
    do_smoothing : bool
    kernel_size : tuple
    sigma : np.array
    
    def __init__(self, 
                 beta = 0.9, 
                 r12 = 1e-6, 
                 r22 = 1e-3, 
                 eps1 = 1e-6, 
                 eps2 = 1e-6, 
                 dL_dr_momentum = 1.0, 
                 dL_dr_step = 1e-3,
                 do_position_correction = lambda it: True if it > 1 else False,
                 do_subpix = True,
                 non_blocking = False,
                 verbose = False,
                 do_smoothing = True,
                 margin = 5,
                 kernel_size = (9, 9),
                 sigma = np.array((2.5, 2.5)),
                 ):
        self.r12 = r12
        self.r22 = r22
        self.eps1 = eps1
        self.eps2 = eps2
        self.dL_dr_momentum = dL_dr_momentum
        self.dL_dr_step = dL_dr_step
        self.do_position_correction = do_position_correction
        self.do_subpix = do_subpix
        self.non_blocking = non_blocking
        self.beta = beta
        self.verbose = verbose
        self.do_smoothing = do_smoothing
        self.margin = margin
        self.kernel_size = kernel_size
        self.sigma = sigma
    
import h5py
@dataclass
class SMPRSolution:
    converged: bool
    smatrix: th.Tensor
    probe: th.Tensor
    positions: th.Tensor
    r_factor: float
    r_factor_history: th.Tensor
    s_matrix_meta : SMeta
    
    def __init__(self, 
                 converged : bool, 
                 smatrix : th.Tensor, 
                 probe : th.tensor, 
                 positions : th.Tensor,
                 r_factor : float,
                 r_factor_history : th.Tensor,
                 s_matrix_meta : SMeta):
        self.converged = converged
        self.smatrix = smatrix
        self.probe = probe
        self.positions = positions
        self.r_factor = r_factor
        self.r_factor_history = r_factor_history
        self.s_matrix_meta = s_matrix_meta

    def to_h5(self, file_path, key):
        with h5py.File(file_path, 'a') as f:
            g = f.create_group(key)
            g.create_dataset('converged', data=self.converged)
            g.create_dataset('smatrix', data=self.smatrix.cpu().numpy())
            g.create_dataset('probe', data=self.probe.cpu().numpy())
            g.create_dataset('positions', data=self.positions.cpu().numpy())
            g.create_dataset('r_factor', data=self.r_factor)
            g.create_dataset('r_factor_history', data=self.r_factor_history.cpu().numpy())
        self.s_matrix_meta.to_h5(file_path, key + 's_meta')
            
    @staticmethod        
    def from_h5(file_path, key):
        with h5py.File(file_path, 'r') as f:
            g = f[key]
            converged = bool(g['converged'][()])
            smatrix = th.as_tensor(g['smatrix'][...])
            probe = th.as_tensor(g['probe'][...])
            positions = th.as_tensor(g['positions'][...])
            r_factor = float(g['r_factor'][()])
            r_factor_history = th.as_tensor(g['r_factor_history'][...])
            
        s_meta = SMeta.from_h5(file_path, key + 's_meta')
        
        res = SMPRSolution(converged,
                           smatrix,
                           probe,
                           positions,
                           r_factor, 
                           r_factor_history, 
                           s_meta)
        return res 

from kornia.filters import  gaussian_blur2d
def gaussian(x, kernel_size, sigma):
        srmax = x.real.max()
        simax = x.real.max()
        smr = gaussian_blur2d(x.real.unsqueeze(0), kernel_size, sigma,border_type='reflect')
        smi = gaussian_blur2d(x.imag.unsqueeze(0), kernel_size, sigma,border_type='reflect')
        smr = smr / smr.max() * srmax
        smi = smi / smr.max() * simax
        ret = smr + 1j * smi
        return th.clone(ret[0])

from timeit import default_timer as timer
def admm(measurements : Union[LinearIndexEncoded4DDataset, Dense4DDataset], 
         r : th.tensor, 
         psi0 : th.tensor, 
         s_meta : SMeta, 
         n_iter : int, 
         n_batches : int, 
         options : ADMMOptions) -> SMPRSolution:
    
    do_position_correction = options.do_position_correction
    do_subpix = options.do_subpix
    beta = options.beta
    dL_dr_momentum = options.dL_dr_momentum
    dL_dr_step = options.dL_dr_step
    non_blocking = options.non_blocking
    r12 = options.r12
    r22 = options.r22
    eps1 = options.eps1
    eps2 = options.eps2
    verbose = options.verbose
    margin = options.margin
    kernel_size = options.kernel_size
    sigma = options.sigma.copy()
    
    cx_dtype = th.complex64
    dev_z = th.device(f'cpu')
    dev_compute = [th.device(f'cuda:{i}') for i in [0]]

    M = s_meta.M
    MY, MX = M[0].item(), M[1].item()
    N = s_meta.N
    NY, NX = N[0].item(), N[1].item()
    K = r.shape[0]
    batch_size = int(np.ceil(K / n_batches))
    slic = np.s_[0, MY // 2 + margin:-MY *2 - margin, MX // 2 + margin:-MX *2 - margin]
    
    shift = SubpixShift(MY, MX, dev_compute[0])
    sampler = BatchSampler(SequentialSampler(range(K)), batch_size=batch_size, drop_last=False)
    
    if isinstance(measurements, Dense4DDataset):
        sum_I = 0
        for batch_inds in sampler:
            I_b = measurements[batch_inds]
            sum_I += th.sum(I_b).item()   
    elif isinstance(measurements, LinearIndexEncoded4DDataset):
        sum_I = 0
        for batch_inds in sampler:
            I_b = measurements[batch_inds]
            sum_I += th.sum(I_b.counts).item()        
            
    a_norm = np.sqrt(sum_I)   
    Bp = s_meta.Bp
    
    z = th.zeros((K, MY, MX), dtype=cx_dtype, device=dev_z).pin_memory()
    S_model = th.zeros((Bp, NY, NX), dtype=cx_dtype, device=dev_compute[0])
    # S_model = th.ones((Bp, NY, NX), dtype=cx_dtype, device=dev_compute[0])
    # S_model.imag[:] = 0
    AtA = th.zeros((Bp, NY, NX), dtype=th.float32, device=dev_compute[0]) + 1e-6
    z_hat = th.zeros_like(z)
    Lambda = th.zeros_like(z)
    r = th.as_tensor(r, device=dev_compute[0])
    dL_dr_old = th.zeros_like(r)
    
    r_int = th.round(r).long()
    dr = r - r_int
    
    sampler = BatchSampler(SequentialSampler(range(K)), batch_size=batch_size, drop_last=False)
    # data_loader = DataLoader(measurements, shuffle=False, num_workers=0, pin_memory=False, sampler=sampler)
    
    for batch_inds in sampler:
        zb = z[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
        I_b = measurements[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
        K_b = zb.shape[0]
        
        psi = th.broadcast_to(psi0[:, None, ...], (Bp, K_b, MY, MX))
        zhb = th.fft.fft2(th.sum(psi, 0), norm='ortho') + 1e-4
    
        zb, _ = prox_D_gaussian(zb, zhb, I_b, 0)
        z[batch_inds] = zb.to(dev_z, non_blocking=non_blocking)
        
    for batch_inds in sampler:
        psi = shift(psi0, dr[batch_inds])
        AtA = Qoverlap_real(r_int[batch_inds], th.abs(psi) ** 2, AtA)

    for batch_inds in sampler:
        psi = shift(psi0, dr[batch_inds])
        zb = z[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
        zb = th.fft.ifft2(zb, norm='ortho')
        zb = th.conj(psi) * zb
        zb = shift(zb, dr[batch_inds])
        S_model = Qoverlap(r_int[batch_inds], zb, S_model)
    S_model /= AtA  

    if verbose:
        plotAbsAngle(S_model[slic].cpu(), 'S_model init')
    
    R_factors = []
    # for i in trange(n_iter, desc = 'ADMM iterations'):
    for i in range(n_iter):
        start = timer()
        for batch_inds in sampler:
            zz = z[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            LL = Lambda[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            z_hat[batch_inds] = th.fft.ifft2(zz + LL / beta, norm='ortho').to(dev_z, non_blocking=non_blocking)
        
        if do_subpix and do_position_correction(i):
            for batch_inds in sampler:
                K_b = len(batch_inds)
                S_split = th.zeros((Bp, K_b, MY, MX), dtype=cx_dtype, device=dev_compute[0])
                z_hatb = z_hat[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
                
                psi = shift(psi0, dr[batch_inds])
                S_split = Qsplit(r_int[batch_inds], S_model, S_split)
    
                # 2 x Bp x K x MY x MX
                dS_split_dr = gradient_fourier_2d(S_split)
                # 2 x K x Bp x MY x MX
                dS_split_drP = dS_split_dr * psi[None, ...]
    
                # 2 x Bp x K x MY x MX
                nom = th.real(th.conj(dS_split_drP) * z_hatb[None, None, ...])
                denom = th.abs(dS_split_drP) ** 2
    
                # 2 x K
                dL_dr = th.sum(nom, (1, 3, 4)) / th.sum(denom, (1, 3, 4))
                # K x 2
                dL_dr = dL_dr.transpose(0, 1)
                # max shift of +/-0.2 pixels
                dL_dr = th.min(th.stack([th.abs(dL_dr), th.zeros_like(dL_dr).fill_(0.2)]), 0).values * th.sgn(dL_dr)
    
                dL_dr_up = dL_dr * dL_dr_step + dL_dr_old[batch_inds] * dL_dr_momentum
                r[batch_inds] += dL_dr_up
                r[batch_inds, 0] = th.clamp(r[batch_inds, 0], 0, NY - MY)
                r[batch_inds, 1] = th.clamp(r[batch_inds, 1], 0, NX - MX)
    
                dL_dr_old[batch_inds] = dL_dr_up
        del S_split
        new_psi = th.zeros_like(psi0)
        new_psi_denom = th.zeros(psi0.shape, device=dev_compute[0])
        
        for batch_inds in sampler:
            K_b = len(batch_inds)
            S_split = th.zeros((Bp, K_b, MY, MX), dtype=cx_dtype, device=dev_compute[0])
            if do_subpix:
                z_hatb = z_hat[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
                S_split = Qsplit(r_int[batch_inds], S_model, S_split)
    
                psi0 = th.conj(S_split) * z_hatb
                # shift update in opposite direction
                psi0 = shift(psi0, -dr[batch_inds])
                # sum over K
                new_psi += th.sum(psi0, 1)
    
                S_split = shift(S_split, -dr[batch_inds])
                # sum over K
    
                new_psi_denom += th.sum(th.abs(S_split) ** 2, 1)
            else:
                new_psi_denom += calc_psi_denom(r_int[batch_inds], S_model, th.zeros(psi0.shape, device=dev_compute[0]))
                new_psi += calc_psi(r_int[batch_inds], S_model, z_hat[batch_inds], th.zeros_like(psi0))
        # # Bp x MY x MX
        psi0 = new_psi / new_psi_denom
        Psi0 = s_meta.beamlets * th.fft.fft2(psi0, norm='ortho')
        psi0 = th.fft.ifft2(Psi0, norm='ortho')
        del Psi0
        del new_psi
        del new_psi_denom
        del S_split
    
        # print(f"psi norm: {th.norm(psi0[0, 0])}")
        # plot(new_psi_denom[0].cpu().numpy(), 'psi_denom')
        # plotAbsAngle(psi0[0].cpu().numpy(),'torch')
        # plotcx(psi0[0].cpu().numpy(), 'torch')
        # update normalisation
        AtA[:] = 1e-6
        for batch_inds in sampler:
            psi = shift(psi0, dr[batch_inds])
            AtA = Qoverlap_real(r_int[batch_inds], th.abs(psi) ** 2, AtA)
    
        h2 = th.max(AtA)
        M2 = (h2 <= eps2) * eps2 + (h2 > eps2) * h2 * r22
        M2 = M2.to(th.float32)
        
        S_model_new = th.zeros_like(S_model)
        for batch_inds in sampler:
            psi = shift(psi0, dr[batch_inds])
            zhb = z_hat[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            S_model_new = Qoverlap(r_int[batch_inds], th.conj(psi) * zhb, S_model_new)
    
        S_model_new += M2 * S_model
        S_model = S_model_new / (AtA + M2)
        # del S_model_new
    
        if Bp > 1:
            S_model = gaussian(S_model, kernel_size, tuple(sigma))
            sigma *= 0.97
            
        # if verbose:
        #     plotAbsAngle(S_model[slic].cpu(), f'S_model {i}')
        # update model exit waves
        for batch_inds in sampler:
            psi = shift(psi0, dr[batch_inds])
            zh = A_realspace(r_int[batch_inds], S_model, psi, th.zeros_like(z_hat[batch_inds], device=dev_compute[0]))
            z_hat[batch_inds] = th.fft.fft2(zh, norm='ortho').to(dev_z, non_blocking=non_blocking)
        losses = []
        #  update model from data, update auxiliary variables
        for batch_inds in sampler:
            zhb = z_hat[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            Lb = Lambda[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            zb = z[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
            I_b = measurements[batch_inds].to(dev_compute[0], non_blocking=non_blocking)
    
            zhb -= Lb / beta
            # zb, loss = sparse_amplitude_prox(zb, zhb, Ii_b, I_b, beta)
            zb, loss = prox_D_gaussian(zb, zhb, I_b, beta)
            zhb += Lb / beta
    
            Lb += beta * zb
            Lb -= beta * zhb
    
            z[batch_inds] = zb.to(dev_z, non_blocking=non_blocking)
            z_hat[batch_inds] = zhb.to(dev_z, non_blocking=non_blocking)
            Lambda[batch_inds] = Lb.to(dev_z, non_blocking=non_blocking)
            losses.append(loss)
    
        losses = np.concatenate(losses)
        R_factor = np.sqrt(np.sum(losses)) / a_norm
        R_factors.append(R_factor)
        end = timer()
        if options.verbose:
            logging.info(f"{i:03d}/{n_iter:03d} [{(end - start):-02.2f}s] R-factor: {R_factor:3.3g}")
    
        if do_position_correction(i):
            for batch_inds in sampler:
                r_int[batch_inds] = th.round(r[batch_inds]).long()
                dr[batch_inds] = r[batch_inds] - r_int[batch_inds]

    out = SMPRSolution(
        converged=True,
        smatrix = S_model,
        probe = psi0,
        positions = r,
        r_factor = R_factors[-1],
        r_factor_history = th.tensor(R_factors),
        s_matrix_meta = s_meta
    )
    
    return out 

ImportError: attempted relative import with no known parent package