In [123]:
import numpy as np
import math as mt
import cmath
import numba
import matplotlib.pyplot as plt
import scipy

#############################################
#Define a tri-diagonal matrix solver function
#############################################

def solveTriDiagonalMatrix(a, b, c, rhs, Nx):
    
    #The y_vector we will return as the result
    y=np.zeros(Nx, dtype=complex)
    
    #step 1: c0 <-- c0/b0 and d0 <-- d0/b0
    c[0]=c[0]/b[0]
    rhs[0]=rhs[0]/b[0]
    
    #step 2: 
    for i in range(1, Nx-1):
        
        #ci <-- ci/(bi-ai*c_{i-1})
        c[i]=c[i]/(b[i]-a[i]*c[i-1])
        
    for i in range(1, Nx):
        #di <-- (di-ai*d_{i-1})/(bi-ai*c_{i-1})
        rhs[i]=(rhs[i]-a[i]*rhs[i-1])/(b[i]-a[i]*c[i-1])
            
    #step 3:
    
    #y_{Nx-1} <-- d_{Nx-1}
    y[Nx-1]=rhs[Nx-1]
    
    #yi <-- di-ci*y_{i+1}
    for i in range(Nx-2, -1, -1):
        y[i]=rhs[i]-c[i]*y[i+1]
    
    return y
    
    
#############################################################
#Define the necessary constants and auxillary data structures
#############################################################

#The constants
hbar=6.5821e-16
hbar_squared_over_2m=3.801
two_m_over_hbar=1/(hbar_squared_over_2m/hbar)
k0=2
s=20
L=1000
V0=3.90
wv=7
mult=10
Ttot=mult*1e-14

# #Mass of the electron
# m=0.5*hbar*hbar/KEfactor

#Number of time and space steps
numXsteps=10000
numTsteps=mult*100

#Space and time increments
dx=L/numXsteps
dt=Ttot/numTsteps

w=2*dx*dx/dt

#Defining the x vector holding the x positions
xvec=np.zeros([numXsteps])
for i in range(len(xvec)):
    xvec[i]=i*dx

#The a and c coefficients in our tridiagonal matrix
aj=1
cj=1

#The potential our wavefunction encounters
def potential(x):
    
#     return V0/(1+np.exp((0.5*L-x)/wv))
    
#     if x==1 or x==999:
#         return 10000
    
#     elif x>450:
#         return -100
#     else:
#         return 0

#     if x<550 and x>450:
#         return -5
#     else:
#         return 0

#     return 0.0001*(x-500)**2-2

    if 0.35*L<x<0.4*L:
        return 10
    else:
        return -0.25

#The initial wavefunction for t=0
def psi0(x):
    return np.exp(-((x-0.1*L)/s)*((x-0.1*L)/s)+x*k0*(1j))

#The b coefficients in our tridiagonal matrix
def bval(i):
    return (two_m_over_hbar)*w*complex(0,1)\
    -2-((two_m_over_hbar/hbar)*(dx*dx))*potential(i*dx) 

#Initializing the wavefunction at time t=0
y0=np.zeros(numXsteps, dtype=complex)

for i in range(0, len(y0)):
    y0[i]=psi0(i*dx)

#Finding the modulus squared of y0 the initial wavefunction
ySquared0=np.zeros([len(y0)])

for i in range(len(ySquared0)):
    ySquared0[i]=(abs(y0[i]))**2
 
#The vector of b values
bvec=np.zeros([len(y0)], dtype=complex)

for i in range(len(y0)):
    bvec[i]=bval(i)

#The vector of a values
avec=np.zeros([len(y0)], dtype=complex)

for i in range(1,len(y0)):
    avec[i]=aj

#The vector of c values
cvec=np.zeros([len(y0)-1], dtype=complex)

for i in range(len(y0)-1):
    cvec[i]=cj

##############################################################
#Repeatedly apply the matrix solver for t_total/dt time steps
##############################################################

t=0
index=0
yt=y0

d=np.zeros(len(y0), dtype=complex)

