In [237]:
import numpy as np
import matplotlib.pyplot as plt 
from scipy.fft import fft2, ifft2
from tqdm import tqdm

In [246]:

Nx = 128
Ny = 128
Lx = 2 * np.pi
Ly = 2 * np.pi
dx = Lx / Nx
dy = Ly / Ny

dt = 0.01
t = 1

c = np.ones((Nx, Ny, int(t/dt))) * 3
h = np.ones((Nx, Ny, int(t/dt))) * 1
w = np.ones((Nx, Ny, int(t/dt))) * 1

theta = np.linspace(0, Lx/Nx, Nx)
for i in range(Nx):
    for j in range(Ny):
        c[i, j, 0] = np.sin(2 * dx * i) * np.sin(dy * j) * 10
        h[i, j, 0] = np.sin(dx * i) * np.sin(dy * j) * 10
        w[i, j, 0] = np.sin(dx * i) * np.sin(dy * j) * 10


Dc = 1
Dh = 1
Dw = 1

A = 0
B = 0

ck = fft2(c[:, :, 0])
hk = fft2(h[:, :, 0])
wk = fft2(w[:, :, 0])

kx = np.fft.fftfreq(Nx, dx) * np.pi * 2
ky = np.fft.fftfreq(Ny, dy) * np.pi * 2
k2 = np.zeros((Nx, Ny), float)

for i in range(Nx):
    for j in range(Ny):
        k2[i, j] = kx[i]**2 + ky[j]**2



# Initialize

# Linear term for c, h, w
Lc = Dc * (-k2) * dt
Lh = Dh * (-k2) * dt
Lw = Dw * (-k2) * dt

# Nonlinear term, N(t), N(t-\delta t) for c, h, w
N_c = fft2(A * w[:, :, 0] - B * c[:, :, 0] * h[:, :, 0]**2)
N_h = fft2(B * c[:, :, 0] * h[:, :, 0]**2 - h[:, :, 0] * w[:, :, 0])
N_w = fft2(h[:, :, 0] * w[:, :, 0] - A * w[:, :, 0] - B * c[:, :, 0])

ck = (Lc * ck + N_c) * dt + ck
hk = (Lh * hk + N_h) * dt + hk
wk = (Lw * wk + N_w) * dt + wk

c[:,:,0] = np.real(ifft2(ck))
h[:,:,0] = np.real(ifft2(hk))
w[:,:,0] = np.real(ifft2(wk))

print(c[:, :, 0])

# Iterate from t = dt
for i in tqdm(range(1, int(t / dt)), position=0):
    c[:,:,i] = np.real(ifft2(ck))
    h[:,:,i] = np.real(ifft2(hk))
    w[:,:,i] = np.real(ifft2(wk))
    
    # Nonlinear term, N(t), N(t-\delta t) for c, h, w
    N_c_past = N_c
    N_c = fft2(A * w[:, :, i] - B * c[:, :, i] * h[:, :, i]**2)
    N_h_past = N_h
    N_h = fft2(B * c[:, :, i] * h[:, :, i]**2 - h[:, :, i] * w[:, :, i])
    N_w_past = N_w
    N_w = fft2(h[:, :, i] * w[:, :, i] - A * w[:, :, i] - B * c[:, :, i])
    
    ck = ((1 + Lc * (1/2)) * ck + (3 * N_c - 3 * N_c_past) * (1/2) * dt) / (1 - Lc * (1/2))
    hk = ((1 + Lh * (1/2)) * hk + (3 * N_h - 3 * N_h_past) * (1/2) * dt) / (1 - Lh * (1/2))
    wk = ((1 + Lw * (1/2)) * wk + (3 * N_w - 3 * N_w_past) * (1/2) * dt) / (1 - Lw * (1/2))

if np.isnan(ck).any():
    print("Error, encounter NaN")


[[-4.97919828e-19 -1.18819725e-17 -2.39408572e-17 ...  3.13396753e-17
   2.38177121e-17  7.36919272e-18]
 [ 1.43951597e-16  4.80706838e-02  9.60255612e-02 ... -1.43749105e-01
  -9.60255612e-02 -4.80706838e-02]
 [-1.47033031e-16  9.56784207e-02  1.91126344e-01 ... -2.86113827e-01
  -1.91126344e-01 -9.56784207e-02]
 ...
 [-4.56234900e-16 -1.42364722e-01 -2.84386475e-01 ...  4.25723116e-01
   2.84386475e-01  1.42364722e-01]
 [-7.28534553e-16 -9.56784207e-02 -1.91126344e-01 ...  2.86113827e-01
   1.91126344e-01  9.56784207e-02]
 [-1.47839581e-16 -4.80706838e-02 -9.60255612e-02 ...  1.43749105e-01
   9.60255612e-02  4.80706838e-02]]


100%|██████████| 99/99 [00:00<00:00, 619.03it/s]


In [247]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
fig, ax = plt.subplots()
line = ax.pcolormesh(c[:,:,0]) 
cbar = fig.colorbar(line, ax=ax)

def init():
    #line = ax.pcolormesh(c[:,:,0])
    line = ax.imshow(c[:, :, i])
    return [line]

def plotFrame(i):
    #line = ax.pcolormesh(c[:,:,i])
    line = ax.imshow(c[:, :, i])

    cbar.mappable.set_array(c[:, :, i])
    #fig.colorbar(line, ax=ax)
    return [line]

video = FuncAnimation(fig, plotFrame, interval=50, frames=tqdm(range(0, int(t/dt), 3), initial=0, position=0), blit=True, init_func=init)
#video.save('euler_adam.mp4', fps=10) 
plt.close()
display(HTML(video.to_html5_video()))

 97%|█████████▋| 33/34 [00:07<00:00,  2.85it/s]