In [16]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fftpack import fft2, ifft2
from scipy.integrate import solve_ivp

# Initializing Parameters
tspan = np.arange(0, 4.5, 0.5) 
beta = 1
D1 = D2 = 0.1
Lx, Ly = 20, 20
nx, ny = 64, 64
N = nx * ny

# Defining the spatial domain
x2 = np.linspace(-Lx / 2, Lx / 2, nx + 1)
x = x2[:nx]
y2 = np.linspace(-Ly / 2, Ly / 2, ny + 1)
y = y2[:ny]
X, Y = np.meshgrid(x, y)

#  Initial Conditons
m = 1 # Number of spirals
U0 = np.tanh(np.sqrt(X**2 + Y**2)) * np.cos(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((nx, ny))
V0 = np.tanh(np.sqrt(X**2 + Y**2)) * np.sin(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((nx, ny))

# Define spectral k values
kx = (2 * np.pi / Lx) * np.concatenate((np.arange(0, nx / 2), np.arange(-nx / 2, 0)))
ky = (2 * np.pi / Ly) * np.concatenate((np.arange(0, ny / 2), np.arange(-ny / 2, 0)))
kx[0] = ky[0] = 1e-6
KX, KY = np.meshgrid(kx, ky)
K = KX**2 + KY**2

def rhs(t, UVt, nx, ny, N, K, D1, D2, beta):
    # Split U and V
    Utc = UVt[:N] + 1j * UVt[N:2*N]
    Vtc = UVt[2*N:3*N] + 1j * UVt[3*N:]

    Ut = Utc.reshape((nx, ny))
    Vt = Vtc.reshape((nx, ny))

    U = np.real(ifft2(Ut))
    V = np.real(ifft2(Vt))

    A_squared = U**2 + V**2
    lambda_A = 1 - A_squared
    omega_A = -beta * A_squared

    rhs_U = fft2(lambda_A * U - omega_A * V) - D1 * K * Ut
    rhs_V = fft2(omega_A * U + lambda_A * V) - D2 * K * Vt

    return np.hstack([np.real(rhs_U).ravel(), np.imag(rhs_U).ravel(),
                      np.real(rhs_V).ravel(), np.imag(rhs_V).ravel()])

U0t = fft2(U0)
V0t = fft2(V0)
UV0t = np.hstack([np.real(U0t).ravel(), np.imag(U0t).ravel(),
                    np.real(V0t).ravel(), np.imag(V0t).ravel()])


sol = solve_ivp(rhs, [tspan[0], tspan[-1]], UV0t, t_eval=tspan, args=(nx, ny, N, K, D1, D2, beta))

for i, t in enumerate(tspan):
    U_hat, V_hat = sol.y[:nx * ny, i].reshape((nx, ny)), sol.y[nx * ny:, i].reshape((nx, ny))
    U = np.real(ifft2(U_hat))
    plt.figure(figsize=(6, 5))
    plt.pcolor(x, y, U, shading='auto')
    plt.title(f'Time: {t:.2f}')
    plt.colorbar(label='U')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()



ValueError: cannot reshape array of size 12288 into shape (64,64)