# PEDS example: diffusion surrogate using a diffusion equation coarse solver

This example of PEDS shows how to use the code from: *Pestourie, Raphaël, et al. "Physics-enhanced deep surrogates for PDEs."*  In this example notebook, we illustrate PEDS for the surrogate of the diffusion equation's flux where the approximate solver is a coarse diffusion equation's solver. For this example, the input is the side length of 25 holes in a material.

NB: since the approximate solver is the same for the reaction-diffusion and the diffusion surrogates, training the surrogate models with the 25 holes structures for the reaction-diffusion equation necessitates only to change the dataset from `X/y_fourier25_small.csv` to `X/y_fisher25_small.csv`.
    
## Note on implementation

In this notebook, we run the training on a single processor and J=1.

However, the implementation of this code is meant to be parallel on multiple processors, where the element of a batch is computed in parallel over J groups of CPUs (J is the number models in the ensemble). For example, with a julia file `train_PEDS.jl`containing functions calls similart to the code in the cells below, the command to run large-scale training on a cluster would be:
    
`mpiexec -n 320 julia train_PEDS.jl` 

In [10]:
#initialize functions

include("../src/PEDS.jl")
include("../src/fourier_solver.jl")

##training functionalities
function train_distributed!(comm, commModel, commLeader, mloglik, m, loss, ps, loader, opt, validation_fes; logging=false)

    train_lossmean = 0.
    k = 0
    for d in loader
        X,y = d #change for Fourier
        train_loss, back = Zygote.pullback(() -> loss(commModel, mloglik, m, X, y), ps)
        train_lossmean *= k
        k+=1
        train_lossmean += train_loss
        train_lossmean /= k

        gs = back(Float32(1.))
        if debug && isleader
            if isnan(train_loss)
                @show (model_color, train_loss)
            end
        end
        for x in ps
            gs[x][:] .= (sum_reduce(commModel, (gs[x][:]) ))
            if debug && isleader
                if any(isnan.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isnan,x)), length(findall(isnan,gs[x][:])))
                end

                if any(isinf.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isinf,x)), length(findall(isinf,gs[x][:])))
                end
            end
        end

        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "before update")
                end
            end
        end
        Flux.Optimise.update!(opt, ps, gs)
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "after update")
                end
            end
        end
    end

    # logging && push!(validation_fes, dFE(comm, commModel, commLeader, m))
    #changing for fourier
    if logging
        if MPI.Comm_rank(comm) == 0
            @show train_lossmean
            push!(validation_fes, train_lossmean)
        end
        MPI.Barrier(comm)
        valid_lossMSE = lossvalid(comm, commModel, commLeader, mloglik, m)
        valid_lossFE = dFE_1d(comm, commModel, commLeader, m)
        if MPI.Comm_rank(comm) == 0
            push!(validation_fes, valid_lossMSE)
            push!(validation_fes, valid_lossFE)
        end
    end
end

function pore(p, resolution; lowval=0.1)

    if resolution % 2==0
        inner = p/2*resolution ÷ 1
        w =  (p*resolution)/2 % 1

        isqb = Int((resolution/2+inner) ÷ 1) # square index to the right
        isqa = Int((resolution/2-inner) ÷ 1) # square index to the left
    else
        inner = p/2*resolution ÷ 1
        w = ((p*resolution)/2 % 1)
        w*= resolution ==1 ? 2 : 1
        isqb = Int((resolution/2+inner) ÷ 1) # square index to the right
        isqa = Int((resolution/2-inner) ÷ 1)+1 # square index to the left
    end

    begin
        A1 = ones((resolution, resolution))
        A2 = ones((resolution, resolution))
        A1[isqa+1:isqb, isqa+1:isqb] .= lowval

        if isqa==0 && w==0
            return A1
        elseif isqa==0 && w!=0
            return A1 .* (1-w)
        else
            A2[isqa:isqb+1, isqa:isqb+1] .= lowval
            return A1 .* (1-w) .+ A2 .* w
        end
    end

end

