In [None]:
import   os
import   sys
sys.path.insert(0, os.path.abspath('.'))
from     ismrmrdUtils   import   rawMRutils
import   numpy          as       np
import   ismrmrd

## First - read in (fully sampled) EPI data from array coil

In [None]:
dataHeader, dataArray, refDataArray = rawMRutils.returnHeaderAndData('./ScanArchive_EPI_converted.h5')

In [None]:
dataArray.shape

In [None]:
refDataArray.shape

In [None]:
# implement ramp-sampling regridding using the same algorithms implemented in the "computeTrajectory()" method in:
# gadgetron_sources/toolboxes/mri/epi/EPIReconXObjectTrapezoid.h - mostly to allow comparisons to be made.

# The basic idea is to take the acquired data, sampled on the ramps and flat top of the acquisition, i.e.
#
#                               Data Acquisition Window
#                           <------------------------------>
#                           .                              .
#                           .     --------------------     .
#                           .    /                    \    .
#                           .   /                      \   .
#                           .  /                        \  .   
#                           . /                          \ .
#                           ./                            \.
#                           /                              \
#                          /                                \
#                         /                                  \
#                        /                                    \
#                       /                                      \
#
# where, because of the ramps, the data is now non-uniformly sampled in k-space.  Since position / location in
# k-space is determined by the time integral of the gradient waveform up to that point, we can take the total
# area under the ramp-sampled acquisition, divide that area evenly to determine the location of k-space points
# after regridding / interpolation.

# First, extract ramp parameters from header.

traj = dataHeader.encoding[0].trajectory
if (traj == 'epi'):
   trajID = dataHeader.encoding[0].trajectoryDescription.identifier
   if (trajID == 'ConventionalEPI'):
      print ("Trajectory for this dataset is %s, and trajectory ID is %s" % (traj, trajID))

      # Iterate over the elements of the trajectory section, and get 'long' parameters needed for EPI.
      for i, trajValue in enumerate(dataHeader.encoding[0].trajectoryDescription.userParameterLong[:]):
         if (trajValue.orderedContent()[0].value == 'acqDelayTime'):
            acqDelayTime_ = trajValue.orderedContent()[1].value

         if (trajValue.orderedContent()[0].value == 'rampUpTime'):
            rampUpTime_   = trajValue.orderedContent()[1].value

         if (trajValue.orderedContent()[0].value == 'flatTopTime'):
            flatTopTime_  = trajValue.orderedContent()[1].value

         if (trajValue.orderedContent()[0].value == 'rampDownTime'):
            rampDownTime_ = trajValue.orderedContent()[1].value

         if (trajValue.orderedContent()[0].value == 'numSamples'):
            numSamples_   = trajValue.orderedContent()[1].value

      for i, trajValue in enumerate(dataHeader.encoding[0].trajectoryDescription.userParameterDouble[:]):
         if (trajValue.orderedContent()[0].value == 'dwellTime'):
            dwellTime_    = trajValue.orderedContent()[1].value

# Code for ramp-sampling from @jad11nih.
numReadoutAcqPoints   = dataArray.shape[1]
numReadoutReconPoints = dataHeader.encoding[0].reconSpace.matrixSize.x

totalGradientTime     = rampUpTime_ + flatTopTime_ + rampDownTime_
totalReadoutTime      = numReadoutAcqPoints * dwellTime_

numGradientRampPoints = int(rampUpTime_  / dwellTime_)
numGradientFlatPoints = int(flatTopTime_ / dwellTime_)

if (numGradientFlatPoints >= numReadoutAcqPoints):
   numAcqFlatPoints = numReadoutAcqPoints
else:
   numAcqFlatPoints = numGradientFlatPoints

numAcqRampPoints    = int ((numReadoutAcqPoints - numAcqFlatPoints) / 2)

doRampSampledRecon  = 0

