In [None]:
import math
import json
import nibabel as nib
import numpy as np
from numpy.fft import ifftn, fftn, ifft, fftshift, ifftshift
import os
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.signal import butter, lfilter, freqz, filtfilt
from scipy.io import loadmat
import warnings
PI_UNICODE = "\U0001D70B"
CHI_UNICODE = "\U0001D712"
MICRO_UNICODE = "\u00B5"
GYRO_BAR_RATIO_H = 42.6e6  # [Hz/T]

In [None]:
fname_mag_e1 = os.path.join("fmap", "sub-fmap_magnitude1.nii.gz")
fname_mag_e1_json = os.path.join("fmap", "sub-fmap_magnitude1.json")
fname_phase_e1 = os.path.join("fmap", "sub-fmap_phase1.nii.gz")
fname_phase_e2 = os.path.join("fmap", "sub-fmap_phase2.nii.gz")
fname_mask = os.path.join("fmap", "mask.nii.gz")
fname_fmap = os.path.join("fmap", "fmap.nii.gz")
mag_e1 = nib.load(fname_mag_e1).get_fdata()[30:-30,8:108,30]
phase_e1 = nib.load(fname_phase_e1).get_fdata()[30:-30,8:108,30]
phase_e2 = nib.load(fname_phase_e2).get_fdata()[30:-30,8:108,30]
fmap_hz = nib.load(fname_fmap).get_fdata()[30:-30,8:108,30]
# mag_e1 = nib.load(fname_mag_e1).get_fdata()
# phase_e1 = nib.load(fname_phase_e1).get_fdata()
# phase_e2 = nib.load(fname_phase_e2).get_fdata()
# mask = nib.load(fname_mask).get_fdata()
# fmap_hz = nib.load(fname_fmap).get_fdata()

with open(fname_mag_e1_json, 'r') as json_data:
    data = json.load(json_data)

freq = data["ImagingFrequency"]
fmap_t = fmap_hz / GYRO_BAR_RATIO_H * 1e6
fmap_ppm = fmap_hz / freq

fig = go.Figure()
fig.add_trace(go.Heatmap(z=np.rot90(mag_e1, k=-1),
                         colorscale='gray',
                         colorbar=dict(
                            title="a.u.",
                            titleside="top",
                            tickmode="array"
                         ))
             )
fig.add_trace(go.Heatmap(z=np.rot90(phase_e2 / 4095 * 2 * math.pi - math.pi, k=-1),
                         colorscale='gray',
                         colorbar=dict(
                            title="Rad",
                            titleside="top",
                            tickmode="array",
                            tickvals=[-math.pi, 0, math.pi-0.01],
                            ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}']
                        ),
                         visible=False))
fig.add_trace(go.Heatmap(z=np.rot90(fmap_hz, k=-1),
                         colorscale='gray',
                         colorbar=dict(
                            title="Hz",
                            titleside="top",
                            tickmode="array"
                        ),
                         visible=False))
fig.add_trace(go.Heatmap(z=np.rot90(fmap_t, k=-1),
                         colorscale='gray',
                         colorbar=dict(
                            title=f"{MICRO_UNICODE}T",
                            titleside="top",
                            tickmode="array"
                        ),
                         visible=False))
fig.add_trace(go.Heatmap(z=np.rot90(fmap_ppm, k=-1),
                         colorscale='gray',
                         colorbar=dict(
                            title="ppm",
                            titleside="top",
                            tickmode="array"
                        ),
                         visible=False))



x0=0
y0=89
x1=10
y1=99
h = 2
rounded_bottom_left = f' M {x0+h}, {y0} Q {x0}, {y0} {x0}, {y0+h}'#
rounded_top_left = f' L {x0}, {y1-h} Q {x0}, {y1} {x0+h}, {y1}'
rounded_top_right = f' L {x1-h}, {y1} Q {x1}, {y1} {x1}, {y1-h}'
rounded_bottom_right = f' L {x1}, {y0+h} Q {x1}, {y0} {x1-h}, {y0}Z'
path = rounded_bottom_left + rounded_top_left+\
         rounded_top_right+rounded_bottom_right

annotations = ['A', 'B', 'C', 'D', 'E']
shapes = []
for i_shape, annotation in enumerate(annotations):
    shapes.append(dict(type='path',
                       path=path,
                       fillcolor='white',
                       layer='above',
                       line=dict(width=1),
                       label=dict(text=f"<b>{annotation}</b>")
                      )
                 )

