# **Numerical Solution of 1D Time Dependent Schrödinger Equation by Split Operator Fourier Transform (SOFT) Method**

<i class="fa fa-home fa-2x"></i><a href="../index.ipynb" style="font-size: 20px"> Go back to index</a>

**Source code:** https://github.com/osscar-org/quantum-mechanics/blob/develop/notebook/quantum-mechanics/soft.ipynb

<hr style="height:1px;border:none;color:#cccccc;background-color:#cccccc;" />

## **Goals**

<hr style="height:1px;border:none;color:#cccccc;background-color:#cccccc;" />

## Interactive visualization
(be patient, it might take a few seconds to load)

In [3]:
%matplotlib widget

import numpy as np
from math import pi
from scipy.fftpack import fft, ifft
from ipywidgets import Button, FloatSlider, HBox, VBox, IntProgress, Dropdown, Layout, Accordion
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import animation

In [4]:
class SOFT(object):
    """
    This is the Split Operator Fourier Transform (SOFT)
    x: the grid space 
    psi_0: inital wavefunction, included real and image parts
    V: the potential
    dt: the time interval
    hbar: the Plank constant (default value 1.0 as atomic unit)
    m: the mass (default value 1.0 as atomic unit)
    """
    
    def __init__(self, x, psi_0, V, dt, hbar=1.0, m = 1.0):
        self.N = len(x);
        self.hbar = hbar;
        self.m = m;
        
        dx = x[1] -x [0];
        dk = (2*pi)/(self.N*dx);
        
        self.x = x;
        self.dx = dx;
        self.dk = dk;
        self.t = 0.0;
        self.dt = dt;
        self.m = m;
        
        self.k_x = -0.5*self.N*self.dk + dk*np.arange(self.N);
        self.psi_x = np.zeros((self.N), dtype=np.complex128)
        self.psi_k = np.zeros((self.N), dtype=np.complex128)
        self.psi_mod_x = np.zeros((self.N), dtype=np.complex128)
        self.psi_mod_k = np.zeros((self.N), dtype=np.complex128)
        
        self.psi_x = psi_0;
        self._periodic_bc();
        
        if callable(V):
            self.V = V(x)
        else:
            self.V = V
            
        self.psi_mod_x[:] = self.dx/np.sqrt(2*pi)*self.psi_x[:]*np.exp(-1.0j*self.k_x[0]*self.x);
        
    def _periodic_bc(self):
        self.psi_x[-1] = self.psi_x[0]
               
    def _half_pot_prop(self, ft=True):
        if ft == True:
            self.psi_mod_x[:] = ifft(self.psi_mod_k[:])
        self.psi_mod_x[:] = self.psi_mod_x[:]*np.exp(-1.0j*(self.dt/2.0)*self.V)

    def _full_kinetic_prop(self):
        self.psi_mod_k[:] = fft(self.psi_mod_x[:])
        self.psi_mod_k[:] = self.psi_mod_k[:]*np.exp(-1.0j*self.k_x**2*self.dt/(2.0*self.m))
        
    def _compute_psi(self):
        self.psi_x[:] = (np.sqrt(2*pi)/self.dx)*self.psi_mod_x*np.exp(1.0j*self.k_x[0]*self.x)
        self.psi_k[:] = self.psi_mod_k*np.exp(-1.0j*self.x[0]*self.dk*np.arange(self.N))
        
    def get_kinetic_energy(self):
        self.ekint = sum(np.conj(self.psi_x)*ifft(self.k_x**2/(2.0*self.m)*fft(self.psi_x)))*self.dx
        return self.ekint
    
    def get_potential_energy(self):
        self.epot = sum(np.conj(self.psi_x)*self.V*self.psi_x)*self.dx
        return self.epot
    
    def get_norm(self):
        self.norm = sum(np.conj(self.psi_x)*self.psi_x)*self.dx
        return self.norm
             
    def evolution(self, Nsteps=1):
        for i in range(Nsteps):
            self._half_pot_prop(ft = False)
            self._full_kinetic_prop()
            self._half_pot_prop(ft = True)
        self._compute_psi()
        self.psi_mod_x /= np.sqrt((abs(self.psi_mod_x)**2).sum()*2*pi/self.dx);
        self.t += self.dt*Nsteps
        

