# **Numerical Solution of 1D Time Dependent Schrödinger Equation by Split Operator Fourier Transform (SOFT) Method**

[Go back to index](./index.ipynb)

<hr style="height:1px;border:none;color:#cccccc;background-color:#cccccc;" />

<p style="text-align: justify;font-size:15px">  
    In this exercise we will see a <strong>rare</strong> example of algorithms for solving the 
    quantum dynamical problem exactly. Let's consider a time-independent Hamiltonian and its 
    associated time-dependent Schroedinger equation for a system of one particle in one dimension
</p>

$i\hbar\frac{d}{dt}|\psi> = \hat{H}|\psi> \quad \text{where} \quad \hat{H} = \frac{\hat{P}^2}{2m} + E(\hat{X})$

<hr style="height:1px;border:none;color:#cccccc;background-color:#cccccc;" />

In [10]:
import numpy as np
from math import pi
from scipy.fftpack import fft, ifft
from ipywidgets import Button, FloatSlider, HBox, VBox

In [11]:
class SOFT(object):
    """
    This is the Split Operator Fourier Transform (SOFT)
    x: the grid space 
    psi_0: inital wavefunction, included real and image parts
    V: the potential
    dt: the time space
    """
    
    def __init__(self, x, psi_0, V, dt, hbar=1.0, m = 1.0):
        self.N = len(x);
        self.hbar = hbar;
        self.m = m;
        
        dx = x[1] -x [0];
        dk = (2*pi)/(self.N*dx);
        
        self.x = x;
        self.dx = dx;
        self.dk = dk;
        self.t = 0.0;
        self.dt = dt;
        self.m = m;
        
        self.k_x = -0.5*self.N*self.dk + dk*np.arange(self.N);
        self.psi_x = np.zeros((self.N), dtype=np.complex128)
        self.psi_k = np.zeros((self.N), dtype=np.complex128)
        self.psi_mod_x = np.zeros((self.N), dtype=np.complex128)
        self.psi_mod_k = np.zeros((self.N), dtype=np.complex128)
        
        self.psi_x = psi_0;
        self._periodic_bc();
        
        if callable(V):
            self.V = V(x)
        else:
            self.V = V
            
        self.psi_mod_x[:] = self.dx/np.sqrt(2*pi)*self.psi_x[:]*np.exp(-1.0j*self.k_x[0]*self.x);
        
    def _periodic_bc(self):
        self.psi_x[-1] = self.psi_x[0]
               
    def _half_pot_prop(self, ft=True):
        if ft == True:
            self.psi_mod_x[:] = ifft(self.psi_mod_k[:])
        self.psi_mod_x[:] = self.psi_mod_x[:]*np.exp(-1.0j*(self.dt/2.0)*self.V)

    def _full_kinetic_prop(self):
        self.psi_mod_k[:] = fft(self.psi_mod_x[:])
        self.psi_mod_k[:] = self.psi_mod_k[:]*np.exp(-1.0j*self.k_x**2*self.dt/(2.0*self.m))
        
    def _compute_psi(self):
        self.psi_x[:] = (np.sqrt(2*pi)/self.dx)*self.psi_mod_x*np.exp(1.0j*self.k_x[0]*self.x)
        self.psi_k[:] = self.psi_mod_k*np.exp(-1.0j*self.x[0]*self.dk*np.arange(self.N))
        
    def evolution(self, Nsteps=1):
        for i in range(Nsteps):
            self._half_pot_prop(ft = False)
            self._full_kinetic_prop()
            self._half_pot_prop(ft = True)
        self._compute_psi()
        self.psi_mod_x /= np.sqrt((abs(self.psi_mod_x)**2).sum()*2*pi/self.dx);
        self.t += self.dt*Nsteps
        

In [12]:
%matplotlib widget

import matplotlib.pyplot as plt
from matplotlib import animation



######################################################################
# Helper functions for gaussian wave-packets
def gauss_x(x, a, x0, k0):
    """
    a gaussian wave packet of width a, centered at x0, with momentum k0
    """
    return ((a * np.sqrt(np.pi)) ** (-0.5)
            * np.exp(-0.5 * ((x - x0) * 1. / a) ** 2 + 1j * x * k0))


def gauss_k(k, a, x0, k0):
    """
    analytical fourier transform of gauss_x(x), above
    """
    return ((a / np.sqrt(np.pi)) ** 0.5
            * np.exp(-0.5 * (a * (k - k0)) ** 2 - 1j * (k - k0) * x0))


######################################################################
# Utility functions for running the animation
def theta(x):
    """
    theta function :
      returns 0 if x<=0, and 1 if x>0
    """
    x = np.asarray(x)
    y = np.zeros(x.shape)
    y[x > 0] = 1.0
    return y


def square_barrier(x, width, height):
    return height * (theta(x + 0.5*width) - theta(x - 0.5*width))

def parabola(x):
    return 0.001*x**2

######################################################################
# Create the animation

# specify time steps and duration
dt = 0.01
N_steps = 50
t_max = 120
frames = int(t_max / float(N_steps * abs(dt)))

