In [None]:
import numpy as np
import plotly.express as px
import os
import nibabel as nib
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import math
import json
PI_UNICODE = "\U0001D70B"

In [None]:
# Parameters
num_frames = 200
omega_0 = 1  # Larmor frequency
omega_1 = 0.9  # Inhomogeneous spin
time_max = 5  # [s]

# Initial phase of the spin
initial_phase = 0.5

# Time array
time = np.linspace(0, time_max, num_frames)

# Generate data for spins
x = np.cos(omega_0 * (2*math.pi) * time + initial_phase)
y = np.sin(omega_0 * (2*math.pi) * time + initial_phase)
x1 = np.cos(omega_1 * (2*math.pi) * time + initial_phase)
y1 = np.sin(omega_1 * (2*math.pi) * time + initial_phase)

# Generate data for spins in rotating frame of reference
x_rot = np.cos((omega_0-omega_0) * (2*math.pi) * time + initial_phase)
y_rot = np.sin((omega_0-omega_0) * (2*math.pi) * time + initial_phase)
x1_rot = np.cos((omega_1-omega_0) * (2*math.pi) * time + initial_phase)
y1_rot = np.sin((omega_1-omega_0) * (2*math.pi) * time + initial_phase)

# Calculate angles
angles = (np.arctan2(y,x))
angles1 = (np.arctan2(y1,x1))
angles_rot = (np.arctan2(y_rot,x_rot))
angles1_rot = (np.arctan2(y1_rot,x1_rot))

# Create figure
fig = make_subplots(rows=1, cols=2, shared_xaxes=False, horizontal_spacing=0.1, subplot_titles=("Spin Rotating Through Time", "Phase Evolution of the Signal (rad)"))

# Add spin as an arrow
fig.add_trace(go.Scatter(
    x=[0, x[0]],
    y=[0, y[0]],
    mode='lines+markers',
    marker=dict(size=5),
    line=dict(color='blue', width=3),
    name='Spin at f0'),
    row=1, col=1)
fig.add_trace(go.Scatter(
    x=[0, x1[0]],
    y=[0, y1[0]],
    mode='lines+markers',
    marker=dict(size=5),
    line=dict(color='red', width=3),
    name='Inhomogeneous Spin'),
    row=1, col=1)

# Add phase of the signal
fig.add_trace(go.Scatter(
    x=[time[0]],
    y=[angles[0]],
    mode='markers',
    marker=dict(color='blue', size=5),
    name='Phase'),
    row=1, col=2
)
fig.add_trace(go.Scatter(
    x=[time],
    y=[angles1],
    mode='markers',
    marker=dict(color='red', size=5),
    name='Inhomogeneous Phase'),
    row=1, col=2
)
# Rotating frame
fig.add_trace(go.Scatter(
    x=[0, x_rot[0]],
    y=[0, y_rot[0]],
    mode='lines+markers',
    marker=dict(size=5),
    line=dict(color='blue', width=3),
    name='Spin at f0', visible=False),
    row=1, col=1)
fig.add_trace(go.Scatter(
    x=[0, x1_rot[0]],
    y=[0, y1_rot[0]],
    mode='lines+markers',
    marker=dict(size=5),
    line=dict(color='red', width=3),
    name='Inhomogeneous Spin', visible=False),
    row=1, col=1)

# Add phase of the signal
fig.add_trace(go.Scatter(
    x=[time[0]],
    y=[angles_rot[0]],
    mode='markers',
    marker=dict(color='blue', size=5),
    name='Phase', visible=False),
    row=1, col=2
)
fig.add_trace(go.Scatter(
    x=[time],
    y=[angles1_rot],
    mode='markers',
    marker=dict(color='red', size=5),
    name='Inhomogeneous Phase', visible=False),
    row=1, col=2
)

fig.update_xaxes(range=[-1.1, 1.1], row=1, col=1)
fig.update_yaxes(range=[-1.1, 1.1], row=1, col=1)
fig.update_xaxes(range=[np.min(time), np.max(time)], row=1, col=2)
fig.update_yaxes(range=[np.min(angles) + 0.1*np.min(angles), np.max(angles) + 0.1*np.max(angles)], row=1, col=2)