In [5]:
######################################################################
# Helper functions for gaussian wave-packets
def gauss_x(x, a, x0, k0):
    """
    a gaussian wave packet of width a, centered at x0, with momentum k0
    """
    return ((a * np.sqrt(np.pi)) ** (-0.5)
            * np.exp(-0.5 * ((x - x0) * 1. / a) ** 2 + 1j * x * k0))


def gauss_k(k, a, x0, k0):
    """
    analytical fourier transform of gauss_x(x), above
    """
    return ((a / np.sqrt(np.pi)) ** 0.5
            * np.exp(-0.5 * (a * (k - k0)) ** 2 - 1j * (k - k0) * x0))


######################################################################
# Utility functions for running the animation
def theta(x):
    """
    theta function :
      returns 0 if x<=0, and 1 if x>0
    """
    x = np.asarray(x)
    y = np.zeros(x.shape)
    y[x > 0] = 1.0
    return y


def square_barrier(x, width, height):
    return height * (theta(x + 0.5*width) - theta(x - 0.5*width))

def parabola(x, a=1.0, x0 = 0.0):
    return a/1000.0*(x-x0)**2

def morse(x, D = 1.0, b = 0.03, x0 = -60.0):
    return D*(np.exp(-2.0*b*(x-x0))-2*np.exp(-b*(x-x0))) + D


In [6]:
######################################################################
# Create the animation

# specify time steps and duration
dt = 0.01
N_steps = 50
t_max = 120
frames = int(t_max / float(N_steps * abs(dt)))

# specify constants
hbar = 1.0   # planck's constant
m = 1.9      # particle mass

# specify range in x coordinate
N = 2 ** 11
dx = 0.1
x = dx * (np.arange(N) - 0.5 * N)

# specify potential
V0 = 1.5
L = hbar / np.sqrt(2 * m * V0)
a = 30 * L
x0 = -60 * L

# specify initial momentum and quantities derived from it
p0 = np.sqrt(2 * m * 0.2 * V0)
dp2 = p0 * p0 * 1. / 80
d = hbar / np.sqrt(2 * dp2)

k0 = p0 / hbar
v0 = p0 / m
psi_x0 = gauss_x(x, d, x0, k0)

In [21]:
style = {'description_width': 'initial'}

pot_select = Dropdown(
    options=['1. Box potential', '2. Morse potential', '3. Harmonic potential'],
    index = 0,
    description='Potential type:',
    disabled=False,
    style = style
)

layout_hidden  = widgets.Layout(visibility = 'hidden')
layout_visible = widgets.Layout(visibility = 'visible')

# Set the depth and wide of the box potential (pot1)
swide = FloatSlider(value = a, min = 0.0, max = 2*a, description = 'Width: ');
sheight = FloatSlider(value = V0, min = 0.0, max = 2*V0, description = 'Height: ');
swide.layout = layout_visible
sheight.layout = layout_visible

#Set the parameter for the Morse potential (pot2)
pot2_D  = FloatSlider(value = 1.0, min = 1.0, max = 5.0, description = 'D: ');
pot2_b  = FloatSlider(value = 0.03, min = 0.01, max = 0.10, step=0.01, description = 'b: ');
pot2_x0 = FloatSlider(value = -60.0, min = -100.0, max = 100.0, description = 'x0: ');

pot2_D.layout = layout_hidden
pot2_b.layout = layout_hidden
pot2_x0.layout = layout_hidden

#Set the parameter for the Harmonic potential (pot3)
pot3_a = FloatSlider(value = 1.0, min = 0.2, max = 2.0, description = 'a: ');
pot3_x0 = FloatSlider(value = 0.0, min = -100.0, max = 100.0, description = 'x0: ');
pot3_a.layout = layout_hidden
pot3_x0.layout = layout_hidden


#Show the potential image

file1 = open("images/pot1.png", "rb")
image1 = file1.read()
file2 = open("images/pot2.png", "rb")
image2 = file2.read()
file3 = open("images/pot3.png", "rb")
image3 = file3.read()

