In [1]:
using Pkg
Pkg.activate("./")
# Pkg.resolve()
# Pkg.instantiate()

using DrWatson
using MPI
using ParametricOperators
using Parameters
using Profile
using Shuffle
using Zygote
using PyPlot
using NNlib
using NNlibCUDA
using FNO4CO2
using JLD2
using Flux, Random, FFTW
using MAT, Statistics, LinearAlgebra
using CUDA
using ProgressMeter
using InvertibleNetworks:ActNorm
using Random
matplotlib.use("Agg")

[32m[1m  Activating[22m[39m project at `~/Desktop/Research/Code/dfno`


┌ Info: OMEinsum loaded the CUDA module successfully
└ @ OMEinsum /Users/richardr2926/.julia/packages/OMEinsum/lTBCn/src/cueinsum.jl:117
┌ Info: FNO4CO2 is using GPU
└ @ FNO4CO2 /Users/richardr2926/Desktop/Research/Code/FNO4CO2/src/FNO4CO2.jl:15


In [2]:
update = ParametricOperators.update!


@with_kw struct ModelConfig
    nx::Int = 64
    ny::Int = 64
    nz::Int = 64
    nt_in::Int = 51
    nt_out::Int = 51
    nc_in::Int = 4
    nc_mid::Int = 128
    nc_out::Int = 1
    nc_lift::Int = 20
    mx::Int = 4
    my::Int = 4
    mz::Int = 4
    mt::Int = 4
    n_blocks::Int = 1
    n_batch::Int = 1
    dtype::DataType = Float32
    partition::Vector{Int} = [1]
end

ModelConfig

In [8]:
function PO_FNO4CO2(config::ModelConfig)

    T = config.dtype

    function lifting(in_shape, lift_dim, out_features, T=Float32)

        net = ParIdentity(T, 1) 
    
        for dim in eachindex(in_shape)
            if dim == lift_dim
                layer = ParMatrix(T, out_features, in_shape[dim])
            else 
                layer = ParIdentity(T, in_shape[dim])
            end
            
            if dim == 1
                net = layer
            else
                net = layer ⊗ net
            end
        end
    
        return net
    end

    function spectral_convolution()

        # Build 4D Fourier transform with real-valued FFT along time
        fourier_x = ParDFT(Complex{T}, config.nx)
        fourier_y = ParDFT(Complex{T}, config.ny)
        # fourier_z = ParDFT(Complex{T}, config.nz)
        fourier_t = ParDFT(T, config.nt_out)

        # Build restrictions to low-frequency modes
        restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
        restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
        # restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz])
        restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])

        # weight_mix = ParIdentity(Complex{T}, Range(restrict_dft) ÷ config.nc_lift) ⊗
        #             ParMatrix(Complex{T}, config.nc_lift, config.nc_lift)

        input_shape = (config.nc_lift, 2*config.mx, 2*config.my, config.mt)
        weight_shape = (config.nc_lift, config.nc_lift, 2*config.mx, 2*config.my, config.mt)

        input_order = (1, 2, 3, 4)
        weight_order = (5, 1, 2, 3, 4)
        target_order = (5, 2, 3, 4)

        weight_mix = ParMatrixN(Complex{T}, weight_order, weight_shape, input_order, input_shape, target_order, input_shape) 

        # Setup FFT-restrict pattern with Kroneckers
        restrict_dft = (restrict_t * fourier_t) ⊗ (restrict_y * fourier_y) ⊗ (restrict_x * fourier_x) ⊗ ParIdentity(T, config.nc_lift)

        sconv = restrict_dft' * weight_mix * restrict_dft

        return sconv
    end

    shape = [config.nc_in, config.nx, config.ny, config.nt_in]

    sconvs = []
    convs = []
    projects = []
    sconv_biases = []
    biases = []

    # Lift Channel dimension
    lifts = ParIdentity(Float32,round(Int, prod(shape)/config.nc_in)) ⊗ ParMatrix(Float32, config.nc_lift, config.nc_in) # lifting(shape, 1, config.nc_lift)
    bias = ParIdentity(Float32,round(Int, prod(shape)/config.nc_in)) ⊗ ParDiagonal(Float32, config.nc_lift) # TODO: Rearrange code for all bias so it makes more sense mathematically
    push!(biases, bias)

    shape[1] = config.nc_lift

    for i in 1:config.n_blocks

        sconv_layer = spectral_convolution()
        conv_layer = ParIdentity(Float32,round(Int, prod(shape)/config.nc_lift)) ⊗ ParMatrix(Float32, config.nc_lift, config.nc_lift) # lifting(shape, 1, config.nc_lift)
        bias = ParIdentity(Float32,round(Int, prod(shape)/config.nc_lift)) ⊗ ParDiagonal(Float32, config.nc_lift)

        push!(sconv_biases, bias)
        push!(sconvs, sconv_layer)
        push!(convs, conv_layer)
    end

    # Uplift channel dimension once more
    uc = ParIdentity(Float32,round(Int, prod(shape)/config.nc_lift)) ⊗ ParMatrix(Float32, config.nc_mid, config.nc_lift) # lifting(shape, 1, config.nc_mid)
    bias = ParIdentity(Float32,round(Int, prod(shape)/config.nc_lift)) ⊗ ParDiagonal(Float32, config.nc_mid)
    push!(biases, bias)
    push!(projects, uc)

    shape[1] = config.nc_mid

    # Project channel dimension
    pc = ParIdentity(Float32,round(Int, prod(shape)/config.nc_mid)) ⊗ ParMatrix(Float32, config.nc_out, config.nc_mid) # lifting(shape, 1, config.nc_out)
    bias = ParIdentity(Float32,round(Int, prod(shape)/config.nc_mid)) ⊗ ParDiagonal(Float32, config.nc_out)
    push!(biases, bias)
    push!(projects, pc)

    shape[1] = config.nc_out

    return lifts, sconvs, convs, projects, biases, sconv_biases
end

modes = 4
width = 20

config = ModelConfig(mx=modes, my=modes, mt=modes, nc_lift=width, n_blocks=4, n_batch=2)
lifts, sconvs, convs, projects, biases, sconv_biases = PO_FNO4CO2(config)

# To Load Saved Dict: 
# key = load("./data/3D_FNO/.jld2")["key"]

θ = init(lifts)
for operator in Iterators.flatten((sconvs, convs, biases, sconv_biases, projects))
    init!(operator, θ)
end

gpu_flag && (global θ = gpu(θ))

function xytcb_to_cxytb(x)
    return permutedims(x, [4,1,2,3,5])
end

function cxytb_to_xytcb(x)
    return permutedims(x, [2,3,4,1,5])
end

function forward(θ, x::Any)
    temp = ones(DDT(biases[1]), Domain(biases[1]), config.n_batch)
    x = lifts(θ) * x + biases[1](θ) * temp

    temp = ones(DDT(sconv_biases[1]), Domain(sconv_biases[1]), config.n_batch)

    for i in 1:config.n_blocks

        x = (sconvs[i](θ) * x) + (convs[i](θ) * x) + (sconv_biases[i](θ) * temp)
        x = cxytb_to_xytcb(reshape(x, (config.nc_lift, config.nx, config.ny, config.nt_in, :)))

        N = ndims(x)
        ϵ = 1f-5

        reduce_dims = [1:N-2; N]

        μ = mean(x; dims=reduce_dims)
        σ² = var(x; mean=μ, dims=reduce_dims, corrected=false)

        prod = config.nc_lift * config.nx * config.ny * config.nt_in

        x = (x .- μ) ./ sqrt.(σ² .+ ϵ)
        x = reshape(xytcb_to_cxytb(x), (prod, :))
        
        if i < config.n_blocks
            x = relu.(x)
        end
    end

    temp = ones(DDT(biases[2]), Domain(biases[2]), config.n_batch)
    x = projects[1](θ) * x + biases[2](θ) * temp
    x = relu.(x)

    temp = ones(DDT(biases[3]), Domain(biases[3]), config.n_batch)
    x = projects[2](θ) * x + biases[3](θ) * temp
    return x
end

# x_train = rand(DDT(lifts), Domain(lifts))
# y_train = rand(RDT(projects), Range(projects))

# y = forward(θ, x_train)
# grads = gradient(params -> Flux.mse(forward(params, x_train), y_train), θ)

forward (generic function with 1 method)

In [4]:
# Define raw data directory
mkpath(datadir("training-data"))
perm_path = datadir("training-data", "perm_gridspacing15.0.mat")
conc_path = datadir("training-data", "conc_gridspacing15.0.mat")

# Download the dataset into the data directory if it does not exist
if ~isfile(perm_path)
    run(`wget https://www.dropbox.com/s/o35wvnlnkca9r8k/'
        'perm_gridspacing15.0.mat -q -O $perm_path`)
end
if ~isfile(conc_path)
    run(`wget https://www.dropbox.com/s/mzi0xgr0z3l553a/'
        'conc_gridspacing15.0.mat -q -O $conc_path`)
end

In [5]:

perm = matread(perm_path)["perm"];
conc = matread(conc_path)["conc"];

nsamples = size(perm, 3)

ntrain = 1000
nvalid = 100

batch_size = config.n_batch
learning_rate = 1f-4

epochs = 1

modes = 4
width = 20

n = (config.nx,config.ny)
#d = (15f0,15f0) # dx, dy in m
d = (1f0/config.nx, 1f0/config.ny)

s = 1

nt = 51
#dt = 20f0    # dt in day
dt = 1f0/(nt-1)

AN = ActNorm(ntrain)
AN.forward(reshape(perm[1:s:end,1:s:end,1:ntrain], n[1], n[2], 1, ntrain));

y_train = permutedims(conc[1:nt,1:s:end,1:s:end,1:ntrain],[2,3,1,4]);
y_valid = permutedims(conc[1:nt,1:s:end,1:s:end,ntrain+1:ntrain+nvalid],[2,3,1,4]);

grid = gen_grid(n, d, nt, dt)

x_train = perm_to_tensor(perm[1:s:end,1:s:end,1:ntrain],grid,AN);
x_valid = perm_to_tensor(perm[1:s:end,1:s:end,ntrain+1:ntrain+nvalid],grid,AN);
x_valid_dfno = xytcb_to_cxytb(x_valid)

# value, x, y, t

NN = Net3d(modes, width)
gpu_flag && (global NN = NN |> gpu)

Flux.trainmode!(NN, true)
w = Flux.params(NN)

opt = Flux.Optimise.ADAMW(learning_rate, (0.9f0, 0.999f0), 1f-4)
nbatches = Int(ntrain/batch_size)

Loss = zeros(Float32,epochs*nbatches)
Loss_valid = zeros(Float32, epochs)
prog = Progress(round(Int, ntrain * epochs / batch_size))

# plot figure
x_plot = x_valid[:, :, :, :, 1:1]
y_plot = y_valid[:, :, :, 1:1]
x_plot_dfno = vec(xytcb_to_cxytb(x_plot))

# Define result directory

sim_name = "3D_FNO"
exp_name = "2phaseflow"

save_dict = @strdict exp_name
plot_path = plotsdir(sim_name, savename(save_dict; digits=6))

"/Users/richardr2926/Desktop/Research/Code/dfno/plots/3D_FNO/exp_name=2phaseflow"

In [6]:
ep = 1
b = 1

Base.flush(Base.stdout)
idx_e = reshape(randperm(ntrain), batch_size, nbatches)

x = x_train[:, :, :, :, idx_e[:,b]]
y = y_train[:, :, :, idx_e[:,b]]

x_dfno = reshape(xytcb_to_cxytb(x), (:, config.n_batch))
y_dfno = reshape(y, (:, config.n_batch));

In [11]:
grads_dfno = gradient(params -> norm(relu01(forward(params, x_dfno))-y_dfno)/norm(y_dfno), θ)[1]
grads = gradient(w) do
    global loss = norm(relu01(NN(x))-y)/norm(y)
    return loss
end

loss_dfno = norm(relu01(forward(θ, x_dfno))-y_dfno)/norm(y_dfno)
loss = norm(relu01(NN(x))-y)/norm(y)

# loss_dfno = norm(relu01(forward(θ, x_dfno)))/norm(y_dfno)
# loss = norm(relu01(NN(x)))/norm(y)

println("DFNO Loss: ", loss_dfno, ". NN Loss: ", loss)

DFNO Loss: 2.7084153. NN Loss: 2.7084153


In [13]:
norm(vec(forward(θ, x_dfno))-vec(NN(x)))/norm(vec(forward(θ, x_dfno))+vec(NN(x)))

1.4674958f-7

In [22]:
o = 1
test_w1 = 0

for (v, p) in θ
    println(v)
    if o == -12
        # println(v)
        test_w1 = p
    end
    o += 1
end

ParDiagonal{Float32}(20)
ParDiagonal{Float32}(1)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("436887fb-9529-4f72-b13b-cbd7cd8a05c6"))
ParMatrix{Float32}(20, 20, UUID("05695064-9539-45ba-917b-d798f0983b2c"), 0)
ParMatrix{Float32}(128, 20, UUID("d7e8609e-0a69-4659-b3f4-1adc3e9c329d"), 0)
ParMatrix{Float32}(1, 128, UUID("638d206c-ca9d-488c-aac7-77e86eff89b9"), 0)


ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("582b2a99-7610-49c3-8343-b3f38b0dc614"))
ParMatrix{Float32}(20, 20, UUID("47f5671c-8a3e-48b1-8d31-d7145f3da751"), 0)
ParMatrix{Float32}(20, 20, UUID("5b0aa3d7-eb1b-45b3-afe9-4100b674f4c1"), 0)
ParMatrix{Float32}(20, 20, UUID("e241b9f1-594a-45df-bd97-0c8eb5055f01"), 0)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("b2efce6e-a4c1-4bcd-9b8d-8389ab538e17"))
ParDiagonal{Float32}(128)
ParMatrix{Float32}(20, 4, UUID("27150ea5-b0a8-47b7-9ecb-b20b8b69ccbb"), 0)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("67e19fc2-a6be-416e-b180-0e00427dbb0c"))