frames=[]

#Looping through the t-values
while(t<=Ttot):
    
    for j in range(len(y0)):
        
        #Updating the d vector
        constant=(two_m_over_hbar)*w*complex(0,1)\
        + 2 + ((two_m_over_hbar/hbar)*(dx*dx))*potential(j*dx)
        
        if (j!=0 and j!=(len(d)-1)):
            d[j] = -1*yt[j-1] + constant*yt[j] -1*yt[j+1]
        elif j==0:
            d[j] = constant*yt[j] -1*yt[j+1]
        elif (j==len(d)-1):
            d[j] = constant*yt[j] -1*yt[j-1]
    
    #Get yt, the solution after one time step
    yt=solveTriDiagonalMatrix(avec, bvec, cvec, d, len(yt))
    
    #Update the time
    index=index+1
    t=dt*index

    #Update the arrays
    avec=np.ones(len(yt), dtype=complex)
    avec[0]=0
    
    cvec=np.ones(len(yt)-1, dtype=complex)
    
    bvec=np.zeros(len(yt), dtype=complex)

    for i in range(len(yt)):
        bvec[i]=bval(i)
        
    frames.append(yt)
        
#########################################
#Plot the solutions at various time steps
#########################################

#Setting up the potential vector
potvec=[]

for i in range(len(xvec)):
    potvec.append(potential(xvec[i]))

#Finding the modulus squared of yt
ySquared=np.zeros([len(yt)])

for i in range(len(ySquared)):
    ySquared[i]=(abs(yt[i]))**2
    
#Calculating the integral of |ySquared0|^2:
integral1=0

for i in range(len(ySquared0)):
    
    integral1=integral1+ySquared0[i]*dx
    
print("integral of initial |y0|^2 is "+str(integral1))

#Calculating the integral of |yt|^2:
integral=0

for i in range(len(ySquared)):
    
    integral=integral+ySquared[i]*dx
    
print("integral of current |yt|^2 is "+str(integral))

#Plotting the above data for y0 at t=0
plt.plot(xvec, potvec, color='k', label='potential')
plt.plot(xvec, np.real(y0), color='b', label='real part')
#plt.plot(xvec, np.imag(y0), color='c', label='imaginary part')
plt.plot(xvec, ySquared0, color='r', label='modulus squared')
plt.xlabel("x-position in Angstroms")
plt.ylabel("Amplitude")
plt.xlim(200, 600)
plt.ylim(-1.1,1.5)
plt.title("Schrodinger Problem at t=0")
plt.legend()
plt.show()

#Plotting the above data for yt at t=tTot
plt.plot(xvec, potvec, color='k', label='potential')
plt.plot(xvec, np.real(yt), color='b', label='real part')
#plt.plot(xvec, np.imag(yt), color='c', label='imaginary part')
plt.plot(xvec, ySquared, color='r', label='modulus squared')
plt.xlabel("x-position in Angstroms")
plt.ylabel("Amplitude")
plt.xlim(0, 1000)
plt.ylim(-1.5,1.5)
plt.title("Schrodinger Problem at t="+str(Ttot))
plt.legend()
plt.show()

integral of initial |y0|^2 is 25.066282746309977
integral of current |yt|^2 is 25.066282746313945


In [126]:
#Animation cell
%matplotlib

for i in range(len(frames)):
    
    plt.plot(xvec, potvec, color='k', label='potential')
    plt.plot(xvec, np.real(frames[i]), color='b', label='real part')
    #plt.plot(xvec, np.imag(frames[i]), color='c', label='imaginary part')
    plt.plot(xvec, (abs(frames[i]))**2, color='r', label='Mod. Squared')
    #plt.xlim(200, 800)
    plt.ylim(-1.5,1.5)
    plt.legend()
    plt.title("1-D T.D.S.E. via Crank-Nicholson Method")
    plt.xlabel("X-Position")
    plt.ylabel("Amplitude")
    plt.show()
    plt.pause(0.001)
    
    if i<len(frames)-1:
        plt.clf()
    

