In [7]:
# IMPORTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
from RMT_Lib import *
from TamLib import *
from tqdm import tqdm # for progress bar

# for animations
import imageio as iio # for animations
from PIL import Image # for animations
from matplotlib import animation
from JSAnimation.IPython_display import display_animation


import time

# SETTINGS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
SMALL_SIZE = 14
MEDIUM_SIZE = 18
LARGE_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title
plt.rcParams.update({"text.usetex": True})

%matplotlib notebook
%matplotlib notebook


In [8]:

# equal axis in 3d
def set_axes_equal(ax):
    '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    '''

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])

# get cyl along axis given by vec in cartesian coord (= normal to cross section of cyl; vec that defines cyl axis)
# base = center of base of cylinder (x0,y0,z0)
# r = radius of cylinder; distance to axis
# h = height of cylinder
# nt = # samples
def getcylinder(r,h, nt=100, base=(0,0,0), vec = (0,0,1)):
    # get sph coord of vec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    x,y,z = vec
    R, Rsinth = np.sqrt(x**2 + y**2 + z**2), np.sqrt(x**2+y**2)
    pol = np.arcsin(Rsinth/R) if z > 0 else np.pi-np.arcsin(Rsinth/R) # polar angle
    
    # get cyl of appropriate dim along z axis ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    th = np.linspace(0,2*np.pi,nt) # angle
    THETA,Z = np.meshgrid(th,np.linspace(0,h,nt))
    z = np.linspace(0, h,nt)
    X,Y = r*np.cos(THETA), r*np.sin(THETA)

    
    # rotate about y axis so aligns with axis given by vec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    u = np.expand_dims(vec/la.norm(vec),axis=1) # unit vec along direction given by vec
    T = np.array([[np.cos(pol), 0, np.sin(pol)],[0,1,0],[-np.sin(pol),0,np.cos(pol)]])
    coord = np.dstack((X,Y,Z))
    newcoord = np.einsum('ij,kmj -> kmi',T,coord)
    X,Y,Z = newcoord[:,:,0], newcoord[:,:,1], newcoord[:,:,2]
   
    # apply shift ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    x0,y0,z0 = base
    X, Y, Z = X+x0, Y+y0, Z+z0 # shift  
    
    
    return X,Y,Z

# get cone which points in direction given by vec
# r = radius of base
# h = height of cone
# nt = # samples
def getcone(rb, h, nt=100, base_center=(0,0,0), vec=(0,0,1)):
    # Get sph coord of directional vec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    x,y,z = vec
    R, Rsinth = np.sqrt(x**2 + y**2 + z**2), np.sqrt(x**2+y**2)
    pol = np.arcsin(Rsinth/R) if z > 0 else np.pi-np.arcsin(Rsinth/R) # polar angle
    
    # Get cone pointing along + z ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    x0,y0,z0 = base_center
    th = np.linspace(0,2*np.pi,nt) # angle
    r = np.linspace(0,rb,nt)
    THETA, R = np.meshgrid(th, r)
    X,Y = R*np.cos(THETA), R*np.sin(THETA)
    Z = h-(h/rb)*np.sqrt(X**2 + Y**2)
    
    # Rotate to point along vec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    u = np.expand_dims(vec/la.norm(vec),axis=1) # unit vec along direction given by vec
    T = np.array([[np.cos(pol), 0, np.sin(pol)],[0,1,0],[-np.sin(pol),0,np.cos(pol)]])
    coord = np.dstack((X,Y,Z))
    newcoord = np.einsum('ij,kmj -> kmi',T,coord)
    X,Y,Z = newcoord[:,:,0], newcoord[:,:,1], newcoord[:,:,2]
    
    # Shift center of base ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    X,Y,Z = X+x0, Y+y0, Z+z0
    
    return X,Y,Z
    

# tip_coord and origin are in cartesian coordinates
# nt = # of angles to sample
def getarrow3d(tip_coord,origin=np.array([0,0,0]),nt=100, shaft_radius = 0.05, arrowtip_height=0.5, arrow_base_ratio = 2.5):
    # get unit vec in direction of arrow
    v = np.array(tip_coord) - origin
    mag = la.norm(v)
    u = v/mag
    pol = np.arcsin(np.sqrt(u[0]**2 + u[1]**2)) if v[2]>0 else np.pi- np.arcsin(np.sqrt(u[0]**2 + u[1]**2)) # get polar angle
    
    # get cylinder for arrow shaft --- aligned along z axis for now
    h = mag-arrowtip_height # shaft height
    xc,yc,zc = getcylinder(shaft_radius,h,nt=nt)
    # get cone for arrow tip --- aligned along z axis for now
    xt,yt,zt = getcone(shaft_radius*arrow_base_ratio,arrowtip_height,nt=nt)
    
    # rotate
    T = np.array([[np.cos(pol), 0, np.sin(pol)],[0,1,0],[-np.sin(pol),0,np.cos(pol)]])
    
    cylcoord = np.einsum('ij,kmj -> kmi',T,np.dstack((xc,yc,zc)))
    xc,yc,zc = cylcoord[:,:,0], cylcoord[:,:,1], cylcoord[:,:,2]
    tipcoord = np.einsum('ij,kmj -> kmi',T,np.dstack((xt,yt,zt)))
    xt,yt,zt = tipcoord[:,:,0], tipcoord[:,:,1], tipcoord[:,:,2]

    
    # shift
    x0,y0,z0 = origin
    xc,yc,zc = xc+x0,yc+y0,zc+z0
    xt,yt,zt = xt+x0,yt+y0,zt+z0+h*np.sign(v[2])
    
    return (xc,yc,zc), (xt,yt,zt)
    
    
