# PEDS example: Maxwell surrogate

This example of PEDS shows how to use the code from: *Pestourie, Raphaël, et al. "Physics-enhanced deep surrogates for PDEs."*  In this example notebook, we illustrate PEDS for the surrogate of the diffusion equation's flux where the approximate solver is a coarse diffusion equation's solver. For this example, the input is the width of 10 layers of holes and the one-hot-encoding of three frequencies.
    
## Note on implementation

In this notebook, we run the training on a single processor and J=1.

However, the implementation of this code is meant to be parallel on multiple processors, where the element of a batch is computed in parallel over J groups of CPUs (J is the number models in the ensemble). For example, with a julia file `train_PEDS.jl`containing functions calls similart to the code in the cells below, the command to run large-scale training on a cluster would be:
    
`mpiexec -n 320 julia train_PEDS.jl` 

In [1]:
##load module
include("../src/PEDS.jl")

##loading data
X = readdlm("../data/X_maxwell10_small.csv", ',')
y = parse.(Complex{Float64}, readdlm("../data/y_maxwell10_small.csv", ',')[:])

Xv = X[:, 1:1024] #valid set
Xtest = X[:, end-1023:end] #test set
Xt = X[:, 1025:end]

yv = y[1:1024] #valid set
ytest = y[end-1023:end] #test set
yt = y[1025:end];

In [2]:
datalimitvalid = 2^10
Jval=1
batchsize=64

##Definition of problem constants
const debug = false
const drv = DataRunner(Xv, yv, [1]);
const al = ALstruct(J=Jval, Nvalid=datalimitvalid, batchsize=batchsize);
const valid = initvalid(al, drv) #validation loader

const drtest = DataRunner(Xtest, ytest, [1]);
const test = initvalid(al, drtest) #validation loader

const nn = NNstruct(outGen=[256, 256, 10*110],
postGen = [x-> @. x*1.5 + 2.5; x-> reshape(x, (110,10,:))],
inVar = [110*10, 256, 256, 256]);
const cs = CSstruct(resolution=10, 
nn_x=10, 
ny_nn=110, 
refsim=0.3364246930443735 + 0.1920021246559511im);
const sd = SimulationDomain(cs)
##setup MPI and random
const comm = MPI.COMM_WORLD
const model_color = MPI.Comm_rank(comm)%al.J
const commModel = MPI.Comm_split(comm, model_color, 0)
const isleader = MPI.Comm_rank(commModel) == 0
const commLeader = MPI.Comm_split(comm, isleader, 0)
debug && print("Comm rank=$(MPI.Comm_rank(comm)), commModel rank = $(MPI.Comm_rank(commModel)), commLeader rank = $(MPI.Comm_rank(commLeader))\n")
Random.seed!(2134*(model_color+1)) #alter seed for different groups

MersenneTwister(2134)

In [3]:
##training functionalities
function train_distributed!(comm, commModel, commLeader, mloglik, m, loss, ps, loader, opt, validation_fes; logging=false)
    for d in loader
        train_loss, back = Zygote.pullback(() -> loss(commModel, mloglik, m, d...), ps)
        gs = back(1.)
        if debug && isleader
            if isnan(train_loss)
                @show (model_color, train_loss)
            end
        end
        for x in ps
            gs[x][:] .= sum_reduce(commModel, Float64.(gs[x][:]))
            if debug && isleader
                if any(isnan.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isnan,x)), length(findall(isnan,gs[x][:])))
                end

                if any(isinf.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isinf,x)), length(findall(isinf,gs[x][:])))
                end
            end
        end
        
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "before update")
                end
            end
        end
        Flux.Optimise.update!(opt, ps, gs)
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "after update")
                end
            end
        end
    end

    logging && push!(validation_fes, dFE(comm, commModel, commLeader, m))
end

function dFE(comm, commModel, commLeader, m; valid=valid) """dFE computes the FE using parallelization over the batch with MPI""" 
    evalsr = zeros(al.Nvalid)
    evalsi = zeros(al.Nvalid)
    FE = 0.
    j=0
    ys = Complex{Float64}[]
    for (x, y) in valid
        for i=1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):length(y)
            rp, ip = m(x[:,i])
            evalsr[j*length(y)+i] = rp
            evalsi[j*length(y)+i] = ip
        end
        j+=1
        push!(ys, y...)
    end
    evalsrModel = sum_reduce(commModel, evalsr)
    evalsiModel = sum_reduce(commModel, evalsi)
    evalsr = sum_reduce(commLeader, evalsrModel) / al.J
    evalsi = sum_reduce(commLeader, evalsiModel) / al.J
    if MPI.Comm_rank(comm) == 0
        ŷ = @. evalsr + 1im * evalsi
        FE = norm(ŷ - ys)/norm(ys)
        @show FE
    end
    return FE
end

dFE (generic function with 1 method)

## PEDS

In [4]:
kval=2^7

##define same AL parameters for all workers
MPI.Barrier(comm)
al1 = ALstruct(Ninit=256+8*kval, T=0);
if MPI.Comm_rank(comm) == 0
    @show kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