Using matplotlib backend: MacOSX


In [117]:
#Now we do the 2D schrodinger equation
#Working in units where hbar=m=1

from numba import jit

import numpy as np
import math as mt
import matplotlib.pyplot as plt
import seaborn as sns

import time
start_time = time.time()

#Define the constants of the system

#The total length and width of the system
L=10

mult=40

numTsteps=250*mult
# numTsteps=1

#The spatial step
dz=0.1

#The time step
dt=0.0001

#The x wave vector
kx=100

#The y wave vector
ky=100

#Energy of the wavefunction
E=1

#The potential our wavefunction encounters
#Zero everywhere but "infinite" at boundaries
def V(i,j):
    
#     #Square in middle
#     if 0.45*L<i*dz<0.55*L and 0.45*L<j*dz<0.55*L:
#         return 1000
    
#     #Other space
#     else:
#         return 0

    #The double slit
    if i*dz==0.5*L and (0<=j*dz<=0.33*L or 0.43*L<=j*dz<=0.57*L or 0.67*L<j*dz<=L):
        return 1000
    else:
        return 0
    

#Define the initial form Y(x,y,t=0) of the Wavefunction
Y0=np.zeros((int(L/dz), int(L/dz)), dtype=complex)

for i in range(len(Y0)):
    for j in range(len(Y0)):
    
        Y0[i,j]=np.exp(-(i*dz-0.2*L)**2-(j*dz-0.5*L)**2+1j*i*dz*kx)
        #Y0[i,j]=np.exp(-(i*dz-0.5*L)**2-(j*dz-0.5*L)**2)

#Make sure the endpoints of the wavefunction are constant
#Update each point of the wavefunction after a step dt for a total time T
WavesOld=Y0

#The wave function after one time step dt is Y0*exp(-i*E*dt/hbar) if V(x) is time independent
#Waves=Y0*np.exp(-1j*E*dt)
Waves=np.zeros((int(L/dz), int(L/dz)), dtype=complex)

for i in range(len(Y0)):
    for j in range(len(Y0)):
    
        Waves[i,j]=np.exp(-(i*dz-0.2*L)**2/(1+2j*dt)-(j*dz-0.5*L)**2/(1+2j*dt)+1j*i*dz*kx)/(1+2j*dt)**0.5

#WavesNew we don't know yet
WavesNew=np.zeros((len(Y0), len(Y0)), dtype=complex)

# #Plotting the initial position of the wavefunction at t=0
# plt.figure()
# ax = sns.heatmap(np.transpose((abs(Y0))**2), vmin=0, vmax=1)
# ax.invert_yaxis()
# plt.xlabel("X")
# plt.ylabel("Y")
# plt.title("Waves at t=0")
# plt.show()

# #Plotting the position of the wavefunction at t=dt
# plt.figure()
# ax = sns.heatmap(np.transpose((abs(Waves))**2), vmin=0, vmax=1)
# ax.invert_yaxis()
# plt.xlabel("X")
# plt.ylabel("Y")
# plt.title("Waves at t=dt")
# plt.show()

t=0

frames=np.zeros((numTsteps+1,len(Y0),len(Y0)), dtype=complex)

frames[0]=Waves

index=1

#Keep evolving the Wavefunction until we reach time T
for t in range(numTsteps):
    
    #Calculating values for WavesNew
    for i in range(1,len(Y0)-1):
        for j in range(1,len(Y0)-1):
            
            WavesNew[i,j]=WavesOld[i,j]+(1j*dt/dz**2)*(Waves[i-1,j]+Waves[i,j-1]-4*Waves[i,j]+Waves[i+1,j]+Waves[i,j+1])-1j*2*dt*V(i,j)*Waves[i,j]
                        
            if np.array_equiv(Waves, WavesNew):
                print("WavesNew the same as Waves")
        
    #Overwrite WavesOld with Waves; and Waves with WavesNew
    for i in range(1,len(Y0)-1):
        for j in range(1,len(Y0)-1):
            
            WavesOld[i,j]=Waves[i,j]
            Waves[i,j]=WavesNew[i,j]
 
    frames[index]=Waves
    index=index+1