function generatepores(ps, resolution)
    n = sqrt(length(ps))
    N = resolution÷n
    @assert n == Int(n)
    @assert N == Int(N)
    n = Int(n)
    N = Int(N)
    pores = zeros((resolution, resolution))
    for (k,p) in enumerate(ps)
        i = (k-1)÷n
        j = (k-1)%n
        pores[i*N+1:(i+1)*N, j*N+1:(j+1)*N] = pore(p, N)
    end
    return pores
end

ChainRules.@non_differentiable generatepores(ps, resolution)



In [11]:
startp, endp = 7, 7
PEDSname = "PEDS"
equation = "examplefourier"

ninit = 64
Jval = 1

cwval = 0.05
learningrate = 5e-5
nnodes =128

optimizerfunc, optname = Flux.ADAM, "ADAM"

number_epoch = 200
batchsize = 64

datalimitvalid = 2^10

Ks = Int[2^i for i=startp:endp]

coarseresolution = 4
Lx = Ly = 1
nin=Int(coarseresolution^2)
coarse_sim = setsimulation(Lx, Ly, coarseresolution);
nout = coarseresolution^2
const avgOpcoarse = create_cOp(coarseresolution);

endname = "$(equation)_$(2^startp)to$(2^endp)_ne$(number_epoch)_lr$(learningrate)_bs$(batchsize)_opt$(optname)_hn$(nnodes)_coarse$(coarseresolution)_seed";

##loading data
X = readdlm("../data/X_fourier$(coarseresolution^2)_small.csv", ',')
y = readdlm("../data/y_fourier$(coarseresolution^2)_small.csv", ',')[:];

Xv = X[:, 1:1024] #valid set
Xtest = X[:, end-1023:end] #test set
Xt = X[:, 1025:end]

yv = y[1:1024] #valid set
ytest = y[end-1023:end] #test set
yt = y[1025:end];

In [12]:
##Definition of problem constants
const debug = false
const drv = DataRunner(Xv, yv, [1]);
const al = ALstruct(J=Jval, Nvalid=datalimitvalid, batchsize=batchsize);
const valid = initvalid(al, drv) #validation loader

const drtest = DataRunner(Xtest, ytest, [1]);
const test = initvalid(al, drtest) #validation loader

##setup MPI and random
const comm = MPI.COMM_WORLD
const model_color = MPI.Comm_rank(comm)%al.J
const commModel = MPI.Comm_split(comm, model_color, 0)
const isleader = MPI.Comm_rank(commModel) == 0
const commLeader = MPI.Comm_split(comm, isleader, 0)
debug && print("Comm rank=$(MPI.Comm_rank(comm)), commModel rank = $(MPI.Comm_rank(commModel)), commLeader rank = $(MPI.Comm_rank(commLeader))\n")

Random.seed!(2138*(model_color+1)) #alter seed for different groups

function lossvalid(comm, commModel, commLeader, mloglik, m)
    valid_lossmean = 0.
    k = 0
    testmode!(m, true)
    for d in valid
        X,y=d
        valid_loss = dMSE(commModel, mloglik, m, X, y)

        valid_lossmean *= k
        k+=1
        valid_lossmean += valid_loss
        valid_lossmean /= k
    end
    testmode!(m, false)
    # MPI.Comm_rank(commModel)== 0 && @show valid_lossmean
    valid_lossmean = sum_reduce(commLeader, valid_lossmean) / al.J
    MPI.Comm_rank(comm) == 0 && @show valid_lossmean
    return valid_lossmean
end

function dFE_1d(comm, commModel, commLeader, m; valid=valid) """dFE computes the FE using parallelization over the batch with MPI"""
    evalsr = zeros(al.Nvalid)
    FE = 0.
    j=0
    ys = []
    testmode!(m, true)
    for (x, y) in valid
                rp = m(x[:,1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):length(y)])[:]
        evalsr[j*length(y)+1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):j*length(y)+length(y)] .= rp
        push!(ys, y...)
        j+=1
    end
    testmode!(m, false)
    evalsrModel = sum_reduce(commModel, evalsr)
    evalsr = sum_reduce(commLeader, evalsrModel) / al.J
    if MPI.Comm_rank(comm) == 0
        FE = norm(evalsr - ys)/norm(ys)
        MPI.Comm_rank(comm) == 0 && @show FE
    end
    return FE