(mgen, cw, mvar) = initmodel(nn)
coarseinput(p) = begin 
    (coarsified, sd_freq) = Zygote.ignore() do
        coarse_geom_func(p)
    end
    generated =dropdims(mgen(p), dims=3)
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("inputnan", p) 
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("mgennan", mgen(p)) 
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("mgennanparam", ps)
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("errorcolor", model_color)
    w = NNlib.sigmoid.(cw*nn.multfact)
    # w = max(0, min(1, cw))
    ϵcombine = @. w * generated + (1-w) * coarsified
    ϵsymmetric = ϵcombine#(ϵcombine .+ reverse(ϵcombine, dims=2))./2
    return ϵsymmetric, sd_freq
end
m(p) = begin
    ϵsymmetric, sd_freq = coarseinput(p)
    return [realtransmissionSolver(ϵsymmetric, sd_freq = sd_freq); imagtransmissionSolver(ϵsymmetric, sd_freq = sd_freq)]
end 

uq(p) = mvar(coarseinput(p)[1])

mloglik(p) =  vcat(m(p), uq(p))
ps = Flux.params(mgen, cw, mvar)
# ps = Flux.params(mgen, mvar)
loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)
##train baseline
if MPI.Comm_rank(comm)==0
    @time Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
else
    Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
end
##active learning loop
for t=1:al1.T
    MPI.Comm_rank(comm) == 0 && @show t
    loader= getloader(al1, dr, ds, X->varfilter(mloglik, X))
    if MPI.Comm_rank(comm)==0 
        @time Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    else
        Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
end

# ##save models and validation FEs
# name = "PEDS10_example"
# if isleader
#     BSON.@save "$(name)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(name)_K$(kval)_cw$(model_color).bson" cw
#     BSON.@save "$(name)_K$(kval)_mvar$(model_color).bson" mvar
# end
# if MPI.Comm_rank(comm) == 0
#     writedlm("$(name)_K$(kval)_validation_fes.csv", validation_fes, ',')
# end

kval = 128


┌ Info: Epoch 1
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.8704552623350917


┌ Info: Epoch 2
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.8221926875215404


┌ Info: Epoch 3
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.7027776554085374


┌ Info: Epoch 4
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.5473956441382966


┌ Info: Epoch 5
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.38630812923964514


┌ Info: Epoch 6
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.38606860895747364


┌ Info: Epoch 7
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.2932033751384104


┌ Info: Epoch 8
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.3175109741931523


┌ Info: Epoch 9
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.2668221298945284


┌ Info: Epoch 10
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.267150512771662
642.674078 seconds (91.85 M allocations: 532.067 GiB, 7.72% gc time, 3.38% compilation time)


In [5]:
dFE(comm, commModel, commLeader, m, valid=test)

FE = 0.27713426133559715


0.27713426133559715

## PEDS result

The fractional error on the test set is 0.278.

## Baseline

In [6]:
##define same AL parameters for all workers
MPI.Barrier(comm)
al1 = ALstruct(Ninit=256+8*kval, T=0);

if MPI.Comm_rank(comm) == 0
    @show kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
(mgen, pred, mvar) = initbase(nn)
mloglik(p) =  vcat(pred(mgen(p)), mvar(mgen(p)))
m(p) = pred(mgen(p))
ps = Flux.params(mgen, pred, mvar)
loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)
##train baseline
if MPI.Comm_rank(comm)==0
    @time Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
else
    Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
end
##active learning loop
for t=1:al1.T
    MPI.Comm_rank(comm) == 0 && @show t
    loader= getloader(al1, dr, ds, X->varfilter(mloglik, X)) 
    if MPI.Comm_rank(comm)==0 
        @time Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    else
        Flux.@epochs al1.ne train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
end

# ##save models and validation FEs
# name = "baseline10_noal_example"
# if isleader
#     BSON.@save "$(name)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(name)_K$(kval)_pred$(model_color).bson" pred
#     BSON.@save "$(name)_K$(kval)_mvar$(model_color).bson" mvar
# end
# if MPI.Comm_rank(comm) == 0
#     writedlm("$(name)_K$(kval)_validation_fes.csv", validation_fes, ',')
# end

kval = 128


┌ Info: Epoch 1
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.858258856094463


┌ Info: Epoch 2
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.7809582822116725


┌ Info: Epoch 3
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.703747328578164


┌ Info: Epoch 4
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.6416452920266059


┌ Info: Epoch 5
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.6175040164581953


┌ Info: Epoch 6
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.565809767735542


┌ Info: Epoch 7
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.5197141398355253


┌ Info: Epoch 8
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.545490018445373


┌ Info: Epoch 9
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.5087716880945948


┌ Info: Epoch 10
└ @ Main /Users/raphaelpestourie/.julia/packages/Flux/EXOFx/src/optimise/train.jl:154


FE = 0.5260581339153209
111.182660 seconds (7.58 M allocations: 138.211 GiB, 14.31% gc time, 1.59% compilation time)


In [7]:
dFE(comm, commModel, commLeader, m, valid=test)

FE = 0.5412049574053728


0.5412049574053728

## Baseline result 

The fractional error on the test set is 0.541.

# Overall result

With about 1000 data points, a single model PEDS model leads to a 1.9x improvement compared to the NN-only baseline. For the Maxwell surrogate, in contrast to diffusion and reaction-diffusion equations, more points are needed to achieve an accuracy comparable to fabrication error (< 5%).