In [23]:
o = 1
test_w2 = 0
for p in w
    println(size(p))
    if o == -16
        # println(size(p))
        test_w2 = p
    end
    o += 1
end

(1, 1, 1, 4, 20)
(20,)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 128)
(128,)
(1, 1, 1, 128, 1)
(1,)


In [26]:
# norm(vec(test_w1) - vec(test_w2))

In [12]:
# norm(vec(permutedims(test_w1, [2, 1])) - vec(test_w2))

In [13]:
function gen(shape...)
    Random.seed!(1234)
    return Flux.glorot_uniform(shape...)
    # return rand(Float32, shape...) / convert(Float32, sqrt(prod(shape)))
end

function compl_mul(x::AbstractArray{T, 5}, y::AbstractArray{T, 5}) where T
    # complex multiplication
    y =  permutedims(y,[5,3,4,2,1]) # (oixyt) -> (txyio) bc x is (txyib)
    # x in (modes1, modes2, modes3, input channels, batchsize)
    # y in (modes1, modes2, modes3, input channels, output channels)
    # output in (modes1,modes2,modes3,output channels,batchsize)
    x_per = permutedims(x,[5,4,1,2,3]) # batchsize*in_channels*modes1*modes2*modes3
    y_per = permutedims(y,[4,5,1,2,3]) # in_channels*out_channels*modes1*modes2*modes3
    x_resh = reshape(x_per,size(x_per,1),size(x_per,2),:) # batchsize*in_channels*(modes1*modes2*modes3)
    y_resh = reshape(y_per,size(y_per,1),size(y_per,2),:) # in_channels*out_channels*(modes1*modes2*modes3)
    out_resh = batched_mul(x_resh,y_resh) # batchsize*out_channels*(modes1*modes2*modes3)
    out_per = reshape(out_resh,size(out_resh,1),size(out_resh,2),size(x,1),size(x,2),size(x,3)) # batchsize*out_channels*modes1*modes2*modes3
    out = permutedims(out_per,[3,4,5,2,1])
    return out
