<a href="https://colab.research.google.com/github/patbolan/MPHY5178_F22/blob/main/pulsetool/RF_OOP_vM11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

RF OOP vM11
* Final from Madeline Mcginnis for Spring 2022
* With modfications by PJB for defaults




In [1]:
#####################################
#Imports

import numpy as np
import matplotlib.pyplot as plt
import math
import scipy.integrate as integrate
import time
import ipywidgets as widgets
from ipywidgets import interact, interactive, Layout
from IPython.display import display
from IPython.core.display import clear_output

# These lines turn off the automatic plotting, so I can control it
plt.ioff()
%config InlineBackend.close_figures=False

#####################################
# RF Pulse classes and definitions

class RFPulse:
  def __init__(self, pulse_type="null"):
    self.rf_type = 'pulse_type'
    self.widget_style =  {'description_width': 'initial'}
    self.duration_slider = widgets.FloatSlider(value=4, min=0.1, max=20, step=0.1, description='Pulse Duration (ms):', style=self.widget_style)
    self.npts_slider = widgets.IntSlider(value=128, min=32, max=512, step=16, description='Npts:', style=self.widget_style)

    self.control_ui = widgets.VBox([self.npts_slider, self.duration_slider])

  def get_control_ui(self):
    return self.control_ui  

  # Will call function if the value of any widget in the control ui changes 
  def register_observer(self, callback_fn):
    # Attach the callback to all child widgets
    for child_widget in self.control_ui.children:
      child_widget.observe(callback_fn, 'value')

  # Should return (time, b1_amp, b1_phase)
  def get_pulse_shape(self) -> tuple:
    raise NotImplementedError() # Like a Virtual method in java

  # Plots the RF pulse, with the shapes passed in
  def plot_pulse(self):
    t, b1_amp, b1_phase = self.get_pulse_shape()

    plt.ioff()
    ax1 = plt.gca()
    ax1.clear()
    ax1.plot(t*1000, b1_amp)
    plt.xlabel('time (ms)')
    plt.ylabel('gamma-bar B1 amplitude ')
    plt.title('RF Pulse Shape')

    with out_rf_display:
      clear_output(wait=True)
      display(ax1.figure)
      #display(ax2.figure)


class WindowedSinc(RFPulse):
  def __init__(self):
    RFPulse.__init__(self, "Windowed Sinc")
    self.nlobes_slider = widgets.FloatSlider(value=5, min=1, max=11, step=1, description='Nlobes:', style=self.widget_style)
    self.window_alpha_slider = widgets.FloatSlider(value=1, min=0.001, max=1, step=0.001, description='Window Alpha:', style=self.widget_style) 
    
    # Overwrite the control UI box
    self.control_ui = widgets.VBox([self.npts_slider, self.duration_slider, self.nlobes_slider, self.window_alpha_slider])

  def get_pulse_shape(self):
    duration = self.duration_slider.value /1000 # display in ms, calc in seconds
    nlobes = self.nlobes_slider.value
    window_alpha = self.window_alpha_slider.value
    npts = self.npts_slider.value

    # Tau is an internal variable (user never needs to know about it) ranging 
    # from -.5 to 0.5, centered at zero. Used to define shape
    tau = np.linspace(-0.5, 0.5, npts)

    # t is the time axis, in s
    t = np.linspace(0, duration, npts)

    # Here is the math behind the pulse shape. It is essentially sin(x)/x
    amp = np.sin((nlobes+1) * np.pi * tau ) / ((nlobes+1) * np.pi * tau)

    # Apply window
    if window_alpha<1.0:
      window_function = window_alpha + (1-window_alpha) * np.cos(2 * np.pi * tau)
      amp = amp * window_function

    # for tau=0 amp would be undefined (div by 0). By convention this is 1.0
    amp[np.isnan(amp)] = 1.0
    return (t, amp, amp*0) # Note the phase is always zero   


class Square(RFPulse):
  def __init__(self):
    RFPulse.__init__(self, "Square")

  def get_pulse_shape(self):
    duration = self.duration_slider.value /1000 # display in ms, calc in seconds
    npts = self.npts_slider.value

    # t is the time axis, in s
    t = np.linspace(0, duration, npts)

    # This is always 1
    amp = t*0 + 1.0

    return (t, amp, amp*0) # Note the phase is always zero    


