In [51]:
using Revise
using ChainRulesCore
using CUDA
using LinearAlgebra
using MAT
using ParametricOperators
using ProgressMeter
using PyPlot
using Random
using Zygote

In [2]:
# Load data
data_path = "/home/tgrady6/data/NavierStokes_V1e-5_N1200_T20.mat"
data = matread(data_path)

Dict{String, Any} with 3 entries:
  "t" => Float32[1.0 2.0 … 19.0 20.0]
  "u" => [-0.0870507 -0.0404285 … -0.192733 -0.137675; 0.0251584 0.0371931 … 0.…
  "a" => [-0.200997 -0.162123 … -0.259321 -0.22478; -0.0454693 -0.0613871 … 0.0…

In [25]:
# Switch to column-major format
U = Float64.(data["u"])
U = permutedims(U, (2, 3, 4, 1))
(nx, ny, nt, nd) = size(U)

(64, 64, 20, 1200)

In [26]:
# Setup FFT
Fx = ParDFT(nx)
Fy = ParDFT(ny)
Ft = ParDRFT(nt)
F = Ft ⊗ (Fy ⊗ Fx)

# Setup Restriction
mx = 4 # Number of Fourier modes in x
my = 4 # Number of Fourier modes in y
mt = 4 # Number of Fourier modes in z
Rx = ParRestriction(ComplexF64, nx, [1:mx, nx-mx+1:nx])
Ry = ParRestriction(ComplexF64, ny, [1:my, ny-my+1:ny])
Rt = ParRestriction(ComplexF64, Range(Ft), [1:mt])
R = Rt ⊗ Ry ⊗ Rx

# Model parameters
nc = 20 # Lifted space dimension
nb = 4  # Number of spectral convolution blocks

# Model layers
Ic = ParIdentity(ComplexF64, nc)
Is = ParIdentity(Float64, nx*ny*nt)
Ds = [ParDiagonal(ComplexF64, Range(R)) ⊗ ParMatrix(ComplexF64, nc, nc) for i ∈ 1:nb]
Ws = [Is ⊗ ParMatrix(nc, nc) for i ∈ 1:nb]

# Model blocks
σ(x::X) where {X} = tanh.(x)
Fc = F ⊗ Ic
Rc = R ⊗ Ic 
Bs = [σ ∘ (Fc'*Rc'*D*Rc*Fc + W) for (D, W) in zip(Ds, Ws)]
B = ∘(Bs...)
    
# Lifting/Projection
Q = Is ⊗ ParMatrix(nc, 1)
P = Is ⊗ ParMatrix(1, nc)
    
# Full Network
G = P ∘ B ∘ Q;

In [40]:
# Split data
train_split = 0.8
n_train = Int64(round(train_split*nd))
n_valid = nd-n_train

U = (U .- minimum(U))./(maximum(U))
U_train = X[:,:,:,1:n_train]
U_valid = X[:,:,:,n_train+1:end]

64×64×20×240 Array{Float64, 4}:
[:, :, 1, 1] =
 0.354405   0.321781  0.294409  0.263867  …   0.4158      0.387307
 0.345268   0.31735   0.294225  0.277523      0.407946    0.37178
 0.341949   0.307503  0.289195  0.28834       0.417474    0.376199
 0.356984   0.322744  0.30646   0.304228      0.435113    0.390717
 0.388186   0.370142  0.347346  0.337705      0.448963    0.418839
 0.42487    0.413357  0.400049  0.390607  …   0.467986    0.448116
 0.433482   0.423919  0.428782  0.42919       0.485375    0.457407
 0.438219   0.439939  0.459873  0.456413      0.489063    0.452578
 0.459496   0.46605   0.473111  0.458867      0.492598    0.463956
 0.47847    0.479171  0.47229   0.451827      0.49027     0.482007
 0.47716    0.475542  0.466609  0.44785   …   0.486488    0.491613
 0.46054    0.460616  0.454868  0.455588      0.472525    0.469805
 0.471313   0.465853  0.461507  0.48158       0.46418     0.469627
 ⋮                                        ⋱               
 0.0564333  0.116393  0.

In [65]:
# Initialize model weights
θ = CuArray(init(G))
    
# Setup optimization parameters and buffers for ADAM
α  = 1e-3
β1 = 0.9
β2 = 0.999
ϵ  = 1e-8

m = zero(θ)
v = zero(θ)
m̂ = zero(θ)
v̂ = zero(θ)

n_epochs = 10

# Setup training logs
losses_train = zeros(n_train, n_epochs);
losses_valid = zeros(n_valid, n_epochs);

In [66]:
# Training loop
for i ∈ 1:n_epochs
    schedule = Random.shuffle(1:n_train)
    prog = Progress(n_train, "Training: ")
    for (k, j) ∈ enumerate(schedule)
        x = CuArray(vec(repeat(U_train[:,:,1:1,1], 1, 1, nt)))
        y = CuArray(vec(U_train[:,:,:,1]))
        
        g = gradient(p -> begin
            ŷ = G(x, p)
            l = norm(y.-ŷ)
            @ignore_derivatives losses_train[k,i] = l
            return l
        end, θ)[1]
        
        #t = (i-1)*n_train + k
        #m .= β1.*m + (1-β1).*g
        #v .= β2.*v + (1-β2).*g.^2
        #m̂ .= m ./ (1-β1^t)
        #v̂ .= v ./ (1-β2^t)
        #θ .-= α.*m̂./(sqrt.(v̂).+ϵ)
        θ .-= α.*g
    
        ProgressMeter.next!(prog, showvalues = [(:epoch, i), (:batch, k), (:loss, losses_train[k,i])])
    end
end

│  - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
└ @ ProgressMeter /home/tgrady6/.julia/packages/ProgressMeter/sN2xr/src/ProgressMeter.jl:618
[32mTraining:   3%|█▍                                       |  ETA: 0:04:37[39m
[34m  epoch:  1[39m
[34m  batch:  31[39m
[34m  loss:   75.28051652991715[39m

LoadError: TaskFailedException

[91m    nested task error: [39mInterruptException:
    Stacktrace:
     [1] [0m[1mpoptask[22m[0m[1m([22m[90mW[39m::[0mBase.InvasiveLinkedListSynchronized[90m{Task}[39m[0m[1m)[22m
    [90m   @ [39m[90mBase[39m [90m./[39m[90m[4mtask.jl:921[24m[39m
     [2] [0m[1mwait[22m[0m[1m([22m[0m[1m)[22m
    [90m   @ [39m[90mBase[39m [90m./[39m[90m[4mtask.jl:930[24m[39m
     [3] [0m[1mwait[22m[0m[1m([22m[90mc[39m::[0mBase.GenericCondition[90m{Base.Threads.SpinLock}[39m[0m[1m)[22m
    [90m   @ [39m[90mBase[39m [90m./[39m[90m[4mcondition.jl:124[24m[39m
     [4] [0m[1m_trywait[22m[0m[1m([22m[90mt[39m::[0mTimer[0m[1m)[22m
    [90m   @ [39m[90mBase[39m [90m./[39m[90m[4masyncevent.jl:138[24m[39m
     [5] [0m[1mwait[22m
    [90m   @ [39m[90m./[39m[90m[4masyncevent.jl:155[24m[39m[90m [inlined][39m
     [6] [0m[1mmacro expansion[22m
    [90m   @ [39m[90m~/.julia/packages/CUDA/DfvRa/lib/cudadrv/[39m[90m[4mstream.jl:169[24m[39m[90m [inlined][39m
     [7] [0m[1m(::CUDA.var"#14#17"{CuStream, Timer, CuDevice, Base.Event})[22m[0m[1m([22m[0m[1m)[22m
    [90m   @ [39m[35mCUDA[39m [90m./[39m[90m[4mthreadingconstructs.jl:258[24m[39m