fig.add_shape(shapes[0])
# Add dropdown
fig.update_layout(
    title_text="Magnitude",
    title_x=0.5,
    height=500,
    width=600,
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    method="update",
                    args=[{"visible": [True, False, False, False, False]},
                          {'shapes': [shapes[0]], "title": "Magnitude"}],
                    label="Magnitude",
                ),
                dict(
                    method="update",
                    args=[{"visible": [False, True, False, False, False]},
                          {'shapes': [shapes[1]], "title": "Phase"}],
                    label="Phase",
                    
                ),
                dict(
                    method="update",
                    args=[{"visible": [False, False, True, False, False]},
                          {'shapes': [shapes[2]], "title": "B0 Fieldmap (Hz)"}],
                    label="B0 field map (Hz)",
                ),
                dict(
                    method="update",
                    args=[{"visible": [False, False, False, True, False]},
                          {'shapes': [shapes[3]], "title": f"B0 Fieldmap ({MICRO_UNICODE}Tesla)"}],
                    label=f"B0 field map ({MICRO_UNICODE}Tesla)",
                ),
                dict(
                    method="update",
                    args=[{"visible": [False, False, False, False, True]},
                          {'shapes': [shapes[4]], "title": "B0 Fieldmap (ppm)"}],
                    label="B0 field map (ppm)",
                )
            ]),
            direction="down",
            showactive=True,
        )
    ]
)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