class HSn(RFPulse):
  def __init__(self):
    RFPulse.__init__(self, "Hsn")
    self.r_slider = widgets.FloatSlider(value=10, min=0.1, max=100, step=0.1, description='R:', style=self.widget_style)
    self.n_exp_slider = widgets.FloatSlider(value=1, min=1, max=20, step=0.1, description='N Exp:', style=self.widget_style)
    self.trunc_value_slider = widgets.FloatSlider(value=0.01, min=0.0001, max=0.1, step=0.0001, description='Truncation Value:', style=self.widget_style)
    self.freq_sweep_drop = widgets.Dropdown(options=[True, False], value=True, description='Frequency Sweep (low to high):', style=self.widget_style)

    self.control_ui = widgets.VBox([self.npts_slider, self.duration_slider, self.r_slider, self.n_exp_slider, self.trunc_value_slider, self.freq_sweep_drop])

  def get_pulse_shape(self):
    pulse_duration = self.duration_slider.value /1000 # display in ms, calc in seconds
    npts = self.npts_slider.value
    r = self.r_slider.value
    n_exp = self.n_exp_slider.value
    truncation_value = self.trunc_value_slider.value
    freq_sweep = self.freq_sweep_drop.value

    BW = r / pulse_duration # Hz
  
    ### Make the pulse 
    # For HSn the dummy time variable tau goes from -1 to 1, centered at zero
    tau = np.linspace(-1.0, 1.0, npts)

    # t is the time axis, in s
    t = np.linspace(0, pulse_duration, npts)


    # This is a constant, determines the smoothness of amplitude at ends
    beta = np.log((1 + np.sqrt(1-truncation_value**2)) / truncation_value)

    # THe amplitude is F1
    F1 = 1/np.cosh(beta * tau**n_exp)
    F2 = integrate.cumtrapz(F1**2) # has one less element than F1
    F2 = np.concatenate([[0], F2]) # Prepend a zero value

    # Calculate frequency sweep in Hz
    F2_range = F2.max() - F2.min()
    omega_Hz = F2 * BW / F2_range
    omega_Hz = omega_Hz - BW/2
    if not freq_sweep: 
      omega_Hz = omega_Hz * -1

    omega_radians_per_s = omega_Hz * 2 * math.pi

    phs_radians = integrate.cumtrapz(omega_radians_per_s * pulse_duration/(npts-1) ) 
    phs_radians = np.concatenate([[0], phs_radians]) # Prepend a zero value

    return (t, F1, phs_radians)


class Gaussian(RFPulse):
  def __init__ (self):
    RFPulse.__init__(self, "Gaussian")
    self.trunc_slider = widgets.FloatSlider(value=3, min=0.5, max=10, step=0.5, description='Truncation Sigma:', style=self.widget_style)

    self.control_ui = widgets.VBox([self.npts_slider, self.duration_slider, self.trunc_slider])

  def get_pulse_shape(self):
    truncation_sigma = self.trunc_slider.value

    duration = self.duration_slider.value /1000 # display in ms, calc in seconds
    npts = self.npts_slider.value

    # t is the time axis, in s
    tau = np.linspace(-1.0, 1.0, npts)
    t = np.linspace(0, duration, npts)

    # This is always 1
    amp = np.exp(-0.5 * (tau * truncation_sigma)**2)

    return (t, amp, amp*0) # Note the phase is always zero

#####################################
# Sim classes and definitions

