In [None]:
import sys

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib import animation
from numba import njit

LX = 1
LY = LX
RHO = 1
NU = 0.01
U_TOP = 1
NX = 31
NY = NX
DT = 0.01
T_END = 10
DT_PLOT = T_END/50
OMEGA_SOR = 1.8
TOL = 10**(-3)
MAX_ITER = 100


class Variable:

    @classmethod
    def set_class_variables(cls, lx, ly, nx, ny):

        cls.__nx = nx
        cls.__ny = ny
        cls.__dx = lx / (nx-1)
        cls.__dy = ly / (ny-1)

    @classmethod
    def check_stable(cls, dt, u_top, nu):

        dx = Variable.__dx
        dy = Variable.__dy
        dt_stable = min(1/(u_top/dx+u_top/dy), 1/((2*nu)*(1/(dx**2)+1/(dy**2))))
        if dt > dt_stable:
            print(f'[ERROR] 数値的安定性の条件(DT < {dt_stable})を満たしていません')
            sys.exit()

    def __init__(self, name):

        self.__name = name

        nx = Variable.__nx
        ny = Variable.__ny
        self.value = np.empty((nx, ny), dtype=np.float64)

        dx = Variable.__dx
        dy = Variable.__dy
        for ix in range(nx):
            x = dx * ix
            for iy in range(ny):
                y = dy * iy
                self.value[ix, iy] = self.__set_initial_condition(x, y)

    def __set_initial_condition(self, x, y):

        if self.__name == 'grid_x':
            return x
        elif self.__name == 'grid_y':
            return y
        elif self.__name == 'psi':
            return 0
        elif self.__name == 'omg':
            return 0

    def set_boundary_condition(self, psi = None, u_top = None):

        if self.__name == 'psi':

            self.value[0, :] = 0
            self.value[-1, :] = 0
            self.value[:, 0] = 0
            self.value[:, -1] = 0

        elif self.__name == 'omg':

            dx = Variable.__dx
            dy = Variable.__dy
            self.value[0, :] = -2 * (psi.value[1, :]-psi.value[0, :]) / (dx**2)
            self.value[-1, :] = -2 * (psi.value[-2, :]-psi.value[-1, :]) / (dx**2)
            self.value[:, 0] = -2 * (psi.value[:, 1]-psi.value[:, 0]) / (dy**2)
            self.value[:, -1] = -2 * (psi.value[:, -2]-psi.value[:, -1]) / (dy**2) - 2 * u_top / dy

    def advection(self, psi):

        nx = Variable.__nx
        ny = Variable.__ny
        advection = np.zeros((nx, ny), dtype=np.float64)

        dx = Variable.__dx
        dy = Variable.__dy
        for ix in range(1, nx-1):
            for iy in range(1, ny-1):
                vx = (psi.value[ix, iy+1] - psi.value[ix, iy-1]) / (2 * dy)
                vy = - (psi.value[ix+1, iy] - psi.value[ix-1, iy]) / (2 * dx)
                advection[ix, iy] = (
                    max(vx, 0) * (self.value[ix, iy] - self.value[ix-1, iy]) / dx
                    + min(vx, 0) * (self.value[ix+1, iy] - self.value[ix, iy]) / dx
                    + max(vy, 0) * (self.value[ix, iy] - self.value[ix, iy-1]) / dy
                    + min(vy, 0) * (self.value[ix, iy+1] - self.value[ix, iy]) / dy
                )
        return advection

    def laplacian(self):

        nx = Variable.__nx
        ny = Variable.__ny
        laplacian = np.zeros((nx, ny), dtype=np.float64)

        dx = Variable.__dx
        dy = Variable.__dy
        for ix in range(1, nx-1):
            for iy in range(1, ny-1):
                laplacian[ix, iy] = (
                    (self.value[ix+1, iy] - 2 * self.value[ix, iy] + self.value[ix-1, iy]) / (dx**2)
                    + (self.value[ix, iy+1] - 2 * self.value[ix, iy] + self.value[ix, iy-1]) / (dy**2)
                )
        return laplacian

    def poisson_solver(self, source, omega_sor, tol, max_iter):

        dx = Variable.__dx
        dy = Variable.__dy
        for _ in range(max_iter):
            self.value, max_err = sor_one_iter(self.value, source, dx, dy, omega_sor)
            self.set_boundary_condition()
            if max_err < tol:
                break