end

T = Float32

conv = Flux.Conv((1, 1, 1), config.nc_in=>config.nc_lift; init=gen, bias=false)
xt = conv(x) # xytcb
temp = conv(x) # xytcb

xt = permutedims(xt, [3,1,2,4,5]) # txycb
x_ft = rfft(xt,[1,2,3])      ## full size FFT

Random.seed!(1234)
weights = rand(Complex{T}, config.nc_lift, config.nc_lift, 8, 8, 4, 1) ./ convert(T, sqrt(config.nc_lift * config.nc_lift * 8 * 8 * 4))

modes1 = config.mt
modes2 = config.mx
modes3 = config.my

# only keep low frequency coefficients weights[1,1,1,:,:,1]
out_ft = cat(cat(cat(compl_mul(x_ft[1:modes1, 1:modes2, 1:modes3, :,:], weights[:,:,1:4,1:4,:,1]),
                zeros(Complex{T}, modes1, modes2, size(x_ft,3)-2*modes3, size(x_ft,4), size(x_ft,5)), 
                compl_mul(x_ft[1:modes1, 1:modes2, end-modes3+1:end,:,:], weights[:,:,1:4,5:8,:,1]),dims=3),
                zeros(Complex{T}, modes1, size(x_ft, 2)-2*modes2, size(x_ft,3), size(x_ft,4), size(x_ft,5)),
                cat(compl_mul(x_ft[1:modes1, end-modes2+1:end, 1:modes3,:,:], weights[:,:,5:8,1:4,:,1]),
                zeros(Complex{T}, modes1, modes2, size(x_ft,3)-2*modes3, size(x_ft,4), size(x_ft,5)),
                compl_mul(x_ft[1:modes1, end-modes2+1:end, end-modes3+1:end,:,:], weights[:,:,5:8,5:8,:,1]),dims=3)
                ,dims=2),
                zeros(Complex{T}, size(x_ft,1)-modes1, size(x_ft,2), size(x_ft,3), size(x_ft,4), size(x_ft,5)),dims=1)