# specify constants
hbar = 1.0   # planck's constant
m = 1.9      # particle mass

# specify range in x coordinate
N = 2 ** 11
dx = 0.1
x = dx * (np.arange(N) - 0.5 * N)

# specify potential
V0 = 1.5
L = hbar / np.sqrt(2 * m * V0)
a = 30 * L
x0 = -60 * L

# Set the depth and wide of the well potential
swide = FloatSlider(value = a, min = 0.0, max = 2*a, description = 'Wide: ');
sheight = FloatSlider(value = V0, min = 0.0, max = 2*V0, description = 'Height: ');


V_x = square_barrier(x, a, V0)
V_x[x < -98] = 1E6
V_x[x > 98] = 1E6

#V_x = parabola(x)
# specify initial momentum and quantities derived from it
p0 = np.sqrt(2 * m * 0.2 * V0)
dp2 = p0 * p0 * 1. / 80
d = hbar / np.sqrt(2 * dp2)

k0 = p0 / hbar
v0 = p0 / m
psi_x0 = gauss_x(x, d, x0, k0)

# define the Schrodinger object which performs the calculations
S = SOFT(x = x, dt = dt, psi_0=psi_x0, V=V_x, hbar=hbar, m=m)
S.evolution(1)
######################################################################
# Set up plot
fig = plt.figure(figsize=(7, 7))
fig.canvas.header_visible = False

# plotting limits
xlim = (-100, 100)
klim = (-5, 5)

# top axes show the x-space data
ymin = 0
ymax = V0
ax1 = fig.add_subplot(211, xlim=xlim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_x_line, = ax1.plot([], [], c='r', label=r'$|\psi(x)|$', linewidth=1.2)
V_x_line, = ax1.plot([], [], c='k', label=r'$V(x)$')
center_line = ax1.axvline(0, c='k', ls=':', label=r"$x_0 + v_0t$")

title = ax1.set_title("")
ax1.legend(prop=dict(size=12))
ax1.set_xlabel('$x$')
ax1.set_ylabel(r'$|\psi(x)|$')

# bottom axes show the k-space data
ymin = abs(S.psi_k).min()
ymax = abs(S.psi_k).max()
ax2 = fig.add_subplot(212, xlim=klim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_k_line, = ax2.plot([], [], c='r', label=r'$|\psi(k)|$', linewidth=1.2)

p0_line1 = ax2.axvline(-p0 / hbar, c='k', ls=':', label=r'$\pm p_0$')
p0_line2 = ax2.axvline(p0 / hbar, c='k', ls=':')
#mV_line = ax2.axvline(np.sqrt(2 * V0) / hbar, c='k', ls='--',
#                      label=r'$\sqrt{2mV_0}$')
ax2.legend(prop=dict(size=12))
ax2.set_xlabel('$k$')
ax2.set_ylabel(r'$|\psi(k)|$')

V_x_line.set_data(S.x, S.V)


######################################################################
# Functions to Animate the plot
def init():
    psi_x_line.set_data([], [])
    V_x_line.set_data([], [])
    center_line.set_data([], [])

    psi_k_line.set_data([], [])
    title.set_text("")
    return (psi_x_line, V_x_line, center_line, psi_k_line, title)

pause = True

def animate(i):
    global S
    if not pause:
        S.evolution(50)
    psi_x_line.set_data(S.x, 4 * abs(S.psi_x))
    V_x_line.set_data(S.x, S.V)
    center_line.set_data(2 * [x0 + S.t * p0 / m], [0, 1])

    psi_k_line.set_data(S.k_x, abs(S.psi_k))
    title.set_text("t = %i" % S.t)
    return (psi_x_line, V_x_line, center_line, psi_k_line, title)

def onClick(event):
    global pause
    pause ^= True
    if button_pause.description == "Pause":
        button_pause.description = "Play"
    else:
        button_pause.description = "Pause"
        
def on_update_change(event):
    global pause, S, V_x
    pause = True
    V_x = square_barrier(x, swide.value, sheight.value)
    print(x)
    V_x[x < -98] = 1E6
    V_x[x > 98] = 1E6
    del S
    S = SOFT(x = x, dt = dt, psi_0=psi_x0, V=V_x, hbar=hbar, m=m)
    S.evolution(1)
    button_pause.description = "Pause"
    pause = False
    
# call the animator.
# blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=30, blit=True)


# uncomment the following line to save the video in mp4 format.  This
# requires either mencoder or ffmpeg to be installed on your system
#anim.save('schrodinger_barrier.mp4', fps=15,
#          extra_args=['-vcodec', 'libx264'])

button_update = Button(description="Update potential")
button_pause = Button(description="Play")
button_pause.on_click(onClick)
button_update.on_click(on_update_change)

display(HBox([swide, sheight]), button_update);
plt.show()

display(button_pause)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatSlider(value=12.565617248750865, description='Wide: ', max=25.13123449750173), FloatSlider…

Button(description='Update potential', style=ButtonStyle())

Button(description='Play', style=ButtonStyle())