if (numAcqRampPoints > 0):    # i.e. Data is being acquired on EPI ramps, we have ramp-sampled data!
   doRampSampledRecon     = 1 # Use this as flag for recon later on.

   # Build gradient waveform
   gx = np.zeros(numReadoutAcqPoints, dtype=np.float64)

   numGradientSkipPoints  = numGradientRampPoints - numAcqRampPoints
   gx[:numAcqRampPoints]  = np.linspace(numGradientSkipPoints + 0.5, numGradientRampPoints, numAcqRampPoints)
   gx[-numAcqRampPoints:] = np.flip(gx[:numAcqRampPoints])

   gx[numAcqRampPoints:numAcqRampPoints+numAcqFlatPoints] = numGradientRampPoints

   # Now, compute k-space along this gradient trajectory.
   kx    = np.cumsum(gx)

   # Rescale gx and kx
   scale = numReadoutReconPoints / kx[-1]
   kx    = kx * scale - numReadoutReconPoints/2

   # Create 'physical' pixel locations
   x     = np.linspace (-numReadoutReconPoints/2, numReadoutReconPoints/2, numReadoutReconPoints, False)

   # Now create the encoding matrix.
   E     = np.exp(2 * np.pi * 1.0j * np.outer(kx, x)/numReadoutReconPoints)
   pinvE = np.linalg.pinv(E)

In [None]:
testCplxArray = [0.0+1.0j, 2.0+3.0j, 4.0+5.0j, 6.0+7.0j, 8.0+9.0j]

In [None]:
import matplotlib
import matplotlib.pyplot as mplt

In [None]:

def computeN2GhostCorrection (angleData):

   x = np.linspace(-0.5,0.5,len(angleData))
   w = np.abs(angleData)
   # A = np.zeros([x.size,3])
   A = np.zeros([x.size,2])

   A[:,0] = w 
   A[:,1] = w*x
   # A[:,2] = w*x**2
   B      = w*np.angle(angleData) # magnitude-weighted phase data

   X,resid,rank,sigma = np.linalg.lstsq(A, B, rcond=None)

   oefit = X[0] + X[1]*x + np.pi # + X[2]*x**2

   return (np.exp(1j * oefit))



def computeN2GhostCorrectionAhnMethod (angleData):
   # based on algorithm in IEEE Trans on Med. Imag., Vol MI-6, No. 1, Pg. 32 - 36 (1987).
   #
   # Still start with the averaged difference between the odd and even lines.

   nPoints      = len(angleData)
   # x            = np.linspace(-0.5, 0.5, nPoints, False)
   x            = np.linspace(-nPoints/2, nPoints/2, nPoints, False)

   # This needs the range for the call to 'linspace' (above) set properly to represent phase.
   avgPhaseDiff   = np.average([(angleData[i] * np.conj(angleData[i+1])) for i in range(nPoints - 1)])

   linearPhaseFit = np.exp(-1.0j * np.angle(avgPhaseDiff) * x)

   # Remove linear phase trend to compute constant phase offset.
   linearRemoved  = angleData * linearPhaseFit
   constantPhase  = np.average(np.angle(linearRemoved))
   constantPhase  = np.exp(-1.0j * constantPhase)

   return (constantPhase * linearPhaseFit)



def loadNonComplexToComplex (nonComplexData):
   return (nonComplexData[0::2] + 1j * nonComplexData[1::2])



def oneDimTransform (data2Transform):
   return np.fft.fftshift(np.fft.fft(np.fft.fftshift(data2Transform)))


In [None]:
nCoils   = refDataArray.shape[0]
nNavs    = refDataArray.shape[2]
nPhases  =    dataArray.shape[2]
nSlices  = refDataArray.shape[3]
nReps    = refDataArray.shape[4]

if (doRampSampledRecon == 1):
   corrData = np.zeros((nCoils, numReadoutReconPoints, nNavs, nSlices, nReps), dtype=np.complex64)

   # Ramp-sampling correction already does the equivalent of a DFT to image space.
   for s in range (nSlices):
      for r in range (nReps):
         for c in range (nCoils):
            for n in range (nNavs):
               corrData[c, :, n, s, r] = pinvE @ refDataArray[c, :, n, s, r]
else:
   corrData = np.zeros(refDataArray.shape, dtype=np.complex64)
   # Transform and store reference lines to compute phase corrections.
   corrData[:, :, 0:nNavs:1, :, :] = np.fft.fftshift(np.fft.fft(np.fft.fftshift(refDataArray[:, :, 0:nNavs:1, :, :],
                                                                axes=[1]),
                                                     axis=1),
                                     axes=[1])