# println(size(out_ft))
# println(size(xt, 1))
out_ft = irfft(out_ft, size(xt,1),[1,2,3]) # nt * nx * ny * channels * batch
out_ft = permutedims(out_ft, [2,3,1,4,5]); # nx * ny * nt * channels * batch

In [14]:
lifting = ParIdentity(T, config.nx*config.ny*config.nt_in) ⊗ ParMatrix(T, config.nc_lift, config.nc_in)
θ_new = init(lifting)

fourier_x = ParDFT(Complex{T}, config.nx)
fourier_y = ParDFT(Complex{T}, config.ny)
# fourier_z = ParDFT(Complex{T}, config.nz)
fourier_t = ParDFT(T, config.nt_out)

# Build restrictions to low-frequency modes
restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
# restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])

input_shape = (config.nc_lift, 2*config.mx, 2*config.my, config.mt)
weight_shape = (config.nc_lift, config.nc_lift, 2*config.mx, 2*config.my, config.mt) # 3 is the no of dimensions including time
target_shape = input_shape

input_order = (1, 2, 3, 4)
weight_order = (5, 1, 2, 3, 4)
target_order = (5, 2, 3, 4)

weight_mix = ParMatrixN(Complex{T}, weight_order, weight_shape, input_order, input_shape, target_order, target_shape) 
init!(weight_mix, θ_new)

