$$
\frac{d f}{dt} +  (v \frac{d f}{dx} - x \frac{d f}{dv}) = 0
$$

$$ 
x \in [-\pi, \pi],\qquad y \in [-\pi, \pi] \qquad \mbox{ and } \qquad t \in [0, 200\pi] 
$$

In [8]:
using  FFTW
using  LinearAlgebra
using  Plots, ProgressMeter
using  BenchmarkTools
pyplot()

Plots.PyPlotBackend()

In [14]:
struct Mesh
    
    nx   :: Int
    ny   :: Int
    xmin :: Float64
    xmax :: Float64
    ymin :: Float64
    ymax :: Float64
    dx   :: Float64
    dy   :: Float64
    x    :: Vector{Float64}
    y    :: Vector{Float64}
    
    function Mesh( xmin, xmax, nx, ymin, ymax, ny)
        dx, dy = (xmax-xmin)/nx, (ymax-ymin)/ny
        x = range(xmin, stop=xmax, length=nx+1)[1:end-1]  # we remove the end point
        y = range(ymin, stop=ymax, length=ny+1)[1:end-1]  # periodic boundary condition
        new( nx, ny, xmin, xmax, ymin, ymax, dx, dy, x, y)
    end
    
    function Mesh( x, y)
        dx, dy = x.step, y.step
        nx, ny = x.len, y.len
        xmin, ymin = x.offset, y.offset
        xmax = xmin + (nx-1)*dx
        ymax = ymin + (ny-1)*dy
        new( nx, ny, xmin, xmax, ymin, ymax, dx, dy, x, y)
    end
end

mesh = Mesh(-π, π, 128, -π, π, 256)

mesh = Mesh(-3:0.1:3, -3:0.1:3)

Mesh(61, 61, 31.0, 37.0, 31.0, 37.0, 0.1, 0.1, [-3.0, -2.9, -2.8, -2.7, -2.6, -2.5, -2.4, -2.3, -2.2, -2.1  …  2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0], [-3.0, -2.9, -2.8, -2.7, -2.6, -2.5, -2.4, -2.3, -2.2, -2.1  …  2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0])

### Julia function to compute exact solution

In [10]:
function exact(tf, mesh)
   
    f = zeros(Float64,(mesh.nx, mesh.ny))
    for (i, x) in enumerate(mesh.x), (j, y) in enumerate(mesh.y)
        xn = cos(tf)*x - sin(tf)*y
        yn = sin(tf)*x + cos(tf)*y
        f[i,j] = exp(-(xn-1)*(xn-1)/0.1)*exp(-(yn-1)*(yn-1)/0.1)
    end
    f
    
end

exact (generic function with 1 method)

In [11]:
function error1(f, f_exact)
    maximum(abs.(f .- f_exact))
end

error1 (generic function with 1 method)

In [12]:
function with_fft_transposed(tf, nt, mesh::Mesh)

    dt = tf/nt

    kx = 2π/(mesh.xmax-mesh.xmin)*[0:mesh.nx÷2-1;mesh.nx÷2-mesh.nx:-1]
    ky = 2π/(mesh.ymax-mesh.ymin)*[0:mesh.ny÷2-1;mesh.ny÷2-mesh.ny:-1]
    
    f  = zeros(Complex{Float64},(mesh.nx,mesh.ny))
    f̂  = similar(f)
    fᵗ = zeros(Complex{Float64},(mesh.ny,mesh.nx))
    f̂ᵗ = similar(fᵗ)

    exky = exp.( 1im*tan(dt/2) .* mesh.x' .* ky )
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )
    
    FFTW.set_num_threads(4)
    Px = plan_fft(f,  1, flags=FFTW.PATIENT)    
    Py = plan_fft(fᵗ, 1, flags=FFTW.PATIENT)
    
    f .= exact(0.0, mesh)
    
    for n = 1:nt
        transpose!(fᵗ,f)
        mul!(f̂ᵗ, Py, fᵗ)
        f̂ᵗ .= f̂ᵗ .* exky
        ldiv!(fᵗ, Py, f̂ᵗ)
        transpose!(f,fᵗ)
        
        mul!(f̂, Px, f)
        f̂ .= f̂ .* ekxy 
        ldiv!(f, Px, f̂)
        
        transpose!(fᵗ,f)
        mul!(f̂ᵗ, Py, fᵗ)
        f̂ᵗ .= f̂ᵗ .* exky
        ldiv!(fᵗ, Py, f̂ᵗ)
        transpose!(f,fᵗ)
    end
    real(f)
end

with_fft_transposed (generic function with 1 method)

In [13]:
nt, tf = 1000, 200\pi
println( " error = ", error1(with_fft_transposed(tf, nt, mesh), exact(tf, mesh)))

 error = 3.2374103398069565e-13