In [None]:
def dipole_kernel(b0_dir, voxel_size, n_voxels):
    """ Create a dipole kernel
    dipole kernel: (3*cos(theta)**2 - 1) / (4*pi*r**3)
                => (3*r**2*cos(theta)**2 - r**2) / (4*pi*r**5)
                => (3*b0_dir**2 - r**2) / (4*pi*r**2**2.5)

        Function inspired and derived from: https://onlinelibrary.wiley.com/doi/10.1002/mrm.28716
    """
    eps = 0.00001
    x, y, z = np.meshgrid(range(round(-n_voxels[0]/2+0.5), round(n_voxels[0]/2+0.5)), range(round(-n_voxels[1]/2+0.5), round(n_voxels[1]/2+0.5)), range(round(-n_voxels[2]/2+0.5), round(n_voxels[2]/2+0.5)), indexing='ij')

    x = x * voxel_size[0] + eps
    y = y * voxel_size[1] + eps
    z = z * voxel_size[2] + eps

    r2 = (x**2 + y**2 + z**2)

    d = np.prod(voxel_size) * ( 3 * ((x*b0_dir[0] + y*b0_dir[1] + z*b0_dir[2])**2) - r2 ) / (4 * math.pi * r2**2.5)

    d[np.isnan(d)] = eps
    D = np.real(fftshift(fftn(ifftshift(d))))

    mid_voxel = n_voxels[0]//2
    return d[n_voxels[1]//2], D[n_voxels[1]//2]

b0_dir = (0, 0, 1)
voxel_size = np.array((1, 1, 1)) * 1e-3
n_voxels = (201,201,201)
d, D = dipole_kernel(b0_dir, voxel_size, n_voxels)

fig = go.Figure()
fig = make_subplots(rows=1, cols=2, shared_xaxes=False, horizontal_spacing=0.13, vertical_spacing = 0.12, subplot_titles=("Dipole Kernel (d)", "Dipole Kernel (D)"), specs=[[{"type": "Heatmap"}, {"type": "Heatmap"}]])
fig.add_trace(go.Heatmap(z=d, colorscale='gray', showscale=False, zmin=-1e-6, zmax=1e-6))
fig.add_trace(go.Heatmap(z=D, colorscale='gray', showscale=False), 1, 2)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(
    height=500,
    width=900)
fig.show()

In [None]:
# Load cylinder (Y 64)
susc = nib.load(os.path.join("field_simulations", "cylinder", "Chi.nii.gz")).get_fdata()
fmap_hz_all = nib.load(os.path.join("field_simulations", "cylinder", "fmap_hz.nii.gz")).get_fdata()
local_field_cyl = nib.load(os.path.join("field_simulations", "cylinder", "local_field.nii.gz")).get_fdata()

# Load brain (Z 210)
susc_brain = nib.load(os.path.join("field_simulations", "brain", "chi_masked.nii.gz")).get_fdata()
fmap_hz_brain_all = nib.load(os.path.join("field_simulations", "brain", "fmap_masked.nii.gz")).get_fdata()
local_field_brain = nib.load(os.path.join("field_simulations", "brain", "local_field.nii.gz")).get_fdata()

fig = make_subplots(rows=1, cols=3, shared_xaxes=False, horizontal_spacing=0.13, vertical_spacing = 0.12, subplot_titles=(f"Susceptibility distribution ({CHI_UNICODE})", "Simulated B0 map", "Simulated B0 map<br>no background field"), specs=[[{"type": "Heatmap"}, {"type": "Heatmap"}, {"type": "Heatmap"}]])
fig.add_trace(go.Heatmap(z=susc, colorscale='gray', colorbar_x=1/3 - 0.09, colorbar=dict(title="ppm", titleside="top")), 1, 1)
fig.add_trace(go.Heatmap(z=fmap_hz_all, colorscale='gray', colorbar_x=2/3 - 0.045, colorbar=dict(title="Hz", titleside="top")), 1, 2)
fig.add_trace(go.Heatmap(z=local_field_cyl, colorscale='gray', colorbar_x=1-0.004, colorbar=dict(title="Hz", titleside="top")), 1, 3)
fig.add_trace(go.Heatmap(z=np.rot90(susc_brain, k=-1), colorscale='gray', colorbar_x=1/3 - 0.09, zmin=-0.5, zmax=0.5, colorbar=dict(title="ppm", titleside="top"), visible=False), 1, 1)
fig.add_trace(go.Heatmap(z=np.rot90(fmap_hz_brain_all, k=-1), colorscale='gray', colorbar_x=2/3 - 0.045, zmin=1100, zmax=2300, colorbar=dict(title="Hz", titleside="top"), visible=False), 1, 2)
fig.add_trace(go.Heatmap(z=np.rot90(local_field_brain, k=-1), colorscale='gray', zmin=-4, zmax=4, colorbar_x=1-0.004, colorbar=dict(title="Hz", titleside="top"), visible=False), 1, 3)

### Create buttons for drop down menu
labels = ["Cylinders", "Brain"]
buttons = []
for i, label in enumerate(labels):
    if label == "Cylinders":
        visibility = [True, True, True, False, False, False]
    else:
        visibility = [False, False, False, True, True, True]
    button = dict(
                 label =  label,
                 method = 'update',
                 args = [{'visible': visibility},
                         {'title': label}])
    buttons.append(button)

updatemenus = list([
    dict(active=0,
         x=-0.15,
         buttons=buttons,
         showactive=True,
    )
])

fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout({"height": 380, "width": 900},
                  title_text="B0 maps from susceptibility simulations",
                  title_x=0.5,
                  updatemenus=updatemenus,
                  showlegend=False
                 )
fig.show()

In [None]:
# Note: Field was reduced a lot to be able to show the sinusoid
# Note: *2 after lowpass filter is because this is a single coil (sin instead of e^(-ix)) and demodulating by multiplying a sinusoid creates a 1/2 difference. In practice, since we have both x and y components, we can recover the full signal instead of doing X2.
GYRO_BAR_RATIO_H = 42.6e6  # [Hz/T]
b0 = 0.000002  # [T]
T2 = 0.3  # s
y_0_cst = 100
fs = 10000

f_larmor = b0 * GYRO_BAR_RATIO_H
t = np.linspace(0, 1, fs + 1)  # 1 second

def butter_lowpass(cutoff, fs, order=5):
    return butter(order, cutoff, fs=fs, btype='low', analog=False)

def butter_lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = filtfilt(b, a, data, method='gust')
    return y

# Lab frame
y_0 = y_0_cst * np.sin(2 * math.pi * f_larmor * t)
exp = np.exp(-t/T2)
y = y_0 * exp / y_0_cst
temp = y * (y_0 / y_0_cst)
y_demod = butter_lowpass_filter(temp, f_larmor, fs, order=5) * 2

fig = go.Figure()
fig.add_scatter(x=t, y=y, name="FID")
fig.add_scatter(x=t, y=y_demod, name="Demodulated FID")
fig.add_scatter(x=t, y=exp, name="T2 decay")

# 2 isochromats
y_1amp = y_0_cst / 10
y_1 = y_1amp * np.sin(2 * math.pi * (f_larmor + 10) * t)
y = (y_0 + y_1) * exp / (y_0_cst + y_1amp)
temp = y * (y_0 / y_0_cst)
y_demod = butter_lowpass_filter(temp, f_larmor, fs, order=5) * 2

fig.add_scatter(x=t, y=y, name="FID", visible=False)
fig.add_scatter(x=t, y=y_demod, name="Demodulated FID", visible=False)
fig.add_scatter(x=t, y=exp, name="T2 decay", visible=False)

# Multiple isochromats
n_freqs = 100
fid = y_0 * exp
y_sum_demod = butter_lowpass_filter(fid * (y_0 / y_0_cst), f_larmor, fs, order=5) * 2
for i in range(n_freqs):
    amp = 1
    freq_offset = 10
    scale = freq_offset/n_freqs
    mid = n_freqs // 2
    y_1 = amp * np.sin(2 * math.pi * (f_larmor + scale*(mid - i)) * t) * exp
    fid += y_1
    y_demod = butter_lowpass_filter(y_1 * (y_0 / y_0_cst), f_larmor, fs, order=5) * 2
    y_sum_demod += y_demod

y_demod_scaled = y_sum_demod / (y_0_cst + (n_freqs * amp))
fid_scaled = fid / (y_0_cst + (n_freqs * amp))

fig.add_scatter(x=t, y=fid_scaled, name="FID", visible=False)
fig.add_scatter(x=t, y=y_demod_scaled, name="Demodulated FID", visible=False)
fig.add_scatter(x=t, y=exp, name="T2 decay", visible=False)
fig.update_traces(marker=dict(size=3))
fig.update_layout(
    title="Single species",
    title_x=0.5,
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=[{"visible": [True, True, True, False, False, False, False, False, False]},
                          {"title": "Single species"}],
                    label="Single species",
                    method="update"
                ),
                dict(
                    args=[{"visible": [False, False, False, True, True, True, False, False, False]},
                          {"title": "Two species"}],
                    label="Two species",
                    method="update"
                ),
                dict(
                    args=[{"visible": [False, False, False, False, False, False, True, True, True]},
                          {"title": "Multiple Species"}],
                    label="Multiple Species",
                    method="update"
                )
            ]),
            direction="down",
            showactive=True,

        ),
    ])