dft = (restrict_t * fourier_t) ⊗
    (restrict_y * fourier_y) ⊗
    (restrict_x * fourier_x) ⊗
    ParIdentity(T, config.nc_lift)
output = cxytb_to_xytcb(reshape(dft' * weight_mix(θ_new) * dft * lifting(θ_new) * x_dfno, (config.nc_lift, config.nx, config.ny, config.nt_out, config.n_batch)))
;

In [15]:
norm(output-out_ft)/norm(output+out_ft)

4.271905f-8

In [16]:
function gen(shape...)
    Random.seed!(1234)
    return Flux.glorot_uniform(shape...)
    # return rand(Float32, shape...) / convert(Float32, sqrt(prod(shape)))
end

function compl_mul(x::AbstractArray{T, 5}, y::AbstractArray{T, 5}) where T
    # complex multiplication
    y =  permutedims(y,[5,3,4,2,1]) # (oixyt) -> (txyio) bc x is (txyib)
    # x in (modes1, modes2, modes3, input channels, batchsize)
    # y in (modes1, modes2, modes3, input channels, output channels)
    # output in (modes1,modes2,modes3,output channels,batchsize)
    x_per = permutedims(x,[5,4,1,2,3]) # batchsize*in_channels*modes1*modes2*modes3
    y_per = permutedims(y,[4,5,1,2,3]) # in_channels*out_channels*modes1*modes2*modes3
    x_resh = reshape(x_per,size(x_per,1),size(x_per,2),:) # batchsize*in_channels*(modes1*modes2*modes3)
    y_resh = reshape(y_per,size(y_per,1),size(y_per,2),:) # in_channels*out_channels*(modes1*modes2*modes3)
    out_resh = batched_mul(x_resh,y_resh) # batchsize*out_channels*(modes1*modes2*modes3)
    out_per = reshape(out_resh,size(out_resh,1),size(out_resh,2),size(x,1),size(x,2),size(x,3)) # batchsize*out_channels*modes1*modes2*modes3
    out = permutedims(out_per,[3,4,5,2,1])
    return out
end

T = Float32

# x_dfno : cxytb
# x      : xytcb
println("Norm on base input: ", norm(vec(xytcb_to_cxytb(x)) - vec(x_dfno)))

conv = Flux.Conv((1, 1, 1), config.nc_in=>config.nc_lift; init=gen, bias=false)
xt = conv(x) # xytcb

lifting = ParIdentity(T, config.nx*config.ny*config.nt_in) ⊗ ParMatrix(T, config.nc_lift, config.nc_in)
θ_new = init(lifting)

xt_dfno = lifting(θ_new) * x_dfno
println("Norm after lifting: ", norm(vec(xytcb_to_cxytb(xt)) - vec(xt_dfno)))

xt = permutedims(xt, [3,1,2,4,5]) # txycb
xt = rfft(xt,[1,2,3])      ## full size FFT
xt_fft = xt

fourier_x = ParDFT(Complex{T}, config.nx)
fourier_y = ParDFT(Complex{T}, config.ny)
fourier_t = ParDFT(T, config.nt_out)

restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])

fft = (fourier_t ⊗ fourier_y ⊗ fourier_x) ⊗ ParIdentity(T, config.nc_lift)
xt_dfno_fft = fft * xt_dfno

println("Norm after fft: ", norm(vec(permutedims(xt, [4,2,3,1,5])) - vec(xt_dfno_fft)))