class sim():
  def __init__(self, my_pulse):
    # PJB Note: relying on a global variable here. 
    # Better to initialize with the pulse object
    self.pulse = my_pulse
    self.widget_style =  {'description_width': 'initial'}

    self.simulation_bw_slider = widgets.IntSlider(value=4000, min=100, max=20000, step=100, description='Simulation BW (Hz):', style=self.widget_style)
    self.offset_steps_slider = widgets.IntSlider(value=101, min=30, max=1000, step=10, description='Offset Steps:', style=self.widget_style)

    self.gamma_max_slider = widgets.IntSlider(value=353, min=0, max=4000, step=5, description='Gamma b1 Max (Hz):', style=self.widget_style)
    self.pulse_phase_slider = widgets.IntSlider(value=0, min=0, max=10, step=1, description='Pulse Phase (radians):', style=self.widget_style)

    self.t1_slider = widgets.IntSlider(value=2000, min=50, max=10000, step=10, description='T1 (ms):', style=self.widget_style)
    self.t2_slider = widgets.IntSlider(value=100, min=5, max=1000, step=5, description='T2 (ms):', style=self.widget_style)

    self.control_ui = widgets.VBox([self.simulation_bw_slider, self.offset_steps_slider, self.gamma_max_slider, self.pulse_phase_slider, self.t1_slider, self.t2_slider])

  def get_control_ui(self):
    return self.control_ui  

  # Will call function if the value of any widget in the control ui changes 
  def register_observer(self, callback_fn):
    # Attach the callback to all child widgets
    for child_widget in self.control_ui.children:
      child_widget.observe(callback_fn, 'value')

  # Plots simulation of pulse
  def plot_sim(self):
    t, b1_amp, b1_phase = self.pulse.get_pulse_shape()
    Npts = self.pulse.npts_slider.value

    # Simulation parameters
    simulation_bw = self.simulation_bw_slider.value
    offset_steps = self.offset_steps_slider.value

    # PJB: the slider is in milliseconds, the simulation should be in seconds
    #Tp = self.pulse.duration_slider.value
    Tp = self.pulse.duration_slider.value / 1000


    # Modification of the RF pulse shape    
    #gamma_b1_max = 353 # Scaling the pulse amplitude, Hz
    #pulse_phase = 0 # A global phase to apply to pulse. Radians
    gamma_b1_max = self.gamma_max_slider.value
    pulse_phase = self.pulse_phase_slider.value

    # Properties of the magnetizaiton 
    #T1 = 1000.0 # ms
    #T2 = 1000.0 # ms
    T1 = self.t1_slider.value
    T2 = self.t2_slider.value
    M0 = np.array([0, 0, 1.0]) # Initial magnetization

    offsets = np.linspace(-simulation_bw/2, simulation_bw/2, offset_steps)

    # Convert polar RF into cartesian components
    b1x = 2 * np.pi * gamma_b1_max * b1_amp * np.cos(pulse_phase + b1_phase)
    b1y = 2 * np.pi * gamma_b1_max * b1_amp * np.sin(pulse_phase + b1_phase)

    Mt = np.zeros([offset_steps, Npts, 3])
    deltaTp = Tp / Npts

    # Array version of simulator
    Mt[:, 0, :] = M0

    # Loop over all time points 
    for tdx in range(1, Npts):
      Mt[:, tdx,:] = blochRK4_arrayform(Mt[:, tdx-1,:], b1x[tdx], b1y[tdx], offsets*2*np.pi, 1000/T1, 1000/T2, deltaTp)


    # PJB There was an indentation problem. Everying below should be in the 
    # outer for loop, not inner
    Mend = Mt[:,-1,:]

    Mxy = np.sqrt(Mend[:,0]**2 + Mend[:,1]**2)
    Mz = Mend[:,2]

    plt.ioff()
    ax1 = plt.gca()
    ax1.clear()
    if True:
      ax1.plot(offsets, Mz, '-k')
      ax1.plot(offsets, Mxy, '-m')
      ax1.set_xlabel('Off resonance (Hz)')
      ax1.legend(['Mz', 'Mxy'])
      ax1.set_ylabel('|Mxy|')
    else: 
      # 2D Plot of Mxy
      Mxyt = np.sqrt(Mt[:,:,0]**2 + Mt[:,:,1]**2)
      ax1.imshow(Mxyt, vmin=0, vmax=1)

    #plt.imshow(Mt[:,:,2])

    with out_rf_sim:
      clear_output(wait=True)
      display(ax1.figure)

# This is a matrix version of a Bloch simulator using the 4th order Runge-Kutta
# ODE solver. This one simultaneously solves for an array of offset values.
# The magnetization M and its time derivative dMdt has shape [num_offsets, 3] 
# num_offsets = offsets.shape[0]. R1, R2, b1x, and b1y are all assumed scalars,
# although if they have the same shape as offsets it will also work.
# This adds an offset dimension ot the calculation but goes much faseter
def dM_dt_function_arrayform(b1x, b1y, R1, R2, offsets, M):
  dMdt = np.zeros([offsets.shape[0], 3])
  #dMdt = M * 0 # Initize the 3D vector

  # The three components of dM/dt: 
  dMdt[:,0] = offsets * M[:,1] - (b1y * M[:,2]) - (M[:,0] * R2)
  dMdt[:,1] = b1x * M[:,2] - (offsets * M[:,0]) - (M[:,1] * R2)
  dMdt[:,2] = b1y * M[:,0] - (b1x * M[:,1] + ((1.0 - M[:,2]) * R1))

  return dMdt