end

dFE_1d (generic function with 1 method)

# PEDS

In [13]:
namemodel = "local_MSE_parallel_PEDSparameterized_$endname"

kval =  Ks[1]
##define same AL parameters for all workers
al1 = ALstruct(Ninit=ninit+8*kval, T=0, η=1e-3, J=Jval, ne=number_epoch, batchsize=batchsize);
if MPI.Comm_rank(comm) == 0
    @show kval
    @show 256+8*kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
mgen =  Chain(Dense(nin, nnodes, relu),
Dropout(0.5),
Dense(nnodes, nnodes,relu),
Dropout(0.5),
Dense(nnodes, coarseresolution^2, hardtanh),
x-> @. x*0.9/2 + 0.45 + .1)

cw=[cwval]

m(p) = begin
    generatedgeom = mgen(p)
    w = @. max(0, min(1, cw)) # harder function.
    @views map(i->Float32(targetfunc(avgOpcoarse * (w .* generatedgeom[:,i] + (1 .-w) .* generatepores(p[:, i],coarseresolution)[:]), sim=coarse_sim)), 1:length(p)÷ nin)
end
mloglik(p) = 1
ps = Flux.params(mgen, cw)


loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)

if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
end


# ##save models and validation FEs
# if isleader
#     BSON.@save "$(namemodel)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(namemodel)_K$(kval)_cw$(model_color).bson" cw
# end

# if MPI.Comm_rank(comm) == 0
#     writedlm("$(namemodel)_K$(kval)_trainingMSE.csv", validation_fes, ',')
# end

# # evaluate and show the error on the test set
# valid_lossmean = dFE_1d(comm, commModel, commLeader, m)
# if MPI.Comm_rank(comm) == 0
#     @show valid_lossmean
#     writedlm("$(namemodel)_K$(kval)_endvalidFE.csv", [valid_lossmean], ',')
# end

kval = 128
256 + 8kval = 1280


