# SimOpt: quantum wave packet

In [None]:
import csv, math, os, shutil, sys
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.animation import FuncAnimation

# !pip install celluloid
from celluloid import Camera
from IPython.display import HTML
from base64 import b64encode

In [None]:
constants = dict(hbar=1, m=1, numdx=10000, numdt=500, xrange=5, trange=1)

dx = constants['xrange'] / constants['numdx']
dt = constants['trange'] / constants['numdt']
alpha = constants['m'] * (dx ** 2 / (2 * constants['hbar'] ** 2 * dt))
beta = dt / constants['hbar'] ** 2
x = np.arange(constants['numdx']) * dx
t = np.arange(constants['numdt']) * dt


def V(x):
    v = np.linspace(1, -1, constants['numdx'])
    return v


ST = np.zeros((constants['numdx'], constants['numdx']))
SV = np.zeros((constants['numdx'], constants['numdx']))
temp = np.arange(constants['numdx'])

for l in range(constants['numdx']):
    ST[:, l] += temp
    ST[l, :] -= np.transpose(temp)
    SV[:, l] += V(temp)
    SV[l, :] -= V(np.transpose(temp))

ST *= alpha * ST
SV *= beta
ST = torch.from_numpy(ST) * 1
SV = torch.from_numpy(SV) * 0
K = torch.exp(torch.complex(0*ST, ST) - torch.complex(0*SV, SV))
K[-2500:][:] *= 0


def gaussian(x, mu, sig):
    norm = math.sqrt(2 * math.pi) * sig
    return np.exp(-((x - mu) * (x - mu)) / (2 * sig * sig)) / norm


r0 = np.atleast_2d(gaussian(x, 2, .5)).T
r0 = r0 + 0 * np.ones_like(r0)
r0[0, 0] = r0[-1, 0] = 0
mom = 10
q0 = np.concatenate((np.atleast_2d(np.cos(mom * x)).T * r0, np.atleast_2d(np.sin(mom * x)).T * r0), axis=1)
q = torch.view_as_complex(torch.from_numpy(q0))

In [None]:
def normalize(psi):
    return psi / psi.norm() #.abs().pow(2).sum().detach()

def simulate(q, K, steps=500):
    q = normalize(q)
    qs = [q.clone()]
    for i in range(steps):
        q = normalize(torch.matmul(q, K))
        qs.append(q.clone())
    return torch.stack(qs)
    
q = torch.view_as_complex(torch.from_numpy(q0))
qs = simulate(q, K)

In [None]:
def make_video(x, qs, path, interval=60, **kwargs): # xs: [time, N, 2]
    fig = plt.gcf() ; fig.set_dpi(150) ; fig.set_size_inches(5, 3)
    camera = Camera(fig)
    for q in qs:
        plt.plot(x, q.abs().detach(), 'b', label='Spatial probability')
        camera.snap()  # plt.ylim(-0.5,1.5)
    anim = camera.animate(blit=True, interval=interval, **kwargs)
    anim.save(path) ; plt.close()

plt.title('Dynamics of a 1D Quantum Wave Packet')
path = 'sim.mp4' ; make_video(x, qs[::3], path, interval=60)
mp4 = open(path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

HTML("""
<video width=600 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
# def normalize(psi):
#     return psi / psi.norm() #.abs().pow(2).sum().detach()

# q = torch.view_as_complex(torch.from_numpy(q0))
# q_ = q
# q_ = normalize(q_)

# q_ @ q_, q_.norm()