def blochRK4_arrayform(Minit, B1x, B1y, offsets, R1, R2, deltaT):
  # temp variable holding output of each of 4 iterations
  KK = np.zeros([4, offsets.shape[0], 3]) # 4 iterations, each holding a 4D vector [num_offsets, Mx, My, Mz]

  KK[0,:,:] = deltaT * dM_dt_function_arrayform(B1x, B1y, R1, R2, offsets, Minit)
  KK[1,:,:] = deltaT * dM_dt_function_arrayform(B1x, B1y, R1, R2, offsets, Minit + KK[0,:,:]/2)
  KK[2,:,:] = deltaT * dM_dt_function_arrayform(B1x, B1y, R1, R2, offsets, Minit + KK[1,:,:]/2)
  KK[3,:,:] = deltaT * dM_dt_function_arrayform(B1x, B1y, R1, R2, offsets, Minit + KK[2,:,:])

  # Take weighted sum of all K values
  Mnext = Minit + 1/6 * (KK[0,:] + 2*KK[1,:] + 2*KK[2,:] + KK[3,:] )

  return Mnext

#####################################
# Callback for parameter changes

def onPulseParameterChange(change):

  # Be explicit about which notebook globals the function is expecting
  global the_pulse, out_rf_display, out_rf_sim

  # Don't care what the change is, just redraw
  with out_rf_display:
    clear_output(wait=True)
    if not the_pulse == None:
      the_pulse.plot_pulse()
      the_sim.pulse = the_pulse
      the_sim.plot_sim()

# Callback for when RF pulse type changes
def onPulseChange(change):

  # Be explicit about which notebook globals the function is expecting
  global the_pulse, out_rf_config, out_rf_display, out_rf_params, out_rf_sim

  chosenPulse = change['new']

  if chosenPulse == 'Windowed Sinc':
    the_pulse = WindowedSinc()
  elif chosenPulse == 'Square':
    the_pulse = Square()
  elif chosenPulse == 'HSn':
    the_pulse = HSn()
  elif chosenPulse == "Gaussian":
    the_pulse = Gaussian()
  else:
    the_pulse = None
    
  with out_rf_config:
    if the_pulse == None:
      clear_output()
      print('Not implemented yet')
    else:
      clear_output(wait=True)
      display(the_pulse.get_control_ui()) 

      with out_rf_params:
        display(the_sim.get_control_ui())         

  with out_rf_display:
    if the_sim == None:
      clear_output()
    else:
      clear_output(wait=True)
      the_pulse.register_observer(onPulseParameterChange)
      the_pulse.plot_pulse()

  with out_rf_sim:
    if the_pulse == None:
      clear_output()
    else:
      the_sim.pulse = the_pulse
      the_sim.plot_sim()

def on_rf_param_change(change):
  global out_rf_sim, out_rf_params, the_sim

  with out_rf_params:
    if the_pulse == None:
      clear_output()
    else:
      clear_output(wait=True)
      display(the_sim.get_control_ui()) 

  with out_rf_sim:
    if the_pulse == None:
      clear_output()
    else:
      the_sim.pulse = the_pulse
      the_sim.plot_sim()


#####################################
# Configure and display the RF Pulse & Sim (Interface)

# Pulse Output Widgets (Stacked Boxes)
out_rf_picker = widgets.Output(layout={'border':'1px solid red'})
out_rf_config = widgets.Output(layout={'border':'1px solid orange'})
out_rf_display = widgets.Output(layout={'border':'1px solid yellow'})
vb = widgets.VBox([out_rf_picker, out_rf_config, out_rf_display])
display(vb)

# Initialize a pulse to "None" 
the_pulse = None

# create pulse picker and display it in out_rf_picker
style = {'description_width': 'initial'}
dropPulseType = widgets.Dropdown(
    options=['-Select Pulse-', 'Windowed Sinc', 'HSn', 'Square', 'Gaussian'],
    value='-Select Pulse-',
    description='Type of Pulse:',
    style = style,
    disabled=False,  )
dropPulseType.observe(onPulseChange,'value')

with out_rf_picker: 
  display(dropPulseType)

# Sim Output Widgets
out_rf_params = widgets.Output(layout={'border':'1px solid green'})
out_rf_sim = widgets.Output(layout={'border':'1px solid blue'})
vb2 = widgets.VBox([out_rf_params, out_rf_sim])
display(vb2)

# Initalize sim with pulse
the_sim = sim(the_pulse)
the_sim.register_observer(on_rf_param_change)





VBox(children=(Output(layout=Layout(border='1px solid red')), Output(layout=Layout(border='1px solid orange'))…

VBox(children=(Output(layout=Layout(border='1px solid green')), Output(layout=Layout(border='1px solid blue'))…