In [None]:
%matplotlib widget
import numpy as np
import time
import matplotlib
from matplotlib import animation, rc
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import cmath
from scipy.sparse import diags, linalg

class Wavepacket : 
    #N and L are multiplied by 10
    def __init__(self, N=6000, L=1000., dt=0.05, gaussian=False, ):

        self.h_bar = 1.0             # Planck's constant / 2pi in natural units
        self.mass = 1.0              # particle mass in natural units

        # The spatial grid
        self.N = N                   # number of interior grid points
        self.L = L                   # system extends from x = 0 to x = L
        self.dx = L / float(N + 1)   # grid spacing
        self.dt = dt                 # time step

        # The potential V(x)
        self.V_0 = 0.5               # height of potential barrier
        self.V_width = 20.0          # width of potential barrier
        self.dist=50
        self.V_center=[]
        for i in range(int(L)+1):
            if i%self.dist==0:
                self.V_center.append(i)

        self.gaussian = False     # True = Gaussian potential, False = step potential


        # Initial wave packet
        self.x_0 = L / 4.0           # location of center
        self.E = 50                # average energy
        self.sigma_0 = L / 10.0      # initial width of wave packet
        self.psi_norm = 1.0          # norm of psi
        self.k_0 = 0               # average wavenumber
        self.velocity = 0         # average velocity
        self.t = 0.0                 # time
        self.psi = np.zeros(N, dtype=complex) # complex wavefunction        
        self.x = np.arange(self.N) * self.dx # vector of grid points

       # initialize the packet
        self.k_0 = np.sqrt(2*self.mass*self.E - self.h_bar**2 / 2 / self.sigma_0**2) / self.h_bar
        self.velocity = -self.k_0 / self.mass
        self.psi_norm = 1 / np.sqrt(self.sigma_0 * np.sqrt(np.pi))
        exp_factor = np.exp( - (self.x - self.x_0)**2 / (2 * self.sigma_0**2))
        # Wavefunction: 
        self.psi = (np.cos(-self.k_0 * self.x) + 1j * np.sin(-self.k_0 * self.x)) * exp_factor * self.psi_norm #Used -self.k_0 in the wavefunction
        # wavefunction for simplified Crank-Nicholson : 
        self.chi = np.zeros(N, dtype=complex)


        # elements of tridiagonal matrix Q = (1/2)(1 + i dt H / (2 hbar))
        self.a = np.full( self.N-1, - 1j * self.dt * self.h_bar / (8 * self.mass * self.dx**2) )
        self.b = 0.5 + 1j * self.dt / (4 * self.h_bar) * (self.V(self.x) + self.h_bar**2 / (self.mass * self.dx**2))
        self.c = np.full( self.N-1, - 1j * self.dt * self.h_bar / (8 * self.mass * self.dx**2) )

    def V(self, x):
        half_width = np.absolute(0.5 * self.V_width)
        if self.gaussian:
            return self.V_0 * np.exp(-(x - self.V_center)**2 / (2 * half_width**2))
        else:
            return np.where( (x%self.dist <= half_width) | ((x+half_width)%self.dist<=half_width), self.V_0, 0.0)
    """
    x%self.dist<=half_width is true if x is within the right half of a potential well
    (x+half_width)%self.dist<=half_width is true if x is within the left half of a potential well
    """

    def step_psi(self):
        T = diags([ self.b, self.a, self.c], [0,-1,1], format = 'csc')
        chi = linalg.spsolve(T,self.psi)
        self.psi = chi - self.psi
        return self.psi


In [None]:

class Animator :
    def __init__(self, wavepacket=None):
        self.avg_times = []
        self.wavepacket = wavepacket       
        self.t = 0.
        self.fig, self.ax = plt.subplots()
        #I loop through each V_center value in the list and mark a red axvline where they have walls. 
        for w in wavepacket.V_center:
            self.myline=plt.axvline(x=(w-.5*self.wavepacket.V_width)/ self.wavepacket.dx, color="r")
            self.myline=plt.axvline(x=(w+.5*self.wavepacket.V_width)/ self.wavepacket.dx, color="r")

        self.ax.set_ylim(0,0.5)
        initvals = np.absolute(self.wavepacket.psi)
        self.line, = self.ax.plot(initvals)
        

    def update(self, data) :
        self.line.set_ydata(data)
        return self.line,
        
    def time_step(self):
        self.wavepacket.psi = self.wavepacket.step_psi()
        self.t += self.wavepacket.dt;
        yield np.absolute(self.wavepacket.psi)

    def animate(self) :
        self.ani = animation.FuncAnimation( self.fig,        # Animate our figure
                                            self.update,     # Update function draws our data
                                            self.time_step,  # "frames" function does the time step, each iteration
                                            interval=50,     # 50 ms between iterations
                                            blit=False,       # don't blit anything
                                            cache_frame_data=False
                                            )

In [None]:
wavepacket = Wavepacket()
animator = Animator(wavepacket=wavepacket)
animator.animate()