fig.show()

In [None]:
def calc_dk(gx, gy, dt):
    dkx = GYRO_BAR_RATIO_H * gx * dt
    dky = GYRO_BAR_RATIO_H * gy * dt
    return (dkx, dky)

gy_bad = 100e-6

end_time = 0.0912
n_times = 913
t = np.linspace(0, end_time, n_times)
dt = end_time / n_times
nx = 64
k = np.zeros([n_times, 2])
k_distorted = np.zeros([n_times, 2])
for it in range(1, n_times):
    
    if it <= 20:
        gx = -40e-3
        gy = -40e-3
    else:
        n_steps = 138
        i = (it - 20) % n_steps
        if i <= 0:
            gx = 0
            gy = 25e-3
        elif i <= nx:
            gx = 25e-3
            gy = 0
        elif i <= nx + 5:
            gx = 0
            gy = 25e-3
        elif i <= (2*nx) + 5:
            gx = -25e-3
            gy = 0
        elif i <= n_steps:
            gx = 0
            gy = 25e-3

    dkx, dky = calc_dk(gx, gy, dt)
    kx = k[it - 1, 0] + dkx
    ky = k[it - 1, 1] + dky
    k[it, :] = [kx, ky]
    dkx, dky = calc_dk(gx, gy + gy_bad, dt)
    kx = k_distorted[it - 1, 0] + dkx
    ky = k_distorted[it - 1, 1] + dky
    k_distorted[it, :] = [kx, ky]


fig = go.Figure()
fig.add_trace(go.Scatter(x=k[:, 0], y=k[:, 1],
                     mode='lines',
                     line=dict(color='#636EFA'),
                     name='Theoretical trajectory'))
fig.add_trace(go.Scatter(x=k_distorted[:, 0], y=k_distorted[:, 1],
                     mode='lines',
                     line=dict(color='#fa6363'),
                     name='Inhomogeneous trajectory'))
frames = [dict(
    data=[go.Scatter(x=k[:2*i, 0], y=k[:2*i, 1],
                     mode='lines',
                     line=dict(color='#636EFA'),
                     name='Theoretical trajectory'),
          go.Scatter(x=k_distorted[:2*i, 0], y=k_distorted[:2*i, 1],
                     mode='lines',
                     line=dict(color='#fa6363'),
                     name='Inhomogeneous trajectory')],
    name=str(i),
    traces=[0,1]) for i in range(int(n_times/2))]
fig.frames = frames

fig.update_xaxes(range=[-3500, 3500])
fig.update_yaxes(range=[-3700, 3700])
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(title="K-space Trajectory",
                  title_x=0.5,
                  height=600,
                  width=700,
                  updatemenus=[dict(
                                type='buttons',
                                buttons=[dict(label='Play',
                                              method='animate',
                                              args=[None, dict(frame=dict(duration=15, redraw=False), transition=dict(duration=15), fromcurrent=True, mode='immediate')]),
                                         dict(label='Pause',
                                              method='animate',
                                              args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')])
                                        ])])
fig.show()