# Add frames
frames = [dict(
    data=[go.Scatter(x=[0, x[i]],
                     y=[0, y[i]],
                     mode='lines+markers',
                     marker=dict(size=5),
                     line=dict(color='#636EFA', width=3),
                     name='Spin at f0'),
          go.Scatter(x=[0, x1[i]],
                     y=[0, y1[i]],
                     mode='lines+markers',
                     marker=dict(size=5),
                     line=dict(color='#fa6363', width=3),
                     name='Inhomogeneous Spin'),
          go.Scatter(x=time[:i],
                     y=angles[:i],
                     mode='lines',
                     marker=dict(size=5),
                     line=dict(color='blue', width=3),
                     name='Phase'),
          go.Scatter(x=time[:i],
                     y=angles1[:i],
                     mode='lines',
                     marker=dict(size=5),
                     line=dict(color='red', width=3),
                     name='Inhomogeneous Phase'),
          go.Scatter(x=[0, x_rot[i]],
                     y=[0, y_rot[i]],
                     mode='lines+markers',
                     marker=dict(size=5),
                     line=dict(color='#636EFA', width=3),
                     name='Spin at f0'),
          go.Scatter(x=[0, x1_rot[i]],
                     y=[0, y1_rot[i]],
                     mode='lines+markers',
                     marker=dict(size=5),
                     line=dict(color='#fa6363', width=3),
                     name='Inhomogeneous Spin'),
          go.Scatter(x=time[:i],
                     y=angles_rot[:i],
                     mode='lines',
                     marker=dict(size=5),
                     line=dict(color='blue', width=3),
                     name='Phase'),
          go.Scatter(x=time[:i],
                     y=angles1_rot[:i],
                     mode='lines',
                     marker=dict(size=5),
                     line=dict(color='red', width=3),
                     name='Inhomogeneous Phase')
         
         
         ],
    name=str(i),
    traces=[0,1,2,3,4,5,6,7]) for i in range(num_frames)]

fig.frames = frames

# Determine the maximum absolute value of coordinates
max_coord = max(abs(x.max()), abs(y.max()))

fig.update_xaxes(title_text="x", row=1, col=1)
fig.update_xaxes(title_text="time", row=1, col=2)
fig.update_yaxes(title_text="y", row=1, col=1)
fig.update_yaxes(title_text="rad", tickmode = 'array',
        tickvals = [-math.pi, 0, math.pi],
        ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}'], row=1, col=2)


