In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pyphase as phase 

In [36]:
import pyphase
from pyphase import *
import numpy as np 
from PIL import Image
import numpy as np
from math import *
import pyphase.parallelizer as Parallelizer
import pyphase.propagator as Propagator
import pyphase.tomography as Tomography
#from vendor.EdfFile import EdfFile #TODO: This should not be necessary here!
import matplotlib.pyplot as pyplot
import scipy.ndimage
from pyphase.config import *
from matplotlib.pyplot import pause
from scipy import interpolate
import pickle #TODO: Is there a better way to handle imports? Centralised?
from scipy import ndimage
import pyphase.dataset as Dataset


In [37]:
x_data_dir = "0 mm"
y_data_dir = "1 mm"

In [38]:
for x_tif in sorted(os.listdir(x_data_dir))[:180]:
    x_data = np.array([np.array(Image.open(os.path.join(x_data_dir, x_tif)))])  
for y_tif in sorted(os.listdir(y_data_dir))[:180]: 
    y_data = np.array([np.array(Image.open(os.path.join(y_data_dir, y_tif)))])

In [41]:
class PhaseRetrievalAlgorithm2D:
    """Base class for 2D phase retrieval algorithms.

    Parameters
    ----------
    dataset : pyphase.Dataset, optional 
        A Dataset type object.
    shape : tuple of ints, optional
        Size of images (ny, nx) for creation of frequency variables etc.
    pixel_size : float, optional
        In m.
    distance : list of floats, optional
        Effective propagation distances in m.
    energy : float, optional
        Effective energy in keV.
    alpha : tuple of floats, optional
        Regularisation parameters. First entry for LF, second for HF. 
        Typically [1e-8, 1e-10].
    pad : int
        Padding factor (default 2).

    Attributes
    ----------
    nx : int
        Number of pixels in horizontal direction.
    ny : int
        Number of pixels in horizontal direction.
    pixel_size : tuple of floats
        Pixel size [x, y] in µm.
    ND : int
        Number of positions.
    energy : float
        Energy in keV.
    alpha : tuple of floats
        First entry for LF, second for HF. Tyically [1e-8, 1e-10].
    distance : numpy array
        Effective propagation distances in m.
    padding : int
        Padding factor.
    sample_frequency : float
        Reciprocal of pixel size in lengthscale.
    nfx : int
        Number of samples in Fourier domain (horizontal).
    nfy : int
        Number of samples in Fourier domain (vertical).
    fx : numpy array
        Frequency variable (horizontal), calculated by frequency_variable.
    fy : numpy array
        Freuency variable (vertical).
    alpha_cutoff : float
        Cutoff frequency for regularization parameter in normalized frequency.
    alpha_slope : float
        Slope in regularization parameter.
    
    Notes
    -----
    Takes either a Dataset type object with keyword dataset (which contains
    all necessary parameters), or parameters as above.
    
    """
    
    def __init__(self, dataset=None, **kwargs):
        self.lengthscale=10e-6 # TODO where should this be... Makes formulae dimensionless
        #print(kwargs)
        # Check whether usage is with dataset or external images 
        #TODO: Which parameters in constructor?
        if (isinstance(dataset, Dataset.Dataset) or isinstance(dataset, Dataset.ESRF)): #TODO: needs integration 
            #dataset = args[0]
            self.nx=dataset.nx # TODO: Probably make into an array [x, y]
            self.ny=dataset.ny
            self.pixel_size = dataset.pixel_size*1e-6
            self.ND = len(dataset.position)
            self.energy = dataset.energy
            self.distance = np.array(dataset.effective_distance) #TODO: wallah I've mixed up position, distance, effective distance. Sort out
        else:
            self.nx = kwargs["shape"][1]
            self.ny = kwargs["shape"][0]
            self.pixel_size = kwargs["pixel_size"]
            self.distance = np.array(kwargs["distance"])
            self.energy = kwargs["energy"]
            self.ND = len(self.distance)

        if 'pad' in kwargs:
            self.padding = kwargs['pad']

        self.alpha=[-8, -10]
 
        if (type(self.pixel_size) == float) or (type(self.pixel_size) == np.float64) : # If x,y pixelsizes are not given, assign them
            self.pixel_size = np.array([self.pixel_size, self.pixel_size])
        elif type(self.pixel_size) == list:
            self.pixel_size = np.array(self.pixel_size)
        self.sample_frequency = self.lengthscale/self.pixel_size #TODO: Should be attribute
        self.fx, self.fy = self.frequency_variable(self.nfx, self.nfy, self.sample_frequency)
        self._compute_factors()
        self.alpha_cutoff = .5
        self.alpha_cutoff_frequency=self.alpha_cutoff*self.sample_frequency # TODO: should be a property (dynamically calculated from alpha_cutoff)
        self.alpha_slope = .1e3

    def _algorithm(self, image, positions=None):
        """'Pure virtual' method containing purely the algorithm code part. Should be defined by each subclass. """

    def _compute_factors(self):
        """
        Computes factors used in phase retrieval. 
        
        Motivation is to save time and memory when reconstructing a series of 
        projections on a processor (usually in parallel with others). Default 
        are the sin and cos chirps in CTF. CTF and Mixed approaches extend 
        this method, while TIE methods override it.
        """
        
        self.coschirp = np.zeros((self.ND, self.nfy, self.nfx))
        self.sinchirp = np.zeros_like(self.coschirp)
        for distance in range(self.ND):
            self.coschirp[distance] = np.cos((pi*self.Fresnel_number[distance]) * (self.fx**2) + (pi*self.Fresnel_number[distance]) * (self.fy**2))
            self.sinchirp[distance] = np.sin((pi*self.Fresnel_number[distance]) * (self.fx**2) + (pi*self.Fresnel_number[distance]) * (self.fy**2))

    def __getstate__(self):
        """
        Used in parallel computing to override writing of voluminous variables
        when serializing using pickle by parallelizer. They are instead 
        re-calculated by each process (see __setstate__()).
        """
        state = self.__dict__.copy()
        del state['fx'], state['fy'], state['sinchirp'], state['coschirp']
        return state

    def __setstate__(self, state):
        """
        Used in parallel computing to re-calculate voluminous variables when
        serializing using pickle by parallelizer instead of saving them to disk
        (see __setstate__()).
        """
        self.__dict__.update(state)
        self._compute_factors()
             
    @property
    def nfx(self):
        return self.padding*self.nx
    
    @property
    def nfy(self):
            return self.padding*self.ny
                
    @property
    def Lambda(self):
        """Wavelength based on energy (float)"""
        return 12.4e-10 / self.energy

    @property
    def Fresnel_number(self):
        """Fresnel number at each position, calculated from energy and distance (float)"""
        return self.Lambda * self.distance / (self.lengthscale**2)

    @property
    def Alpha(self):
        """Image implementation of regularisation parameter (np.array)"""
        x=np.linspace(-1,1,self.nfx)
        y=np.linspace(-1,1,self.nfy)
        xv, yv = np.meshgrid(x,y)
        R=np.sqrt(np.square(xv) + np.square(yv))
        R=np.fft.fftshift(R)
        if self.alpha[0] > self.alpha[1]:
            # Logistic function instead of error function (to be seen)
            Alpha = self.alpha[0] - ((self.alpha[0] - self.alpha[1]) / (1 + np.exp(-self.alpha_slope * (R-self.alpha_cutoff))))
        elif self.alpha[0] < self.alpha[1]:
            Alpha = self.alpha[0] + ((self.alpha[1] - self.alpha[0]) / (1 + np.exp(-self.alpha_slope * (R-self.alpha_cutoff))))
        else:
            Alpha = self.alpha[0] * R**0 
        return 10**Alpha 

    @Parallelize
    def reconstruct_projections(self, *, dataset, projections):
        """
        Reconstruct a range of projections (parallelized function).
        
        Parameters
        ----------
        dataset : pyphase.Dataset
            Dataset to reconstruct.
        projections : list of int
            In the form [start, end]
        
        """
        
        for projection in range(projections[0], projections[1]+1):
            #print("Projection: {}".format(projection))
            self.reconstruct_projection(dataset=dataset, projection=projection)

    def frequency_variable(self, nfx, nfy, sample_frequency):
        """
        Calculate frequency variables.
        
        Parameters
        ----------
        nfx : int
            Number of samples in x direction
        nfy : int
            Number of samples in y direction
        sample_frequency : float
            Reciprocal of pixel size in 1/m
            
        Returns
        -------
        nparray
            Frequency variables as an array of size  [nfy, nfx, 2]
            
        Notes
        -----
        Follows numpy FFT convention. Zero frequency at [0,0], [1:n//2] 
        contain the positive frequencies, [n//2 + 1:] n the negative 
        frequencies in increasing order starting from the most negative 
        frequency.
        """
        
        if type(sample_frequency) == int:
            sample_frequency = np.array([sample_frequency, sample_frequency]) #TODO: refactor
        
        x=0
        x=np.append(x, np.linspace(sample_frequency[0]/nfx, sample_frequency[0]/2, nfx//2))
        x=np.append(x, np.linspace(-sample_frequency[0]/2+sample_frequency[0]/nfx, -sample_frequency[0]/nfx, int(nfx//2-1+(np.ceil(nfx/2)-nfx//2))))
        
        y=0
        y=np.append(y, np.linspace(sample_frequency[1]/nfy, sample_frequency[1]/2, nfy//2))
        y=np.append(y, np.linspace(-sample_frequency[1]/2+sample_frequency[1]/nfy, -sample_frequency[1]/nfy, int(nfy//2-1+(np.ceil(nfy/2)-nfy//2))))
        return np.meshgrid(x, y)

    def simple_propagator(self, pxs, Lambda, z):
        """
        Creates a Fresnel propagator.
        
        Parameters
        ----------
        pxs : float
            Pixel size in µm.
        Lambda : float
            Wavelength in m.
        z : float
            Effective propagation distance in m.
            
        Returns
        -------
        H : nparray
            Fresnel propagator.
        
        Notes
        -----
        Temporary implementation by Y. Zhang. Will be integrated with the
        propagator module.
        
        """
        
        # TODO: need to be replaced by the propagator class
        # TODO: Needs refactoring to meet coding standards
        # Generates the ifftshift version of the propagators
        x = np.arange(-np.fix(self.nfx / 2), np.ceil(self.nfx / 2))
        y = np.arange(-np.fix(self.nfy / 2), np.ceil(self.nfy / 2))
        fx = np.fft.ifftshift(x / (self.nfx * pxs))
        fy = np.fft.ifftshift(y / (self.nfy * pxs))
        Fx, Fy = np.meshgrid(fx, fy)
        f2 = Fx ** 2 + Fy ** 2
        H = np.zeros([self.ND, self.nfy, self.nfx], dtype="complex_")
        for distance in range(self.ND):
            H[distance,:,:] = np.exp(-1j * np.pi * Lambda * z[distance] * f2) 
        return H

    def reconstruct_projection(self, dataset, projection=0, positions=None, pad=True):
        """
        Reconstruct one projection from a Dataset object and saves the result.
        
        Parameters
        ----------
        dataset : Dataset
            Dataset object to use (not necessarily the same as initialised).
        projection : int, optional
            Number of projection to reconstruct.
        positions : int or list of ints, optional
            Subset of positions to use for reconstruction
            
        Returns
        -------
        phase : np.array
            Reconstructed phase.
        attenuation : np.array
            Reconstructed attenuation.        
        """      
                    
        ID = np.zeros((self.ND, self.nfy, self.nfx))
        #dataset.padding = self.padding #TODO: needs proper sorting out...
        for position in range(self.ND):
#                if difference:
#                    FID[position] = dataset.get_image(projection=projection, difference=True, Fourier=True, position=position) #TODO: difference not really necessary, just give the difference image as input...
#                else:
                    ID[position] = dataset.get_projection(projection=projection, position=position, pad=pad)
#       elif image.any():
        
        # Call algorithm part
        phase, attenuation = self.reconstruct_image(ID, positions=positions, pad=pad)
        
        
#            if difference:
#                dataset.write_image(phase, 'phase update', projection)
#            else:
        dataset.write_image(image=Utilities.resize(phase, [self.ny, self.nx]), projection=projection)
        dataset.write_image(image=Utilities.resize(attenuation, [self.ny, self.nx]), projection=projection, projection_type='attenuation')

        return phase, attenuation

    def reconstruct_image(self, image, positions=None, pad=False):
        """
        Template for reconstructing an image given as argument. 
        
        Arguments
        ---------
        image : numpy.array
            A phase contrast image or an ndarray of images stacked along the
            first dimension.
        positions : int or list of ints, optional
            Subset of positions to use for reconstruction.
        
        Note
        ----
        Calls _algorithm (container purely for algorithm part).
        
        Returns
        -------
        phase : numpy.array
        attenuation : numpy.array
        """

        if len(image.shape) == 2: # If only one image is given, add 3rd dimention for compatibility with loops
            image = image[np.newaxis]

        if not positions:
            positions = list(range(self.ND))
        elif not isinstance(positions, list):
            positions = [ positions ]
            
        image = Utilities.resize(image, [self.nfy, self.nfx])
            
        phase, attenuation = self._algorithm(image, positions=positions)
        
        if not pad:
            phase = Utilities.resize(phase, [self.ny, self.nx])
            attenuation = Utilities.resize(attenuation, [self.ny, self.nx])
        
        return phase, attenuation


In [73]:
class TIEHOM:
    """
    Transport of Intensity Equation for homogeneous objects (or "Paganin's algorithm") [1]
    
    Parameters
    ----------
    delta_beta : float, optional
        Material dependent ratio delta over beta.
    
    References
    ----------
    [1] Paganin et al. J. Microsc. 206 (2002) 33
    """
        
    def __init__(self, image=None, delta_beta=500, **kwargs):    
        self.lengthscale=10e-6
        self._delta_beta=delta_beta
        self.padding = 2 
        self.nx=image.shape[0]
        self.ny=image.shape[1]
        self.pixel_size = 1e-6
        self.ND = len(image.shape)
        self.distance = np.array(1) 
        self.nfx = self.padding*self.nx 
        self.nfy = self.padding*self.ny  
        self.energy = 500 #KeV 
        self.Lambda = 12.4e-10 / self.energy
        self.Fresnel_number = self.Lambda * self.distance / (self.lengthscale**2) 
  
    def _compute_factors(self):
        """Calculate TIEHOM factors. Overrides PhaseRetrievalAlgorithm2D."""
        self.TIEHOM_factor = [0 for xxx in range(self.ND)] 
        for distance in range(self.ND):
            self.TIEHOM_factor[distance] = 1 + self.Fresnel_number[distance] * np.pi * self.delta_beta * ((self.fx ** 2) + (self.fy ** 2))

    def _algorithm(self, image, positions=None):
        """Reconstruct one image or a set of images using TIEHOM."""
        #TODO: Needs verification on simpler images

        FID=np.fft.fft2(image)

        numerator_TIEHOM = np.zeros((self.nfy, self.nfx))
        denominator_TIEHOM = numerator_TIEHOM.copy()  
        
        TIEHOM_factor = self._compute_factors() 
        
        if len(positions) == 1:
            phase = 1/2 * self.delta_beta * np.log(np.real(np.fft.ifft2(FID[positions[0]] / TIEHOM_factor[positions[0]])))
        else:  
            for position in positions:
                numerator_TIEHOM = numerator_TIEHOM + (self.TIEHOM_factor[position] * FID[position]) 
                denominator_TIEHOM = denominator_TIEHOM + (self.TIEHOM_factor[position])**2 
            phase = 1/2 * self.delta_beta * np.log(np.real(np.fft.ifft2(numerator_TIEHOM / denominator_TIEHOM)))
        
        attenuation = -1/(self.delta_beta) * phase

        return phase, attenuation


In [74]:
TIEHOM(image=x_data[0])._algorithm(image=x_data[0], positions=[1])

AttributeError: 'TIEHOM' object has no attribute 'lengthscale'