[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(16 => 128, relu)  [90m# 2_176 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "16×64 Matrix{Float64}"
[33m[1m└ [22m[39m[90m@ Flux C:\Users\XMhua\.julia\packages\Flux\CUn7U\src\layers\stateless.jl:60[39m


train_lossmean = 6.595788668749035e-5
valid_lossmean = 0.0047059730094293865
FE = 0.11343876735389033
train_lossmean = 5.071250014826823e-5
valid_lossmean = 0.0033040698215197383
FE = 0.09505202878676187
train_lossmean = 4.4103692854615465e-5
valid_lossmean = 0.0028588925543352457
FE = 0.08841698460545593
train_lossmean = 4.124046449685829e-5
valid_lossmean = 0.0026160269837089104
FE = 0.08457809370312441
train_lossmean = 3.9693472079149996e-5
valid_lossmean = 0.0025133140301987177
FE = 0.08290107468677564
train_lossmean = 3.924708041711712e-5
valid_lossmean = 0.0024707965201275586
FE = 0.08219686867959194
train_lossmean = 3.910843083785664e-5
valid_lossmean = 0.002459056731764226
FE = 0.08200136030517251
train_lossmean = 3.902580560590542e-5
valid_lossmean = 0.0024558390528505707
FE = 0.08194769330058632
train_lossmean = 3.900507863017524e-5
valid_lossmean = 0.002456321986816268
FE = 0.08195575029826624
train_lossmean = 3.900324035038499e-5
valid_lossmean = 0.0024558801326775436
FE = 

valid_lossmean = 0.0005776059067207955
FE = 0.039742263237635814
train_lossmean = 2.0697162371085624e-5
valid_lossmean = 0.0006378989318274087
FE = 0.04176502254733263
train_lossmean = 2.0021794441031725e-5
valid_lossmean = 0.0005768175100421678
FE = 0.03971513109869016
train_lossmean = 2.0079936913504968e-5
valid_lossmean = 0.0005623019789062154
FE = 0.03921223427684718
train_lossmean = 2.036064073826076e-5
valid_lossmean = 0.0005717001318393029
FE = 0.03953856734943354
train_lossmean = 1.9344257976701382e-5
valid_lossmean = 0.0005251900857794085
FE = 0.03789614577436982
train_lossmean = 1.9906767219386223e-5
valid_lossmean = 0.0005660469023662113
FE = 0.039342594059424155
train_lossmean = 1.9896287163606693e-5
valid_lossmean = 0.0005365762882969732
FE = 0.03830474021058075
train_lossmean = 1.8800288632070433e-5
valid_lossmean = 0.0005496728457834585
FE = 0.038769386109833125
train_lossmean = 1.9430991278744367e-5
valid_lossmean = 0.0005347288349825039
FE = 0.03823874098678345
train_l

train_lossmean = 1.5651732723027207e-5
valid_lossmean = 0.00041475953154985313
FE = 0.033677122013716324
train_lossmean = 1.6423398101835065e-5
valid_lossmean = 0.0003803109826873674
FE = 0.03224825478123329
train_lossmean = 1.5321334657763586e-5
valid_lossmean = 0.00042140831817508113
FE = 0.03394597872408556
train_lossmean = 1.630618881788998e-5
valid_lossmean = 0.0003967101696895372
FE = 0.032936196848882056
train_lossmean = 1.547670065826931e-5
valid_lossmean = 0.0003893114055079552
FE = 0.03262761630015607
train_lossmean = 1.549222430811024e-5
valid_lossmean = 0.000425217849272203
FE = 0.034099069337557415
train_lossmean = 1.5711224412477095e-5
valid_lossmean = 0.00040939786578059406
FE = 0.03345873905791497
train_lossmean = 1.4871146521836443e-5
valid_lossmean = 0.00042137395944362823
FE = 0.03394459483524022
train_lossmean = 1.5604355761510547e-5
valid_lossmean = 0.00041983822860909494
FE = 0.03388268148689505
train_lossmean = 1.549758220603816e-5
valid_lossmean = 0.000422811104

In [14]:
dFE_1d(comm, commModel, commLeader, m, valid=test)

FE = 0.033832562167824085


0.033832562167824085

## PEDS result

The fractional error on the test set is 0.0334. Surprisingly, it is slightly better than the result for an ensemble of 5 models (which achieves an error of 3.8%).

# Baseline 

In [15]:
namemodel = "local_MSE_parallel_baselineparameterized_$endname"

##define same AL parameters for all workers
al1 = ALstruct(Ninit=ninit+8*kval, T=0, η=1e-3, J=Jval, ne=number_epoch, batchsize=batchsize);
if MPI.Comm_rank(comm) == 0
    @show kval
    @show 256+8*kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
mgen =  Chain(Dense(nin, nnodes, relu),
Dropout(0.5),
Dense(nnodes, nnodes, relu),
Dropout(0.5),
Dense(nnodes, coarseresolution^2, relu),
Dense(coarseresolution^2, 1, hardtanh),
x-> @. x[:]*0.9/2 + 0.45 + .1)

cw=[cwval]

m(p) = begin
    mgen(p)
end
mloglik(p) = 1
ps = Flux.params(mgen)


loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)

if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
end


# ##save models and validation FEs
# if isleader
#     BSON.@save "$(namemodel)_K$(kval)_mgen$(model_color).bson" mgen
# end

# if MPI.Comm_rank(comm) == 0
#     writedlm("$(namemodel)_K$(kval)_trainingMSE.csv", validation_fes, ',')
# end

# # evaluate and show the error on the test set
# valid_lossmean = dFE_1d(comm, commModel, commLeader, m)
# if MPI.Comm_rank(comm) == 0
#     @show valid_lossmean
#     writedlm("$(namemodel)_K$(kval)_endvalidFE.csv", [valid_lossmean], ',')
# end



kval = 128
256 + 8kval = 1280
train_lossmean = 8.368385148995778e-5
valid_lossmean = 0.007339392529853562
FE = 0.14166639434659292
train_lossmean = 7.142089213349383e-5
valid_lossmean = 0.007419936700301887
FE = 0.14244161298232885
train_lossmean = 6.918667806639829e-5
valid_lossmean = 0.007141174621812882
FE = 0.139740280026674
train_lossmean = 6.712250528941349e-5
valid_lossmean = 0.007024524506395929
FE = 0.13859426151637347
train_lossmean = 6.690880123394466e-5
valid_lossmean = 0.006806376152412716
FE = 0.1364252492870103
train_lossmean = 6.59262556685134e-5
valid_lossmean = 0.0067199236827887275
FE = 0.13555606493020808
train_lossmean = 6.524368673006851e-5
valid_lossmean = 0.006533860293929803
FE = 0.1336662317602876
train_lossmean = 6.322947683153269e-5
valid_lossmean = 0.006348083630613328
FE = 0.1317522687840615
train_lossmean = 6.191362016954071e-5
valid_lossmean = 0.006134216570415839
FE = 0.12951388623528937
train_lossmean = 6.150326562985147e-5
valid_lossmean = 0.006062994

train_lossmean = 2.8707024386317372e-5
valid_lossmean = 0.0016395637960505281
FE = 0.06695776405766539
train_lossmean = 2.6992021218622655e-5
valid_lossmean = 0.001806317086288704
FE = 0.07028032768611543
train_lossmean = 2.819025738150143e-5
valid_lossmean = 0.0016420725009071494
FE = 0.06700897068398873
train_lossmean = 2.830537930079218e-5
valid_lossmean = 0.0015047410308714662
FE = 0.06414571707680886
train_lossmean = 2.5930428738607737e-5
valid_lossmean = 0.0014544262737527167
FE = 0.0630641633411193
train_lossmean = 2.7171838569892404e-5
valid_lossmean = 0.001542548081260407
FE = 0.06494655772377776
train_lossmean = 2.5836765837389017e-5
valid_lossmean = 0.0016969460435114366
FE = 0.0681193977308055
train_lossmean = 2.688373864730447e-5
valid_lossmean = 0.001544804620209899
FE = 0.06499404437290145
train_lossmean = 2.6965829132600417e-5
valid_lossmean = 0.001499869638736432
FE = 0.06404180143705919
train_lossmean = 2.641909236850481e-5
valid_lossmean = 0.001541908843121232
FE = 0

train_lossmean = 2.040395616191759e-5
valid_lossmean = 0.00185499004901549
FE = 0.07122091925396312
train_lossmean = 2.1084276559940916e-5
valid_lossmean = 0.0015963013892518053
FE = 0.066068466793681
train_lossmean = 2.1173184177816565e-5
valid_lossmean = 0.0016558647477751386
FE = 0.06728979619798706
train_lossmean = 2.1282625732178136e-5
valid_lossmean = 0.0016538443624231666
FE = 0.06724873221194663
train_lossmean = 1.999627664185319e-5
valid_lossmean = 0.0019933967835475633
FE = 0.07383013445362988
train_lossmean = 2.1576141138830242e-5
valid_lossmean = 0.0016596580962769773
FE = 0.06736682773929453
train_lossmean = 2.0784997451234733e-5
valid_lossmean = 0.0018178321980462135
FE = 0.07050398723825152
train_lossmean = 1.995651334422236e-5
valid_lossmean = 0.0014991865872878924
FE = 0.06402721722735231
train_lossmean = 2.0716006593936913e-5
valid_lossmean = 0.0018001532049543702
FE = 0.07016031282326102
train_lossmean = 2.0782498674371137e-5
valid_lossmean = 0.0015536366845318175
FE

In [16]:
dFE_1d(comm, commModel, commLeader, m, valid=test)

FE = 0.06909264800142227


0.06909264800142227

## Baseline result 

The fractional error on the test set is 0.0597, which is worse than the ensemble of 5 models (with an error of 4.7%).

# Overall result

With about 1000 data points, a single model PEDS model achieves an error < 5% and leads to a 1.8x improvement compared to the NN-only baseline. 