# Update layout
fig.update_layout(
    height=600,
    title="Spins rotating",
    xaxis=dict(autorange=False),
    yaxis=dict(autorange=False),
    updatemenus=
    [dict(
        type='buttons',
        buttons=[dict(label='Play',
                      method='animate',
                      args=[None, dict(frame=dict(duration=50, redraw=False), fromcurrent=True, mode='immediate')]),
                 dict(label='Pause',
                      method='animate',
                      args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')])
                ],
         ),
     dict(
        buttons=[dict(
                    args=[{"visible": [True, True, True, True, False, False, False, False]}],
                    label="Laboratory Frame",
                    method="update"),
                dict(
                    args=[{"visible": [False, False, False, False, True, True, True, True]}],
                    label="Rotating Frame",
                    method="update"
                )],
            direction="down",
            pad={"b": 70},
            showactive=True,
            x=-0.13,
            xanchor="left",
            y=-0.15,
            yanchor="top"
     )
    ]
)

# Show figure
fig.show()

In [None]:
fname_mag_e1 = os.path.join("fmap", "sub-fmap_magnitude1.nii.gz")
fname_phase_e1 = os.path.join("fmap", "sub-fmap_phase1.nii.gz")
fname_phase_e1_json = os.path.join("fmap", "sub-fmap_phase1.json")
fname_phase_e2 = os.path.join("fmap", "sub-fmap_phase2.nii.gz")
fname_mask = os.path.join("fmap", "mask.nii.gz")
fname_fmap = os.path.join("fmap", "fmap.nii.gz")

nii_mag_e1 = nib.load(fname_mag_e1)
nii_phase_e1 = nib.load(fname_phase_e1)
nii_phase_e2 = nib.load(fname_phase_e2)
nii_mask = nib.load(fname_mask)
nii_fmap = nib.load(fname_fmap)

In [None]:
# Phase evolution though different echo times
mask = nii_mask.get_fdata()[30:-30,8:105,30]
fmap = nii_fmap.get_fdata()[30:-30,8:105,30] * mask  # [Hz]
phase1 = (nii_phase_e1.get_fdata()[30:-30,8:105,30] / 4095 * 2 * math.pi - math.pi) * mask

with open(fname_phase_e1_json, 'r') as json_data1:
    data1 = json.load(json_data1)
    
echo_time1 = data1['EchoTime']
phase0 = phase1 - (echo_time1 * (fmap * 2 * math.pi))
zmin = -math.pi
zmax = math.pi

steps = 31
last_echo_time = 0.03
echo_times = np.linspace(0.0, last_echo_time, steps)
fig = go.Figure()
for i_echo, echo_time in enumerate(echo_times):
    phase = phase0 + (fmap * echo_time * 2 * math.pi)
    phase = np.angle(np.exp(1j*phase))
    if i_echo >= len(echo_times) - 1:
        fig.add_trace(go.Heatmap(z=np.rot90(phase, k=-1), visible=True, coloraxis = "coloraxis"))
    else:
        fig.add_trace(go.Heatmap(z=np.rot90(phase, k=-1), visible=False, coloraxis = "coloraxis"))

fig.update_layout(coloraxis = {'colorscale':'gray'},
                 coloraxis_cmin=zmin, coloraxis_cmax=zmax)
fig.update_coloraxes(
    colorbar=dict(title="Rad",
                  titleside="top",
                  tickmode="array",
                  tickvals=[-math.pi, 0, math.pi-0.01],
                  ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}']))

echo_times_str = [f"{time:.2}" for time in echo_times]
steps = []
for i in range(len(fig.data)):
    step = dict(
        method="update",
        label=echo_times_str[i],
        args=[{"visible": [False] * len(fig.data)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=30,
    currentvalue={"prefix": "Echo Time: "},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    title=dict(text="Phase at different echo times", x=0.5)
)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout({"height": 550, "width": 500})
fig.show()

In [None]:
def complex_difference(phase1, phase2):
    """ Calculates the complex difference between 2 phase arrays (phase2 - phase1)

    Args:
        phase1 (numpy.ndarray): Array containing phase data in radians
        phase2 (numpy.ndarray): Array containing phase data in radians. Must be the same shape as phase1.

    Returns:
        numpy.ndarray: The difference in phase between each voxels of phase2 and phase1 (phase2 - phase1)
    """

    # Calculate phasediff using complex difference
    comp_0 = np.exp(-1j * phase1)
    comp_1 = np.exp(1j * phase2)
    return np.angle(comp_0 * comp_1)

In [None]:
def plot_2_echo_fmap(phase1, phase2, echotime1, echotime2):
    phase_diff = complex_difference(phase1, phase2)
    fmap = phase_diff / (echotime2 - echotime1) / 2 / math.pi
    n=4
    # Attempt at subplots
    fig = make_subplots(rows=1, cols=n, shared_xaxes=False, horizontal_spacing=0.1, subplot_titles=("Phase 1", "Phase 2", "Phase difference", "B0 field map"), specs=[[{"type": "Heatmap"}, {"type": "Heatmap"}, {"type": "Heatmap"}, {"type": "Heatmap"}]], )
    
    fig.add_trace(go.Heatmap(z=np.rot90(phase1, k=-1), colorscale='gray', colorbar_x=1/n - 0.05, zmin=-math.pi, zmax=math.pi,
                             colorbar=dict(title="Rad",
                                           titleside="top",
                                           tickmode="array",
                                           tickvals=[-math.pi, 0, math.pi-0.01],
                                           ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}'])), 1, 1)
    fig.add_trace(go.Heatmap(z=np.rot90(phase2, k=-1), colorscale='gray', colorbar_x=2/n - 0.02, zmin=-math.pi, zmax=math.pi,
                             colorbar=dict(title="Rad",
                                           titleside="top",
                                           tickmode="array",
                                           tickvals=[-math.pi, 0, math.pi-0.01],
                                           ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}'])), 1, 2)
    fig.add_trace(go.Heatmap(z=np.rot90(phase_diff, k=-1), colorscale='gray', colorbar_x=3/n - 0.02, zmin=-math.pi, zmax=math.pi,
                             colorbar=dict(title="Rad",
                                           titleside="top",
                                           tickmode="array",
                                           tickvals=[-math.pi, 0, math.pi-0.01],
                                           ticktext = [f"-{PI_UNICODE}", 0, f'{PI_UNICODE}'])), 1, 3)
    fig.add_trace(go.Heatmap(z=np.rot90(fmap, k=-1), colorscale='gray',
                             colorbar=dict(title="Hz",
                                           titleside="top")), 1, 4)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.update_layout({"height": 450, "width": 1700})
    
    fig.show()

In [None]:
def get_circle(x, y, r):
    if x < 1 or y < 1 or r < 1:
        raise ValueError("Input parameters are too small")
        
    my_array = np.zeros([x,y])
    for i in range(x):
        for j in range(y):
            squared = (i-(x/2))**2 + (j-(y/2))**2
            h = np.sqrt(squared)
            if h < r:
                my_array[i,j] = 1
    return my_array

echo1 = get_circle(100, 100, 30) * -1
echo2 = get_circle(100, 100, 30) * 2
echo_time1 = 0.005
echo_time2 = 0.01

plot_2_echo_fmap(echo1, echo2, echo_time1, echo_time2)

In [None]:
mask = nii_mask.get_fdata()[30:-30,8:105,30]
phase1 = (nii_phase_e1.get_fdata()[30:-30,8:105,30] / 4095 * 2 * math.pi - math.pi) * mask
phase2 = (nii_phase_e2.get_fdata()[30:-30,8:105,30] / 4095 * 2 * math.pi - math.pi) * mask
echo_time1 = 0.00338
echo_time2 = 0.00558

plot_2_echo_fmap(phase1, phase2, echo_time1, echo_time2)