# SimOpt: quantum wave packet

In [None]:
import csv, math, os, shutil, sys
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.animation import FuncAnimation

# !pip install celluloid
from celluloid import Camera
from IPython.display import HTML
from base64 import b64encode

In [None]:
def V(x, N):
    return np.linspace(1, -1, N)

def gaussian(x, mu, sig):
    norm = math.sqrt(2 * math.pi) * sig
    return np.exp(-((x - mu) * (x - mu)) / (2 * sig * sig)) / norm

def construct_operators(N, alpha, beta):
    ST, SV = np.zeros((2, N, N)) # both are [N x N] matrices
    temp = np.arange(N)
    for l in range(N):
        ST[:, l] += temp
        ST[l, :] -= np.transpose(temp)
        SV[:, l] += V(temp, N)
        SV[l, :] -= V(np.transpose(temp), N)

    ST *= alpha * ST
    SV *= beta
    ST = torch.tensor(ST) * 1
    SV = torch.tensor(SV) * 0
    K = (torch.complex(0*ST, ST) - torch.complex(0*SV, SV)).exp()
    return ST, SV, K

def normalize(psi):
    return psi / psi.norm() #.abs().pow(2).sum().detach()

class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
        
def init(constants):
    dx = constants.xrange / constants.numdx  # scalar
    dt = constants.trange / constants.numdt  # scalar
    alpha = constants.m * (dx ** 2 / (2 * constants.hbar ** 2 * dt))  # scalar
    beta = dt / constants.hbar ** 2  # scalar
    x = np.linspace(0, constants.xrange, constants.numdx)
    t = np.linspace(0, constants.trange, constants.numdt)

    ST, SV, K = construct_operators(constants.numdx, alpha, beta) # all three are [numdx x numdx]

    r0 = gaussian(x, 1.5, .5)[:,None] # r0: [N x 1]
    r0[0, 0] = r0[-1, 0] = 0
    q0 = np.concatenate((np.cos(constants.mom * x)[:,None] * r0,
                         np.sin(constants.mom * x)[:,None] * r0), axis=1)
    q = torch.view_as_complex(torch.tensor(q0))
    return x, t, ST, SV, K, q

def make_video(x, qs, path, interval=60, **kwargs): # xs: [time, N, 2]
    plt.title('Dynamics of a 1D gaussian wave packet')
    fig = plt.gcf() ; fig.set_dpi(100) ; fig.set_size_inches(5, 3)
    camera = Camera(fig)
    for i, q in enumerate(qs):
        probs = (q.conj() * q).real.detach()
        plt.plot(x, probs, 'b', label='Spatial probability')
        camera.snap()
    anim = camera.animate(blit=True, interval=interval, **kwargs)
    anim.save(path) ; plt.close()
    
def simulate(q, K, steps=500):
    q = normalize(q)
    qs = [q.clone()]
    for i in range(steps):
        q = normalize(q @ K)
        qs.append(q.clone())
    return torch.stack(qs)

In [None]:
%%time
constants = dict(hbar=1, m=1, numdx=1500, numdt=150, xrange=3.5, trange=1, mom=10)  # constants
constants = ObjectView(constants) # dict -> object

x, t, ST, SV, K, q = init(constants)

In [None]:
# plt.plot(x, (q.conj() * q).real.detach())

In [None]:
%%time
steps = constants.numdt // constants.trange
qs = simulate(q, K, steps)

In [None]:
q0, q1 = qs[:2]

In [None]:
_x = torch.log(q0).imag

In [None]:
(_x @ ST).sum()

In [None]:
%%time
path = 'sim.mp4'
make_video(x, qs, path, interval=60)

mp4 = open(path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML('<video width=400 controls><source src="{}" type="video/mp4"></video>'.format(data_url))

In [None]:
plt.figure(figsize=[6,3], dpi=130)
plt.subplot(1,2,1)
plt.imshow(K.real) ; plt.title('K (real)') ; plt.clim(-1,1)
plt.subplot(1,2,2)
plt.imshow(K.imag) ; plt.title('K (imag)') ; plt.clim(-1,1)
plt.tight_layout()

In [None]:
plt.figure(figsize=[6,3], dpi=130)
plt.subplot(1,2,1)
plt.imshow(ST) ; plt.title('ST') ; plt.clim(0,1000)
plt.subplot(1,2,2)
plt.imshow(SV) ; plt.title('SV') ; plt.clim(0,1000)
plt.tight_layout()