In [1]:
import numpy as np
import plotly.express as px
from scipy.special import hermite

Step 1: Take an arbitrary set of coefficients and plot the wavefunction

Step 2: Do this in 2D? or maybe have it time evolving?


In [2]:
def qho_n(n, x, m, h_bar, omega):
    xi = np.sqrt(m*omega/h_bar)*x
    A_n = np.math.pow((m*omega)/(np.pi*h_bar), 0.25)/np.sqrt(float(np.math.factorial(n)*(2**n)))
    psi_n = A_n*hermite(n)(xi)*np.exp(-np.square(xi)/2)
    return psi_n

class WaveFunction:

    def __init__(self, coeff, 
                 x_min=-20, 
                 x_max=20, 
                 resolution=10000,
                 m=1,
                 omega=1,
                 h_bar=1):
        self.coeffs = self.normalize(coeff)
        self.x_max, self.x_min, self.resolution = x_max, x_min, resolution
        self.x = np.linspace(x_min, x_max, resolution)
        self.psi= np.zeros(resolution)

        self.states = []
        for i in coeff:
            self.states.append(i)
            self.psi+= coeff[i]*qho_n(i, self.x, m, h_bar, omega)

    def normalize(self, coeffs):
        if isinstance(coeffs, list):
            s = np.sum(np.square(np.array(coeffs)))

        if isinstance(coeffs, dict):
            _temp = np.zeros(max(coeffs.keys())+1)
            for i in coeffs:
                _temp[i] = coeffs[i]
            coeffs = _temp
            s = np.sum(np.square(np.array(coeffs)))
            
        if s==1:
            return np.array(coeffs)
        else:
            return np.array(coeffs)/s
        
    
    def is_normalized(self):
        return np.abs(np.sum((self.psi**2)*(self.x_max-self.x_min)/self.resolution)-1) < 1e-4


class Plotter:

    def __init__(self, figsize=[20,6]):
        self.figsize = figsize

    def plotWaveFunction(self, wf, save_image=False):
        fig = px.line(x=wf.x, y=wf.psi)
        fig.show()
        if save_image==True:
            fig.write_image(f"./qho_state_{wf.states[0]}.jpeg")

    def plotProbability(self, wf, save_image=False):
        fig = px.line(x=wf.x, y=np.square(wf.psi))
        fig.show()
        if save_image==True:
            fig.write_image(f"./qho_state_{wf.states[0]}_prob.jpeg")


In [3]:
for i in range(5):
    wf = WaveFunction({i:1})
    plotter = Plotter()
    plotter.plotWaveFunction(wf, True)
    plotter.plotProbability(wf, True)

In [4]:
wf.is_normalized()

True