modes1 = config.mt
modes2 = config.mx
modes3 = config.my

# println(size(xt))
xt = cat(cat(xt[1:modes1, 1:modes2, 1:modes3, :,:],
                    xt[1:modes1, 1:modes2, end-modes3+1:end,:,:],dims=3),
                cat(xt[1:modes1, end-modes2+1:end, 1:modes3,:,:],
                xt[1:modes1, end-modes2+1:end, end-modes3+1:end,:,:],dims=3)
                ,dims=2) #txycb
                
# println(size(xt))
restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])

dft = (restrict_t * fourier_t) ⊗
    (restrict_y * fourier_y) ⊗
    (restrict_x * fourier_x) ⊗
    ParIdentity(T, config.nc_lift)

xt_dfno = dft * xt_dfno

println("Norm after restriction: ", norm(vec(permutedims(xt, [4,2,3,1,5])) - vec(xt_dfno)))

Random.seed!(1234)
weights = rand(Complex{T}, config.nc_lift, config.nc_lift, 8, 8, 4, 1) ./ convert(T, sqrt(config.nc_lift * config.nc_lift * 8 * 8 * 4))

input_shape = (config.nc_lift, 2*config.mx, 2*config.my, config.mt)
weight_shape = (config.nc_lift, config.nc_lift, 2*config.mx, 2*config.my, config.mt) # 3 is the no of dimensions including time
target_shape = input_shape

input_order = (1, 2, 3, 4)
weight_order = (5, 1, 2, 3, 4)
target_order = (5, 2, 3, 4)

weight_mix = ParMatrixN(Complex{T}, weight_order, weight_shape, input_order, input_shape, target_order, target_shape) 
θ_test = init(weight_mix)
init!(weight_mix, θ_new)

for (k, v) in θ_test
    println("Norm of Weight difference: ", norm(weights - v))
end

# only keep low frequency coefficients weights[1,1,1,:,:,1] 
# xt : txycb
xt_weighted = cat(cat(compl_mul(xt[1:modes1, 1:modes2, 1:modes3, :,:], weights[:,:,1:4,1:4,:,1]),
        compl_mul(xt[1:modes1, 1:modes2, end-modes3+1:end,:,:], weights[:,:,1:4,5:8,:,1]),dims=3),
        cat(compl_mul(xt[1:modes1, end-modes2+1:end, 1:modes3,:,:], weights[:,:,5:8,1:4,:,1]),
        compl_mul(xt[1:modes1, end-modes2+1:end, end-modes3+1:end,:,:], weights[:,:,5:8,5:8,:,1]),dims=3)
        ,dims=2)
xt_dfno = weight_mix(θ_new) * xt_dfno

println("Norm after weightage: ", norm(vec(permutedims(xt_weighted, [4,2,3,1,5])) - vec(xt_dfno)))

out_ft = cat(cat(cat(xt_weighted[1:modes1, 1:modes2, 1:modes3, :,:],
                zeros(Complex{T}, modes1, modes2, size(xt_fft,3)-2*modes3, size(xt_fft,4), size(xt_fft,5)), 
                xt_weighted[1:modes1, 1:modes2, end-modes3+1:end,:,:],dims=3),
                zeros(Complex{T}, modes1, size(xt_fft, 2)-2*modes2, size(xt_fft,3), size(xt_fft,4), size(xt_fft,5)),
                cat(xt_weighted[1:modes1, end-modes2+1:end, 1:modes3,:,:],
                zeros(Complex{T}, modes1, modes2, size(xt_fft,3)-2*modes3, size(xt_fft,4), size(xt_fft,5)),
                xt_weighted[1:modes1, end-modes2+1:end, end-modes3+1:end,:,:],dims=3)
                ,dims=2),
                zeros(Complex{T}, size(xt_fft, 1)-modes1, size(xt_fft,2), size(xt_fft,3), size(xt_fft,4), size(xt_fft,5)),dims=1)

out_ft = irfft(out_ft, config.nt_in,[1,2,3]) # nt * nx * ny * channels * batch
xt_dfno = dft' * xt_dfno

println("Norm after inverse FFT + restrict: ", norm(vec(permutedims(out_ft, [4,2,3,1,5])) - vec(xt_dfno)))

Norm on base input: 0.0


Norm after lifting: 0.0
Norm after fft: 

0.0
Norm after restriction: 0.0
Norm of Weight difference: 

0.0
Norm after weightage: 0.00033601018
Norm after inverse FFT + restrict: 

9.625966e-7