#  ##Plotting the final position of the wavefunction at t=T
# plt.figure()
# ax = sns.heatmap(np.transpose((abs(Waves))**2), vmin=0, vmax=1)
# ax.invert_yaxis()
# plt.xlabel("X")
# plt.ylabel("Y")
# plt.title("Waves at t=T")
# plt.show()

print("--- %s seconds ---" % (time.time() - start_time))

--- 2443.4992389678955 seconds ---


In [96]:
#Test if the wavefunctions in the various frames are all different from each other

print(len(frames))
for k in range(1,len(frames)):
    print(k)
    if np.array_equiv(frames[k-1], frames[k]):
        print("There's no difference for frame "+str(k))
print("They were all different!")

251
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
They were all different!


In [90]:
#Plot each frame in the frames list

t=0

for frame in frames:
    
    plt.figure()
    ax = sns.heatmap(np.transpose((abs(frame))**2), vmin=0, vmax=1)
    ax.invert_yaxis()
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title("Waves at t="+str(t))
    plt.vlines(x = 5*L, ymin = 0, ymax = 3.3*L, colors = 'white')
    plt.vlines(x = 5*L, ymin = 4.3*L, ymax = 5.7*L, colors = 'white')
    plt.vlines(x = 5*L, ymin = 6.7*L, ymax = 10*L, colors = 'white')
    plt.show()
    
    t=t+dt

In [79]:
for i in range(1,len(frames)):

        ax = sns.heatmap(np.transpose((abs(frames[i]))**2), vmin=0, vmax=1)
        ax.invert_yaxis()
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.title("Heatmap")

        plt.show()
        plt.pause(0.001)

Exception ignored in: <function TransformNode.set_children.<locals>.<lambda> at 0x7ff255b4c3a0>
Traceback (most recent call last):
  File "/Users/zaneblood/opt/anaconda3/lib/python3.8/site-packages/matplotlib/transforms.py", line 200, in <lambda>
    self, lambda _, pop=child._parents.pop, k=id(self): pop(k))
KeyboardInterrupt: 


KeyboardInterrupt: 

In [118]:
#Animation cell 
%matplotlib

truth=True

plt.clf()

while truth==True:

    for i in range(1,len(frames)):

        if i%100==0:
        
            ax = sns.heatmap(np.transpose((abs(frames[i]))**2), vmin=0, vmax=1)
            ax.invert_yaxis()
            plt.xlabel("X")
            plt.ylabel("Y")
            plt.title("Heatmap")

            plt.show()
            plt.pause(0.001)
        
            if i<len(frames)-1:
                plt.clf()
        

    plt.clf()

Using matplotlib backend: MacOSX


KeyboardInterrupt: 

In [101]:
plt.clf()

for i in range(1,len(frames)):

        if i%100==0:
        
            ax = sns.heatmap(np.transpose((abs(frames[i]))**2), vmin=0, vmax=1)
            ax.invert_yaxis()
            plt.xlabel("X")
            plt.ylabel("Y")
            plt.title("Heatmap")

            plt.show()
            plt.pause(1)
        
            if i<len(frames)-1:
                plt.clf()

In [103]:

ax = sns.heatmap(np.transpose((abs(frames[len(frames)-1]))**2), vmin=0, vmax=1)
ax.invert_yaxis()
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Heatmap")

plt.show()

import os
print('Number of CPUs in the system: {}'.format(os.cpu_count()))

Number of CPUs in the system: 8


In [None]:
import math
import numpy as np
from timebudget import timebudget
from multiprocessing import Pool

iterations_count = round(1e7)

def complex_operation:
    

@timebudget
def run_complex_operations(operation, input, pool):
    pool.map(operation, input)

processes_count = 10

if __name__ == '__main__':
    processes_pool = Pool(processes_count)
    run_complex_operations(complex_operation, range(10), processes_pool)