In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import math

# -----------------------------------------------------------------------------#
# Load images
# -----------------------------------------------------------------------------#
sc_mask = nib.load('/spinal_cord_seg.nii')  # ROI covering the spinal cord
sim_img = nib.load('/GRE-T1w.nii')         # Single slice T1w scan

sim_img_data = sim_img.get_fdata()
sc_mask_data = sc_mask.get_fdata()
sc_mask_data[sc_mask_data < 1] = 0

# -----------------------------------------------------------------------------#
# Ideal k-space
sim_FFT = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(sim_img_data)))

# -----------------------------------------------------------------------------#
# Image acquisition parameters (adjusted for GRE-EPI)
# -----------------------------------------------------------------------------#
matrix = sim_img.shape
image_res = sim_img.header.get_zooms()  # Voxel size in mm
fov = np.array(image_res) * np.array(matrix)  # Field of view in mm

TE = 500e-3  # [s], longer TE for EPI
TR = 1000e-3 # [s]
readout_time = 300e-3  # [s], total EPI readout duration
readout_start = TE - (readout_time / 2)
delta_t = readout_time / matrix[0]  # Time per kx line

# -----------------------------------------------------------------------------#
# Define constants for time-varying magnetic field
# -----------------------------------------------------------------------------#
w_r = 2 * math.pi / 3  # Respiratory frequency [rad/s]
RIROmax_uniform = 12  # Max RIRO frequency offset [Hz]

# -----------------------------------------------------------------------------#
# Define spatial distribution for RIROmax (through-plane)
# -----------------------------------------------------------------------------#
[x, y] = np.meshgrid(
    np.linspace(-(matrix[0]-1)/2, (matrix[0]-1)/2, matrix[0]),
    np.linspace(-(matrix[1]-1)/2, (matrix[1]-1)/2, matrix[1]),
    indexing='ij'
)
r = np.sqrt((x * image_res[0])**2 + (y * image_res[1])**2)  # Radial position [mm]
r = abs((r - np.max(r)) / np.max(r)) ** 4  # Normalized radial decay
sim_RIROmax = RIROmax_uniform * r  # Through-plane RIRO

# Limit RIROmax to signal regions
noise_mask = np.zeros(matrix)
noise_mask[0:4, 0:4] = 1
noise_mask[0:4, (-1-4):-1] = 1
noise_mask[(-1-4):-1, 0:4] = 1
noise_mask[(-1-4):-1, (-1-4):-1] = 1
noise_data = np.multiply(sim_img_data, noise_mask)
sigma = noise_data[noise_data != 0].std()
bkgrnd_mask = (sim_img_data > 15 * sigma).astype(float)
sim_RIROmax = np.multiply(bkgrnd_mask, sim_RIROmax)

# -----------------------------------------------------------------------------#
# Define k-space constants for EPI trajectory
# -----------------------------------------------------------------------------#
k_max = 1 / (2 * np.array(image_res))  # Max spatial frequencies [mm^-1]
delta_k = 1 / fov  # k-space step size [mm^-1]

# EPI trajectory: zigzag pattern
kx = np.linspace(-k_max[0], k_max[0], matrix[0])  # Frequency encode
ky = np.linspace(-k_max[1], k_max[1], matrix[1])  # Phase encode
time = np.linspace(0, readout_start + readout_time, matrix[0] * matrix[1])  # Time points

k_traj = np.zeros((matrix[0] * matrix[1], 2))  # [kx, ky] trajectory
for i in range(matrix[1]):
    start_idx = i * matrix[0]
    end_idx = (i + 1) * matrix[0]
    if i % 2 == 0:  # Forward kx traversal
        k_traj[start_idx:end_idx, 0] = kx
    else:  # Reverse kx traversal
        k_traj[start_idx:end_idx, 0] = kx[::-1]
    k_traj[start_idx:end_idx, 1] = ky[i]

# -----------------------------------------------------------------------------#
# Apply RIRO Distortion in k-space
# -----------------------------------------------------------------------------#
k_space_distorted = np.zeros_like(sim_FFT, dtype=complex)

for i in range(matrix[1]):  # Loop over phase encode lines
    start_idx = i * matrix[0]
    if i % 2 == 0:
        m_indices = np.arange(matrix[0])  # Forward: 0 to matrix[0]-1
    else:
        m_indices = np.arange(matrix[0]-1, -1, -1)  # Reverse: matrix[0]-1 to 0

    for j in range(matrix[0]):  # Readout across kx
        t_acq = readout_start + (i * delta_t) + (j * delta_t)

        # Cumulative phase shift from start of sequence to this readout point
        phase_RIRO = -2 * np.pi * sim_RIROmax * (
            - (1 / w_r) * np.cos(w_r * t_acq) + (1 / w_r) * np.cos(0)
        )

        # Apply progressive phase shift to the image at this moment
        distorted_img = sim_img_data * np.exp(1j * phase_RIRO)
        k_space_temp = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(distorted_img)))

        # Store the k-space data in the correct EPI zigzag order
        m = m_indices[j]
        k_space_distorted[m, i] = k_space_temp[m, i]

# -----------------------------------------------------------------------------#
# Reconstruct Image
# -----------------------------------------------------------------------------#
reconstructed_img = np.abs(np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(k_space_distorted))))

# -----------------------------------------------------------------------------#
# Visualization
fig = plt.figure(figsize=(25, 15))

ax1 = fig.add_subplot(1, 4, 1)
im1 = ax1.imshow(np.rot90(sim_RIROmax), cmap='jet', vmin=0, vmax=15)
ax1.set_title('RIROmax [Hz]')
plt.setp(ax1, xticks=[], yticks=[])
plt.colorbar(im1, fraction=0.046, pad=0.05)

ax2 = fig.add_subplot(1, 4, 2)
ax2.imshow(np.rot90(sim_img_data), cmap='gray', vmin=0, vmax=1200)
ax2.set_title('Ideal Image')
plt.setp(ax2, xticks=[], yticks=[])

ax3 = fig.add_subplot(1, 4, 3)
ax3.imshow(np.rot90(reconstructed_img), cmap='gray', vmin=0, vmax=1200)
ax3.set_title('Simulated EPI with RIRO')
plt.setp(ax3, xticks=[], yticks=[])

plt.show()