In [19]:
function gen(shape...)
    Random.seed!(1234)
    return Flux.glorot_uniform(shape...)
    # return rand(Float32, shape...) / convert(Float32, sqrt(prod(shape)))
end

T = Float32

# x_dfno : cxytb
# x      : xytcb
println("Norm on base input: ", norm(vec(xytcb_to_cxytb(x)) - vec(x_dfno)))

conv = Flux.Conv((1, 1, 1), config.nc_in=>config.nc_lift; init=gen, bias=true)
xt = conv(x) # xytcb

for w in Flux.params(conv)
    println(w)
end

bias = ParIdentity(T, config.nx*config.ny*config.nt_in) ⊗ ParDiagonal(T, config.nc_lift)
lift = ParIdentity(T, config.nx*config.ny*config.nt_in) ⊗ ParMatrix(T, config.nc_lift, config.nc_in)

println("Bias RxD: ", Range(bias), " ", Domain(bias))
println("Lifting: ", Range(lift), " ", Domain(lift))

function lifting_layer(θ, x)
    temp = ones(T, Domain(bias), config.n_batch)
    return lift(θ) * x + bias(θ) * temp
end

θ_new = init(bias)
init!(lift, θ_new)

xt_dfno = lifting_layer(θ_new, x_dfno)
println("Norm after lifting_layer: ", norm(vec(xytcb_to_cxytb(xt)) - vec(xt_dfno)))

y_t_out = rand(Random.seed!(1234), T, 64, 64, 51, 20, 2)
y_dfno_out = reshape(xytcb_to_cxytb(y_t_out), 64*64*51*20, 2)

grads_xt_dfno = gradient(params -> norm(relu01(lifting_layer(params, x_dfno))-y_dfno_out)/norm(y_dfno_out), θ_new)[1]
w = Flux.params(conv)
grads_xt = gradient(w) do
    global loss = norm(relu01(conv(x))-y_t_out)/norm(y_t_out)
    return loss
end

g1 = nothing
g2 = nothing

for g in grads_xt
    if size(vec(g)) == (20,)
        g1 = g
    end
end

for (k, g) in grads_xt_dfno
    # println(size(g))
    if size(vec(g)) == (20,)
        g2 = g
    end
end

println("Norm of gradient: ", norm(vec(g1) - vec(g2)))

