In [None]:
import numpy as np
from scipy.sparse import spdiags
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.fftpack import fft2, ifft2
import imageio.v2 as imageio
from google.colab import files   # file download

# ---- Part 1 - Solve with periodic boundary conditions using fft ----

# ---- Paramaters ----
n = 64
L = 20
x2 = np.linspace(-L/2, L/2, n+1)
x = x2[:n]
y = x
X, Y = np.meshgrid(x, y, indexing='xy')
N = n * n
D1 = .1
D2 = .1
B = 1
m = 1
tspan = np.arange(0, 4.5, .5)

# ---- Initial Conditions ----
U0 = np.tanh(np.sqrt(X**2 + Y**2)) * np.cos(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((n, n))
V0 = np.tanh(np.sqrt(X**2 + Y**2)) * np.sin(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((n, n))

# ---- Define wavenumbers for finding derivatives via fft ----
kx = (2 * np.pi / L) * np.concatenate((np.arange(0, n/2), np.arange(-n/2, 0)))
ky = (2 * np.pi / L) * np.concatenate((np.arange(0, n/2), np.arange(-n/2, 0)))
kx[0] = 1e-6
ky[0] = 1e-6
Kx, Ky = np.meshgrid(kx, ky, indexing='xy')
k = Kx**2 + Ky**2

# ---- FFT initial conditions, stack real and imag, stack U and V ----
Ut0= np.hstack([np.real(fft2(U0).reshape(N)),np.imag(fft2(U0).reshape(N))])
Vt0= np.hstack([np.real(fft2(V0).reshape(N)),np.imag(fft2(V0).reshape(N))])
UVt0 = np.hstack([Ut0, Vt0])

# ---- Define RHS ----
def rhs1(t, UVt):
  Ut = np.reshape((UVt[:N] + 1j * UVt[N:2*N]), (n, n))      #Split UVT into U and V, add real and imag, reshape
  Vt = np.reshape((UVt[2*N:3*N] + 1j * UVt[3*N:]), (n, n))
  U = ifft2(Ut)
  V = ifft2(Vt)
  dUnl = (1 - U**2 - V**2) * U + B * (U**2 + V**2) * V      #non linear terms of dUdt
  dVnl = -1 * B * (U**2 + V**2) * U + (1 - U**2 - V**2) * V   #non linear terms of dVdt
  dUdt = fft2(dUnl) + D1 * -k * Ut
  dVdt = fft2(dVnl) + D2 * -k * Vt
  return np.hstack([np.real(dUdt).reshape(N), np.imag(dUdt).reshape(N), np.real(dVdt).reshape(N), np.imag(dVdt).reshape(N)])

# ---- Solve PDEs with ode45 ----
solution = solve_ivp(
    rhs1,
    [tspan[0], tspan[-1]],
    UVt0,
    t_eval=tspan,
    method='RK45'
)

sol = solution.y
A1 = np.vstack([sol[0:N] + 1j * sol[N:2*N], sol[2*N:3*N] + 1j * sol[3*N:]])


# ---- Visualization ----

# # Display static images over tspan
# def plot_snapshots(solution, tspan, n, X, Y):
#     N = n * n
#     # Iterate over specific time snapshots
#     for i, t in enumerate(tspan):
#         # Extract solutions for U and V at time t
#         UVt = solution.y[:, i]
#         Ut = np.reshape(UVt[:N] + 1j * UVt[N:2*N], (n, n))
#         Vt = np.reshape(UVt[2*N:3*N] + 1j * UVt[3*N:], (n, n))
#         # Transform back to physical space using ifft
#         U = np.real(ifft2(Ut))
#         V = np.real(ifft2(Vt))
#         # Plot the real parts of U and V
#         fig, axes = plt.subplots(1, 2, figsize=(12, 5))
#         c1 = axes[0].contourf(X, Y, U, cmap='viridis')
#         c2 = axes[1].contourf(X, Y, V, cmap='plasma')
#         fig.colorbar(c1, ax=axes[0])
#         fig.colorbar(c2, ax=axes[1])
#         axes[0].set_title(f"Real Part of U at t={t:.1f}")
#         axes[1].set_title(f"Real Part of V at t={t:.1f}")
#         plt.tight_layout()
#         plt.show()
# # Call the function to generate snapshots
# plot_snapshots(solution, tspan, n, X, Y)

# Create and download a gif over tspan
def create_gif(solution, tspan, n, X, Y, filename="reaction_diffusion.gif"):
    N = n * n
    images = []  # To store frames
    for i, t in enumerate(tspan):
        # Extract solutions for U and V at time t
        UVt = solution.y[:, i]
        Ut = np.reshape(UVt[:N] + 1j * UVt[N:2*N], (n, n))
        Vt = np.reshape(UVt[2*N:3*N] + 1j * UVt[3*N:], (n, n))
        # Transform back to physical space using ifft
        U = np.real(ifft2(Ut))
        V = np.real(ifft2(Vt))
        # Plot the real parts of U and V
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        c1 = axes[0].contourf(X, Y, U, cmap="viridis")
        c2 = axes[1].contourf(X, Y, V, cmap="plasma")
        axes[0].set_title(f"U at t={t:.1f}")
        axes[1].set_title(f"V at t={t:.1f}")
        for ax in axes:
            ax.axis("off")
        plt.tight_layout()
        # Save plot to a temporary file
        plt.savefig("temp_frame.png", dpi=150)
        plt.close(fig)
        # Read the saved file and append it as a frame
        images.append(imageio.imread("temp_frame.png"))
    # Save all frames as a GIF
    imageio.mimsave(filename, images, fps=10)
    print(f"GIF saved as {filename}")

create_gif(solution, tspan, n, X, Y)
files.download('reaction_diffusion.gif')

print(f"A1 shape: {A1.shape}\nA1 first entry: {A1[0,0]}\nA1 second entry: {A1[0, 1]}")
print(f"\nA1 second to last entry: {A1[-1, -2]}\nA1 last entry: {A1[-1, -1]}")



# ---- Part 2 - Solve with no-flux boundary conditions using Chebychev ----

# ---- Chebychev D matrix ----
def cheb(N):
  if N==0:
    D = 0
    x = 1
  else:
    n = np.arange(0, N+1)
    x = np.cos(np.pi * n / N).reshape(N+1, 1)
    c = (np.hstack(([2], np.ones(N-1), [2])) * (-1)**n).reshape(N+1, 1)
    X = np.tile(x, (1, N+1))
    dX = X - X.T
    D = np.dot(c, 1 / c.T) / (dX + np.eye(N+1))
    D -= np.diag(np.sum(D.T, axis=0))
  return D, x.reshape(N+1)

# ---- Paramaters ----
N = 30
NN = (N+1)**2
Ds, xcs = cheb(N)    # Ds=dervivative matrix scaled down to [-1, 1], xcs = x values clustered and scaled down to [-1, 1]
Ds[N, :] = 0
Ds[0, :] = 0
sf = 10           # scale factor for adjusting from [-1, 1] Cheby domain to physical domain of [-10, 10]
D = (1 / sf) * Ds
D2 = (1 / sf**2) * np.dot(Ds, Ds)
I = np.eye(len(D2))
L = np.kron(I, D2) + np.kron(D2, I)   # 2D Laplacian
xc = sf * xcs     # scale x up to physical domain [-10, 10]
yc = xc
X, Y = np.meshgrid(xc, yc, indexing='xy')
D1 = .1
D2 = .1
B = 1
m = 1
tspan = np.arange(0, 4.5, .5)

# ---- Initial Conditions ----
U0 = (np.tanh(np.sqrt(X**2 + Y**2)) * np.cos(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((N+1, N+1))).reshape(NN)
V0 = (np.tanh(np.sqrt(X**2 + Y**2)) * np.sin(m * np.angle(X + 1j * Y) - np.sqrt(X**2 + Y**2)) + 1j * np.zeros((N+1, N+1))).reshape(NN)
UV0 = np.hstack([U0, V0])

# ---- Define RHS ----
def rhs2(t, UVs):
  Us = UVs[:NN]
  Vs = UVs[NN:]
  LU = np.reshape(np.dot(L, Us), (N+1, N+1))
  LV = np.reshape(np.dot(L, Vs), (N+1, N+1))
  U = np.reshape(Us, (N+1, N+1))
  V = np.reshape(Vs, (N+1, N+1))
  dUdt = (1 - U**2 - V**2) * U + B * (U**2 + V**2) * V + D1 * LU
  dVdt = -1 * B * (U**2 + V**2) * U + (1 - U**2 - V**2) * V + D2 * LV
  return np.hstack([dUdt.reshape(NN), dVdt.reshape(NN)])

# ---- Solve PDEs with ode45 ----
solution = solve_ivp(
    rhs2,
    [tspan[0], tspan[-1]],
    UV0,
    t_eval=tspan,
    method='RK45'
)

A2 = solution.y

# # ---- Visualization ----

# Create static images of solutions
# def plot_snapshots_chebyshev(solution, tspan, N, X, Y):
#     NN = (N + 1) ** 2  # Total grid points
#     # Iterate over specific time snapshots
#     for i, t in enumerate(tspan):
#         # Extract U and V at time t
#         UVt = solution.y[:, i]
#         U = np.reshape(UVt[:NN], (N + 1, N + 1))
#         V = np.reshape(UVt[NN:], (N + 1, N + 1))
#         # Plot U and V
#         fig, axes = plt.subplots(1, 2, figsize=(12, 5))
#         c1 = axes[0].contourf(X, Y, U.real, cmap='viridis')
#         c2 = axes[1].contourf(X, Y, V.real, cmap='plasma')
#         fig.colorbar(c1, ax=axes[0])
#         fig.colorbar(c2, ax=axes[1])
#         # Set titles and layout
#         axes[0].set_title(f"Real Part of U at t={t:.1f}")
#         axes[1].set_title(f"Real Part of V at t={t:.1f}")
#         axes[0].set_xlabel("x")
#         axes[0].set_ylabel("y")
#         axes[1].set_xlabel("x")
#         axes[1].set_ylabel("y")
#         plt.tight_layout()
#         plt.show()
# # Call the function to generate snapshots
# plot_snapshots_chebyshev(solution, tspan, N, X, Y)

# Create and download gif of solutions
def reshape_no_flux(A, N):
    U = np.reshape(A[:(N+1)**2, :], (N+1, N+1, -1))
    V = np.reshape(A[(N+1)**2:, :], (N+1, N+1, -1))
    return U, V

U2, V2 = reshape_no_flux(A2, N)

frames = []
for i, t in enumerate(tspan):
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    c1 = axes[0].contourf(xc, yc, U2[:, :, i], cmap="viridis")
    axes[0].set_title(f"No-Flux U (t={t:.1f})")
    plt.colorbar(c1, ax=axes[0])

    c2 = axes[1].contourf(xc, yc, V2[:, :, i], cmap="plasma")
    axes[1].set_title(f"No-Flux V (t={t:.1f})")
    plt.colorbar(c2, ax=axes[1])

    plt.tight_layout()
    plt.savefig("frame.png")
    plt.close()
    frames.append(imageio.imread("frame.png"))
gif_path = "no_flux_solutions.gif"
imageio.mimsave(gif_path, frames, fps=10)
files.download(gif_path)

print(f"A2 shape: {A2.shape}\nA2 first entry: {A2[0,0]}\nA2 second entry: {A2[0, 1]}\nA2 last entry: {A2[-1, -1]} ")