pot_img = widgets.Image(
    value=image1,
    format='png',
    width=300,
    height=400,
)

pot_img.layout = layout_visible

def pot_change(change):
    global V_x
    if pot_select.index == 0:
        swide.layout = layout_visible
        sheight.layout = layout_visible
        pot2_D.layout = layout_hidden
        pot2_b.layout = layout_hidden
        pot2_x0.layout = layout_hidden
        pot3_a.layout = layout_hidden
        pot3_x0.layout = layout_hidden
        pot_img.layout = layout_visible
        pot_img.value = image1
    elif pot_select.index == 1:
        swide.layout = layout_hidden
        sheight.layout = layout_hidden
        pot2_D.layout = layout_visible
        pot2_b.layout = layout_visible
        pot2_x0.layout = layout_visible
        pot3_a.layout = layout_hidden
        pot3_x0.layout = layout_hidden
        pot_img.layout = layout_visible
        pot_img.value = image2                
    elif pot_select.index == 2:
        swide.layout = layout_hidden
        sheight.layout = layout_hidden
        pot2_D.layout = layout_hidden
        pot2_b.layout = layout_hidden
        pot2_x0.layout = layout_hidden
        pot3_a.layout = layout_visible
        pot3_x0.layout = layout_visible
        pot_img.layout = layout_visible
        pot_img.value = image3  
        
pot_select.observe(pot_change, names='value', type='change');

def on_pot_update(event):
    global V_x
    if pot_select.index == 0:
        V_x = square_barrier(x, swide.value, sheight.value)
        V_x[x < -98] = 1E6
        V_x[x > 98] = 1E6
    elif pot_select.index == 1:
        V_x = morse(x, pot2_D.value, pot2_b.value, pot2_x0.value)
    elif pot_select.index == 2:
        V_x = parabola(x, pot3_a.value, pot3_x0.value)

pot_update = Button(description="Update potential")
pot_update.on_click(on_pot_update)

#display(pot_select, pot_img)
#display(HBox([swide, sheight]))
#display(HBox([pot2_D, pot2_b]), pot2_x0)
#display(HBox([pot3_a, pot3_x0]))
#display(pot_update)

pot_accordion = Accordion(children=[VBox([pot_select, pot_img,
                   HBox([swide, sheight]),
                   HBox([pot2_D, pot2_b]), pot2_x0,
                   HBox([pot3_a, pot3_x0]),
                   pot_update])], selected_index = None)

pot_accordion.set_title(0, "Select potential and set parameters")
display(pot_accordion)

