In [None]:
import numpy as np
import matplotlib.pyplot as plt 
from scipy.fft import fft2, ifft2
from scipy.signal import convolve2d
from tqdm import tqdm

In [231]:

Nx = 128
Ny = 128
Lx = 2 * np.pi
Ly = 2 * np.pi
dx = Lx / Nx
dy = Ly / Ny

dt = 0.01
t = 1

c = np.ones((Nx, Ny, int(t/dt))) * 3
h = np.ones((Nx, Ny, int(t/dt))) * 1
w = np.ones((Nx, Ny, int(t/dt))) * 1

theta = np.linspace(0, Lx/Nx, Nx)
for i in range(Nx):
    for j in range(Ny):
        c[i, j, 0] = np.sin(2 * dx * i) * np.sin(dy * j)
        h[i, j, 0] = np.sin(dx * i) * np.sin(dy * j)
        w[i, j, 0] = np.sin(dx * i) * np.sin(dy * j)


Dc = 1
Dh = 0.01
Dw = 0.01

A = 0
B = 0

ck = fft2(c[:, :, 0])
hk = fft2(h[:, :, 0])
wk = fft2(w[:, :, 0])

kx = np.fft.fftfreq(Nx, dx) * np.pi * 2
ky = np.fft.fftfreq(Ny, dy) * np.pi * 2
k2 = np.zeros((Nx, Ny), float)

for i in range(Nx):
    for j in range(Ny):
        k2[i, j] = kx[i]**2 + ky[j]**2

for i in tqdm(range(int(t / dt)), position=0):
    c[:,:,i] = np.real(ifft2(ck))
    h[:,:,i] = np.real(ifft2(hk))
    w[:,:,i] = np.real(ifft2(wk))
    
    _B = B * fft2(c[:,:,i] * h[:,:,i]**2)
    _HW = fft2(h[:,:,i] * w[:,:,i])

    #_B = B * convolve2d(ck, hk)
    #_B = convolve2d(_B, hk)
    #_HW = convolve2d(hk, wk)

    #ck = (A*wk - _B + Dc * (-k2)*ck) * dt + ck
    #hk = (_B - _HW + Dh * (-k2)*hk) * dt + hk
    #wk = (_HW - A*wk + Dw * (-k2)*wk) * dt + wk

    #ck = Dc * -k2 * dt * ck + ck
    Lk = (-k2) * dt / 2
    ck = (1 + Lk) / (1 - Lk) * ck
    #a1 = (A*wk - _B + Dc * (-kx**2 - ky**2)*ck)
    #a2 = (_B - _HW + Dh * (-kx**2 - ky**2)*hk)
    #a3 = (_HW - A*wk + Dw * (-kx**2 - ky**2)*wk)
    #if i%100 ==0:
    #plt.pcolormesh(c[:,:,i])

if np.isnan(ck).any():
    print("Error, encounter NaN")


100%|██████████| 100/100 [00:00<00:00, 1235.58it/s]


In [232]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
fig, ax = plt.subplots()
line = ax.pcolormesh(c[:,:,0]) 
cbar = fig.colorbar(line, ax=ax)

def init():
    #line = ax.pcolormesh(c[:,:,0])
    line = ax.imshow(c[:, :, i])
    return [line]

def plotFrame(i):
    #line = ax.pcolormesh(c[:,:,i])
    line = ax.imshow(c[:, :, i])

    cbar.mappable.set_array(c[:, :, i])
    #fig.colorbar(line, ax=ax)
    return [line]

video = FuncAnimation(fig, plotFrame, interval=50, frames=tqdm(range(0, int(t/dt), 3), initial=0, position=0), blit=True, init_func=init)
#video.save('euler_adam.mp4', fps=10) 
plt.close()
display(HTML(video.to_html5_video()))

  0%|          | 0/34 [00:00<?, ?it/s]

 97%|█████████▋| 33/34 [00:07<00:00,  2.87it/s]