@njit
def sor_one_iter(solution, source, dx, dy, omega_sor):

    nx, ny = solution.shape

    max_err = 0
    denominator = 2/(dx**2) + 2/(dy**2)

    for ix in range(1, nx-1):
        for iy in range(1, ny-1):
            old = solution[ix, iy]
            new = (
                (solution[ix+1, iy] + solution[ix-1, iy]) / (dx**2)
                + (solution[ix, iy+1] + solution[ix, iy-1]) / (dy**2)
                - source[ix, iy]
            ) / denominator
            solution[ix, iy] = (1-omega_sor) * old + omega_sor * new

            if new != 0:
                err = abs((new - old) / new)
                max_err = max(max_err, err)

    return solution, max_err


def create_plot(frame, *fargs) -> tuple:
    
    grid_x, grid_y, results_t, results_psi, results_omg = fargs
    
    axes[0].cla()
    axes[1].cla()

    # psi の等値線図
    psi = results_psi[frame]
    vmin = np.min(results_psi)
    vmax = np.max(results_psi)
    im1 = axes[0].contourf(grid_x, grid_y, psi, vmin=vmin, vmax=vmax, levels=20, cmap='jet')
    axes[0].contour(im1, colors='k', linewidths=0.5)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title('Stream Function (psi)')
    axes[0].set_aspect('equal')

    # omg の等値線図
    omg = results_omg[frame]
    vmin = np.min(results_omg)
    vmax = np.max(results_omg)
    im2 = axes[1].contourf(grid_x, grid_y, omg, vmin=vmin, vmax=vmax, levels=20, cmap='jet')
    axes[1].contour(im2, colors='k', linewidths=0.5)
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    axes[1].set_title('Vorticity (omg)')
    axes[1].set_aspect('equal')

    figure.suptitle(f't = {results_t[frame]:.1f} sec')

    return im1, im2


if __name__ == '__main__':

    Variable.set_class_variables(LX, LY, NX, NY)
    Variable.check_stable(DT, U_TOP, NU)

    grid_x: Variable = Variable('grid_x')
    grid_y: Variable = Variable('grid_y')
    psi: Variable = Variable('psi')
    omg: Variable = Variable('omg')

    t = 0
    it = 0

    # アニメーションの準備
    figure, axes = plt.subplots(1, 2, figsize=(10, 5))
    results_t = [t]
    results_psi = [psi.value.copy()]
    results_omg = [omg.value.copy()]

    while t < T_END:

        omg.value += (-omg.advection(psi) + NU * omg.laplacian()) * DT
        psi.poisson_solver(-omg.value, OMEGA_SOR, TOL, MAX_ITER)
        omg.set_boundary_condition(psi, U_TOP)
        t += DT

        # アニメーション用に結果を保存
        if it % int(DT_PLOT / DT) == 0:
            results_t.append(t)
            results_psi.append(psi.value.copy())
            results_omg.append(omg.value.copy())
        it += 1

    # アニメーションを作成
    im1, im2 = create_plot(len(results_t)-1, grid_x.value, grid_y.value, results_t, results_psi, results_omg)
    figure.colorbar(im1, ax=axes[0])
    figure.colorbar(im2, ax=axes[1])
    figure.tight_layout()
    anim = animation.FuncAnimation(figure, create_plot, range(len(results_t)),
                                   fargs=(grid_x.value, grid_y.value, results_t, results_psi, results_omg))
    display(HTML(anim.to_jshtml()))
    plt.close()