Norm on base input: 0.0
[-0.33467782;;;; 0.07986212;;;; 0.13177484;;;; -0.0887059;;;;; -0.092202485;;;; 0.47213608;;;; -0.39950258;;;; -0.4850912;;;;; -0.14634603;;;; 0.020354986;;;; -0.095781446;;;; 0.1395616;;;;; 0.28368402;;;; 0.3396219;;;; -0.4455446;;;; 0.46714276;;;;; -0.0018628836;;;; 0.2897644;;;; -0.28694052;;;; 0.19604069;;;;; 0.42789656;;;; 0.06670427;;;; -0.20519775;;;; 0.03636855;;;;; 0.12906319;;;; 0.21138924;;;; 0.3086548;;;; -0.39607054;;;;; -0.04428661;;;; 0.3067041;;;; -0.057995915;;;; 0.3705393;;;;; 0.12084699;;;; 0.4627145;;;; -0.14333922;;;; -0.3488199;;;;; -0.08929753;;;; 0.21535462;;;; -0.41860288;;;; 0.4395476;;;;; 0.1600315;;;; 0.026343644;;;; -0.025538623;;;; -0.4220317;;;;; 0.4344228;;;; 0.46619666;;;; 0.17604738;;;; 0.16655785;;;;; 0.22604805;;;; -0.16614097;;;; -0.30865377;;;; 0.30219877;;;;; 0.18089682;;;; -0.34354162;;;; 0.3965994;;;; -0.119200826;;;;; 0.20674825;;;; -0.43124092;;;; -0.21257704;;;; 0.36172485;;;;; -0.021605551;;;; -0.49456346;;;; -0.21759

0.0


Norm of gradient: 2.32693e-5


In [48]:
64*64*51*20

4177920

In [17]:
o = 0
test = 0
for (k, v) in θ_new
    if o == 1
        test = v
    end
    o += 1
end

# size(weights)
# size(test)

# sum(test - weights)

In [None]:
using ParametricOperators

T = Complex{Int64}

input_shape = (5, 3, 3)
weight_shape = (5, 5, 3, 3)
target_shape = input_shape

input_order = (1, 2, 3)
weight_order = (4, 1, 2, 3)
target_order = (4, 2, 3)

operator = ParMatrixN(T, weight_order, weight_shape, input_order, input_shape, target_order, target_shape) 
weights = init(operator)

operator(weights) * vec(rand(T, input_shape...))

In [None]:
using OMEinsum

input_shape = (5, 3, 3)
weight_shape = (5, 5, 3, 3)

input = rand(input_shape...)
weight = rand(weight_shape...)

target_order = (4, 2, 3)
weight_order = (4, 1, 2, 3)
input_order = (1, 2, 3)

target = einsum(EinCode((weight_order,input_order),target_order),(weight, input))

In [23]:
tempx = permutedims(x_ft[1:modes1, 1:modes2, 1:modes3, :,:], [4,2,3,1,5]) # nc * nx * ny * nt * batch
tempx = reshape(tempx, (1, 20, 64))

tempxout = batched_mul(tempx,weights[1,1,1,:,:,1])

inter = reshape(dft * lifting(θ_new) * x_dfno, (20, 8, 8, 4))
tempy = inter[:, 1:4, 1:4, :]

weights_dfno = ParIdentity(Complex{T}, 4*4*4) ⊗ ParMatrix(Complex{T}, config.nc_lift, config.nc_lift) # nc * nx * ny * nt
θ_new2 = init(weights_dfno)
weights_temp = 0

for (k, v) in θ_new2
    weights_temp = permutedims(v, [2, 1])
end

tempyout1 = reshape(weights_dfno(θ_new2) * vec(tempy), (1, 20, 64))
tempyout2 = batched_mul(reshape(tempy, (1,20,64)),weights_temp)

sum(vec(tempyout1) - vec(tempyout2))
# println(sum(vec(tempx) - vec(tempy)))
# println(sum(vec(weights_temp) - vec(weights[1,1,1,:,:,1])))

# Random.seed!(1234)
# baseline = rand(ComplexF32, 20, 20) ./ 20 # 1f0 / convert(real(T), sqrt(20 * 20)) # ./ convert(real(T), sqrt(20 * 20))

# scale = 1f0 / convert(real(T), sqrt(20 * 20))
# Random.seed!(1234)
# weights = scale*rand(Complex{T}, 1, 1, 1, 20, 20, 1)

# println(sum(vec(baseline) - vec(weights[1,1,1,:,:,1])))
# println(sum(vec(baseline) - vec(weights_temp)))

DimensionMismatch: DimensionMismatch: new dimensions (1, 20, 64) must be consistent with array size 2560

In [28]:
u = 1
grad_one = nothing

for p in grads
    println(size(p))
    if u == 16
        grad_one = p
    end
    u += 1
end

(1, 1, 1, 4, 20)
(20,)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(20, 20, 8, 8, 4, 1)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 20)
(20,)
(1, 1, 1, 20, 128)
(128,)
(1, 1, 1, 128, 1)
(1,)


In [30]:
o = 1
grad_two = nothing

for (k, v) in grads_dfno
    println(k)
    if o == 12
        grad_two = v
    end
    o += 1
end

ParDiagonal{Float32}(20)
ParDiagonal{Float32}(1)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("436887fb-9529-4f72-b13b-cbd7cd8a05c6"))
ParMatrix{Float32}(20, 20, UUID("05695064-9539-45ba-917b-d798f0983b2c"), 0)
ParMatrix{Float32}(128, 20, UUID("d7e8609e-0a69-4659-b3f4-1adc3e9c329d"), 0)
ParMatrix{Float32}(1, 128, UUID("638d206c-ca9d-488c-aac7-77e86eff89b9"), 0)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("582b2a99-7610-49c3-8343-b3f38b0dc614"))
ParMatrix{Float32}(20, 20, UUID("47f5671c-8a3e-48b1-8d31-d7145f3da751"), 0)
ParMatrix{Float32}(20, 20, UUID("5b0aa3d7-eb1b-45b3-afe9-4100b674f4c1"), 0

)
ParMatrix{Float32}(20, 20, UUID("e241b9f1-594a-45df-bd97-0c8eb5055f01"), 0)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("b2efce6e-a4c1-4bcd-9b8d-8389ab538e17"))
ParDiagonal{Float32}(128)
ParMatrixN{5, 4, 4, ComplexF32}((5, 1, 2, 3, 4), (20, 20, 8, 8, 4), (1, 2, 3, 4), (20, 8, 8, 4), (5, 2, 3, 4), (20, 8, 8, 4), UUID("67e19fc2-a6be-416e-b180-0e00427dbb0c"))
ParMatrix{Float32}(20, 4, UUID("27150ea5-b0a8-47b7-9ecb-b20b8b69ccbb"), 0)


In [31]:
# norm(vec(grad_one) - vec(grad_two))
# norm(vec(permutedims(grad_two, [2, 1])) - vec(grad_one))

1.5936166f-5