def set3daxparams(ax,alpha_bg =0,lims=[-2,2],viewAngle=[0,0],axesLabels=False,showAxes=False):
    ax.auto_scale_xyz(lims,lims,lims)

    # transparent background 
    plt.gcf().patch.set_alpha(alpha_bg)
    ax.patch.set_alpha(alpha_bg)
    
    ax.view_init(elev=viewAngle[0],azim=viewAngle[1])
    if not showAxes:
        ax.set_axis_off()
    else:
        ax.set_xticks([]) 
        ax.set_yticks([]) 
        ax.set_zticks([])
    
    if axesLabels:
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("z")
    
    for spine in ax.spines.values():
        spine.set_visible(False) 


# creates a frame for animation
# viewAngle = [elev/polar, azim]
# za = height of arrow
def getSpinFrame(za,lims=[-2,2],viewAngle=[0,0],shaft_width = 0.05,arrowtip_height = 0.5, arrow_base_ratio = 2.5, alpha_sphere=0.05, colors=['#1f77b4','tomato'],axesLabels=False,showAxes=False):
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    # Plot sphere
    phivec, thetavec = np.linspace(0, np.pi, 100),np.linspace(0, 2*np.pi, 100)
    phi, theta = np.meshgrid(phivec, thetavec)
    xs, ys, zs = np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)

    ax.plot_surface(xs, ys, zs,  rstride=1,color="lightsteelblue",alpha=0.3)
    
    # Plot arrow
    (xc,yc,zc),(xt,yt,zt) = getarrow3d((0,0,za),(0,0,0))

    color = colors[0] if za > 0 else colors[1]
    ax.plot_surface(xc,yc,zc,color=color,alpha=1)
    ax.plot_surface(xt,yt,zt,color=color,alpha=1)
    
    ax.auto_scale_xyz(lims,lims,lims)

    # transparent background 
    fig.patch.set_alpha(0.)
    ax.patch.set_alpha(0.)
    
    ax.view_init(elev=viewAngle[0],azim=viewAngle[1])
        
    if not showAxes:
        ax.set_axis_off()
    else:
        ax.set_xticks([]) 
        ax.set_yticks([]) 
        ax.set_zticks([]) 
    if axesLabels:
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("z")
    
    for spine in ax.spines.values():
        spine.set_visible(False)

    # Used to return the plot as an image array
    fig.canvas.draw()       # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_argb(), dtype='uint8')
    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))
    plt.close(fig)
    
    return image


# creates a frame for animation
# viewAngle = [elev/polar, azim]
# za = height of arrow
def initSpinFrame(alpha_sphere=0.05,**kwargs):
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    # Plot sphere
    phivec, thetavec = np.linspace(0, np.pi, 100),np.linspace(0, 2*np.pi, 100)
    phi, theta = np.meshgrid(phivec, thetavec)
    xs, ys, zs = np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)

    ax.plot_surface(xs, ys, zs,  rstride=1,color="lightsteelblue",alpha=0.3)
    
    set3daxparams(ax,**kwargs)
    
    return fig,ax


# creates a frame for animation
# viewAngle = [elev/polar, azim]
# za = height of arrow
def updateSpinFrame(za,ax,shaft_width = 0.05,arrowtip_height = 0.5, arrow_base_ratio = 2.5, alpha_sphere=0.05, colors=['#1f77b4','tomato'],**kwargs):
    ax.clear()
    
    # Plot sphere
    phivec, thetavec = np.linspace(0, np.pi, 100),np.linspace(0, 2*np.pi, 100)
    phi, theta = np.meshgrid(phivec, thetavec)
    xs, ys, zs = np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)
    ax.plot_surface(xs, ys, zs,  rstride=1,color="lightsteelblue",alpha=0.3)
    
    # Plot arrow
    (xc,yc,zc),(xt,yt,zt) = getarrow3d((0,0,za),(0,0,0))
    color = colors[0] if za > 0 else colors[1]
    ax.plot_surface(xc,yc,zc,color=color,alpha=1)
    ax.plot_surface(xt,yt,zt,color=color,alpha=1)
    
    
    set3daxparams(ax,**kwargs)
    
    return ax


# generates animation of evolution of 3D probability distribution in Bloch sphere
def animateSpin(zvec,filename=[],saveBool=False,fps=2,**kwargs):
    kwargs_write = {'fps':1.0, 'quantizer':'nq'}
    plt.ioff()
    ImList = [getSpinFrame(zvec[i],**kwargs) for i in tqdm(range(len(zvec)))]
    plt.ion()
    if saveBool:
        if filename==[]:
            timestr = time.strftime("%Y-%m-%d-%Hh%Mm%Ss");
            filename = "./spinAnim_%s.gif" % timestr
        iio.mimwrite(filename,ImList,fps=10)
    return ImList


In [6]:
minz, maxz = 0.6, 2 # min and max absolute values of arrow height
zstep = 0.2
zvec = np.arange(maxz, minz-zstep,-zstep)
zvec = np.concatenate((zvec,-np.flip(zvec),-zvec[1:],np.flip(zvec)))

ImList = animateSpin(zvec,lims=[-2,2],viewAngle=[0,90],alpha_sphere=0.02,filename=[],saveBool=True,fps=10);

100%|██████████| 31/31 [00:14<00:00,  2.17it/s]
