In [8]:
import numpy as np
import matplotlib.pyplot as plt 
from scipy.fft import fft2, ifft2
from tqdm import tqdm
from random import random

In [44]:
Nx = 512
Ny = 512
Lx = 8 * np.pi
Ly = 8 * np.pi
dx = Lx / Nx
dy = Ly / Ny

dt = 0.01
t = 10

c = np.ones((Nx, Ny, int(t/dt))) * 0.7
h = np.ones((Nx, Ny, int(t/dt))) * 0.2
w = np.ones((Nx, Ny, int(t/dt))) * 0.1

theta = np.linspace(0, Lx/Nx, Nx)
'''
for i in range(Nx):
    for j in range(Ny):
        c[i, j, 0] = np.sin(2 * dx * i) * np.sin(dy * j) * 2
        h[i, j, 0] = np.sin(dx * i) * np.sin(dy * j) * 2
        w[i, j, 0] = np.sin(dx * i) * np.sin(dy * j) * 1
'''
'''
for i in range(int(Nx/10), int(Nx/5)):
    for j in range(int(Nx/10), int(Nx/5)):
        c[i, j, 0] = c[i, j, 0] + np.sin(5 * dx * i) * np.sin(5 * dy * j) * 0.2
'''    

'''
for i in range(Nx):
    for j in range(Ny):
        c[i, j, 0] = c[i, j, 0] + np.sin(5 * dx * i) * np.sin(5 * dy * j) * 0.01
        w[i, j, 0] = w[i, j, 0] + np.cos(5 * dx * i) * np.cos(5 * dy * j) * 0.01
        h[i, j, 0] = h[i, j, 0] + np.sin(5 * dx * i) * np.cos(5 * dy * j) * 0.01
'''

for i in range(Nx):
    for j in range(Ny):
        c[i, j, 0] = c[i, j, 0] + random() * 0.2
        w[i, j, 0] = w[i, j, 0] + random() * 0.2
        h[i, j, 0] = h[i, j, 0] + random() * 0.2

Dc = 0.1
Dh = 0.005
Dw = 0.1

A = 0.3
B = 6

ck = fft2(c[:, :, 0])
hk = fft2(h[:, :, 0])
wk = fft2(w[:, :, 0])

kx = np.fft.fftfreq(Nx, dx) * np.pi * 2
ky = np.fft.fftfreq(Ny, dy) * np.pi * 2
k2 = np.zeros((Nx, Ny), float)

for i in range(Nx):
    for j in range(Ny):
        k2[i, j] = kx[i]**2 + ky[j]**2



# Initialize

# Linear term for c, h, w
Lc = Dc * (-k2) * dt
Lh = Dh * (-k2) * dt
Lw = Dw * (-k2) * dt

# Nonlinear term, N(t), N(t-\delta t) for c, h, w
N_c = fft2(A * w[:, :, 0] - B * c[:, :, 0] * h[:, :, 0]**2)
N_h = fft2(B * c[:, :, 0] * h[:, :, 0]**2 - h[:, :, 0] * w[:, :, 0])
N_w = fft2(h[:, :, 0] * w[:, :, 0] - A * w[:, :, 0] - B * c[:, :, 0])

ck = (Lc * ck + N_c) * dt + ck
hk = (Lh * hk + N_h) * dt + hk
wk = (Lw * wk + N_w) * dt + wk

c[:,:,0] = np.real(ifft2(ck))
h[:,:,0] = np.real(ifft2(hk))
w[:,:,0] = np.real(ifft2(wk))


# Iterate from t = dt
for i in tqdm(range(1, int(t / dt)), position=0):
    c[:,:,i] = np.real(ifft2(ck))
    h[:,:,i] = np.real(ifft2(hk))
    w[:,:,i] = np.real(ifft2(wk))
    
    # Nonlinear term, N(t), N(t-\delta t) for c, h, w
    N_c_past = N_c
    N_c = fft2(A * w[:, :, i] - B * c[:, :, i] * h[:, :, i]**2)
    N_h_past = N_h
    N_h = fft2(B * c[:, :, i] * h[:, :, i]**2 - h[:, :, i] * w[:, :, i])
    N_w_past = N_w
    N_w = fft2(h[:, :, i] * w[:, :, i] - A * w[:, :, i] - B * c[:, :, i])
    
    ck = ((1 + Lc * (1/2)) * ck + (3 * N_c - 3 * N_c_past) * (1/2) * dt) / (1 - Lc * (1/2))
    hk = ((1 + Lh * (1/2)) * hk + (3 * N_h - 3 * N_h_past) * (1/2) * dt) / (1 - Lh * (1/2))
    wk = ((1 + Lw * (1/2)) * wk + (3 * N_w - 3 * N_w_past) * (1/2) * dt) / (1 - Lw * (1/2))

if np.isnan(ck).any():
    print("Error, encounter NaN")


  3%|▎         | 27/999 [41:30<21:08:21, 78.29s/it] 

In [31]:
#print(c[100,:,:])
print(c[100,0, 100] + w[100,0, 100] + h[100,0, 100])

1.4046121770477427


In [43]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, ax = plt.subplots()
#line = ax.pcolormesh(c[:,:,0]) 
#line = plt.imshow(c[:, :, 0])
img = ax.imshow(c[:, :, 0],  cmap='plasma')
#colorbar = fig.colorbar(img, ax=ax)
div = make_axes_locatable(ax)
cax = div.append_axes('right', '5%', '5%')

#colorbar = plt.colorbar()
#cbar = fig.colorbar(line, ax=ax)

def init():
    img = ax.imshow(c[:, :, 0],  cmap='plasma')
    #colorbar = fig.colorbar(img)
    fig.colorbar(img, cax=cax)
    #colorbar.mappable.set_array(c[:, :, 0])
    return [img]
    #return [line]

def plotFrame(i):
    img = ax.imshow(c[:, :, i], cmap='plasma')
    fig.colorbar(img, cax=cax)
    #colorbar.mappable.set_array(c[:, :, i])
    
    return [img]
    #return [line]

video = FuncAnimation(fig, plotFrame, interval=100, frames=tqdm(range(0, int(t/dt), 50), initial=0, position=0), blit=False, init_func=init)
#video.save('euler_adam.mp4', fps=10) 
plt.close()
display(HTML(video.to_html5_video()))

 99%|█████████▉| 99/100 [01:04<00:00,  1.10it/s]