Accordion(children=(VBox(children=(Dropdown(description='Potential type:', options=('1. Box potential', '2. Mo…

In [16]:
w_mass = FloatSlider(value = 1.0, min = 0.2, max=5.0, description="mass: ") 
w_dt = FloatSlider(value = 0.01, min = 0.01, max = 0.1, step=0.01, description="dt: ")

def on_mass_change(change):
    global m
    m = w_mass.value;
    
def on_dt_change(change):
    global dt
    dt = w_dt.value;

w_mass.observe(on_mass_change, names='value')
w_dt.observe(on_dt_change, names='value')

In [20]:
V_x = square_barrier(x, a, V0)
V_x[x < -98] = 1E6
V_x[x > 98] = 1E6

S = SOFT(x = x, dt = dt, psi_0=psi_x0, V=V_x, hbar=hbar, m=m)
S.evolution(1)

def on_init_change(b):
    global S
    psi_x0 = gauss_x(x, d, x0, k0)
    S = SOFT(x = x, dt = dt, psi_0=psi_x0, V=V_x, hbar=hbar, m=m)
    S.evolution(1)
    setup_plot()
    

w_init = Button(description="Update parameters");
w_init.on_click(on_init_change)

para_accordion = Accordion(children=[VBox([HBox([w_mass, w_dt]), w_init])], 
                           selected_index = None)
para_accordion.set_title(0, "Set simulation parameters")
display(para_accordion)

Accordion(children=(VBox(children=(HBox(children=(FloatSlider(value=1.0, description='mass: ', max=5.0, min=0.…

In [None]:
######################################################################
# Set up plot

fig = plt.figure(figsize=(7, 6))
fig.canvas.header_visible = False

# plotting limits
xlim = (-100, 100)
klim = (-5, 5)

# top axes show the x-space data
ymin = 0
ymax = V0
ax1 = fig.add_subplot(211, xlim=xlim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_x_line, = ax1.plot([], [], c='r', label=r'$|\psi(x)|$', linewidth=1.2)
psi_x_real, = ax1.plot([], [], c='b', label=r'$\psi(x)_r$', linewidth=0.8)
psi_x_imag, = ax1.plot([], [], c='orange', label=r'$\psi(x)_i$', linewidth=0.8)
V_x_line, = ax1.plot([], [], c='k', label=r'$V(x)$')

psi_x_line.set_visible(True)
psi_x_real.set_visible(False)
psi_x_imag.set_visible(False)

title = ax1.set_title("")
ax1.legend(prop=dict(size=8), ncol=4, loc=1)
ax1.set_xlabel('$x$')
ax1.set_ylabel(r'$|\psi(x)|$')

# bottom axes show the k-space data
ymin = abs(S.psi_k).min()
ymax = abs(S.psi_k).max()
ax2 = fig.add_subplot(212, xlim=klim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_k_line, = ax2.plot([], [], c='r', label=r'$|\psi(k)|$', linewidth=1.2)

p0_line1 = ax2.axvline(-p0 / hbar, c='k', ls=':', label=r'$\pm p_0$')
p0_line2 = ax2.axvline(p0 / hbar, c='k', ls=':')


ax2.legend(prop=dict(size=12))
ax2.set_xlabel('$p$')
ax2.set_ylabel(r'$|\psi(k)|$')

V_x_line.set_data(S.x, S.V)
psi_x_line.set_data(S.x, 4 * abs(S.psi_x))
V_x_line.set_data(S.x, S.V)


psi_k_line.set_data(S.k_x, abs(S.psi_k))
title.set_text("t = %-5i" % S.t)
plt.show()

######################################################################
# Functions to Animate the plot

pause = True

def init():
    psi_x_line.set_data([], [])
    psi_x_real.set_data([], [])
    psi_x_imag.set_data([], [])
    V_x_line.set_data([], [])
    psi_k_line.set_data([], [])
    title.set_text("")

    return (psi_x_line, V_x_line, psi_k_line, title)

def animate(i):
    global S
    if not pause:
        S.evolution(50)
        psi_x_line.set_data(S.x, 4 * abs(S.psi_x))
        psi_x_real.set_data(S.x, 4 * S.psi_x.real)
        psi_x_imag.set_data(S.x, 4 * S.psi_x.imag)
        V_x_line.set_data(S.x, S.V)

        psi_k_line.set_data(S.k_x, abs(S.psi_k))
        title.set_text("t = %-5i" % S.t)
        return (psi_x_line, V_x_line, psi_k_line, title)
    else:
        anim.event_source.stop()

def onClick(event):
    global pause
    
    pause ^= True
    if button_pause.description == "Pause":
        button_pause.description = "Play"
        anim.event_source.stop()
    else:
        button_pause.description = "Pause"
        anim.event_source.start()


# call the animator.
# blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=30, blit=True)

psi1 = widgets.Checkbox(value=True, description=r'$|\psi(x)|$',    disabled=False, layout=Layout(width='25%'))
psi2 = widgets.Checkbox(value=False, description=r'$\psi(x)_r$', disabled=False, layout=Layout(width='25%'))
psi3 = widgets.Checkbox(value=False, description=r'$\psi(x)_i$', disabled=False, layout=Layout(width='25%'))


def on_psi1_change(b):
    psi_x_line.set_visible(psi1.value)
    
def on_psi2_change(b):
    psi_x_real.set_visible(psi2.value)
    
def on_psi3_change(b):
    psi_x_imag.set_visible(psi3.value)

psi1.observe(on_psi1_change, names='value')
psi2.observe(on_psi2_change, names='value')
psi3.observe(on_psi3_change, names='value')

button_pause = Button(description="Play");
button_pause.on_click(onClick)

display(HBox([psi1, psi2, psi3]));
display(button_pause);