for s in range (nSlices):
   for r in range (nReps):
      for c in range (nCoils):
         # Store computed odd - even phase difference fit back in line '0' of reference data array.  Check
         # vendor as each vendor handles packing EPI reference data a little differently.
         if (dataHeader.acquisitionSystemInformation.systemVendor == 'GE MEDICAL SYSTEMS'):
            # GE has 4, but first line (index value == 0) doesn't contain useful EPI reference data.
            epiRefStart = 1
            # angleData = (corrData[c, :, 1, s, r] + corrData[c, :, 3, s, r]) * 0.5 * np.conj(corrData[c, :, 2, s, r])
         else:
            # Siemens data has 3 EPI reference lines, all holding relveant data.
            epiRefStart = 0
            # angleData = (corrData[c, :, 0, s, r] + corrData[c, :, 2, s, r]) * 0.5 * np.conj(corrData[c, :, 1, s, r])

         angleData = ((np.average(corrData[c, :,  epiRefStart::2,    s, r], axis=1)) *
               np.conj(np.average(corrData[c, :, (epiRefStart+1)::2, s, r], axis=1)))

         # Use 'raw' phase difference between odd/even lines directly, normalized by magnitude (suggested by @jad11).
         # corrData[c, :, 0, s, r] = angleData / abs(angleData)

         # Or call polynomial phase fitting routine.
         corrData[c, :, 0, s, r] = computeN2GhostCorrection(angleData)

         # Use Cho/Ahn method, which should be more robust to phase wraps - linear term currently not correct.
         # corrData[c, :, 0, s, r] = computeN2GhostCorrectionAhnMethod(angleData)

In [None]:
# Debugging cell - checking contents of reference lines and results of computations to fit phase.

fig, ax     = mplt.subplots(1, 1, sharex=True, sharey=True)

item2Plot   = corrData[4, :, :, 1, 3]
# item2Plot   = np.zeros ((100, 4) , dtype = np.complex64)
# item2Plot   = corrData[15, 70:170, 0, 1, 3]
# item2Plot   = phaseProfile

ax.plot((np.angle(item2Plot[:,:])), label=str('phase'))
# ax.plot(np.abs(item2Plot), label=str('mag'))
ax.legend()

In [None]:
# Apply phase corrections.

if (doRampSampledRecon == 1):
   correctedArray   = np.zeros((nCoils, numReadoutReconPoints, nPhases, nSlices, nReps), dtype=np.complex64)
   xFormedInReadout = np.zeros((nCoils, numReadoutReconPoints, nPhases, nSlices, nReps), dtype=np.complex64)

   for s in range (nSlices):
      for r in range (nReps):
         for c in range (nCoils):
            for n in range (nPhases):
               xFormedInReadout[c, :, n, s, r] = pinvE @ dataArray[c, :, n, s, r]
else:
   correctedArray   = np.zeros (dataArray.shape, dtype = np.complex64)
   xFormedInReadout = np.fft.fftshift(np.fft.fft(np.fft.fftshift(dataArray, axes=[1]), axis=1), axes=[1])

# Apply phase correction to alternating lines.
correctedArray[:, :, 0:dataArray.shape[2]:2, :, :] = (xFormedInReadout[:, :, 0:dataArray.shape[2]:2, :, :]
                                                          *   corrData[:, :, 0:1:1,                  :, :])

# Then copy in data that were not corrected, i.e. alternate lines NOT corrected above.
correctedArray[:, :, 1:dataArray.shape[2]:2, :, :] =  xFormedInReadout[:, :, 1:dataArray.shape[2]:2, :, :]

In [None]:
imageSpace = np.fft.fftshift(np.fft.fft(np.fft.fftshift(correctedArray, axes=[2]), axis=2), axes=[2])
# imageSpace = np.fft.fftshift(np.fft.fft2(dataArray, axes=(1,2)), axes=(1,2))
# rawMRutils.computeAndPlot(imageSpace)
plottedFigures = mplt.figure(figsize=(12,18))

subImages = plottedFigures.add_subplot(1, 1, (1))

disp = abs(imageSpace[:, :, :, 1, 3])
# subImages.imshow(disp)
subImages.imshow(np.sqrt(np.sum(disp*disp, axis=0)), 'viridis')