# PEDS example: Reaction--diffusion surrogate using a diffusion equation coarse solver

This example of PEDS shows how to use the code from: *Pestourie, Raphaël, et al. "Physics-enhanced deep surrogates for PDEs."*  In this example notebook, we illustrate PEDS for the surrogate of the reaction--diffusion equation's flux where the approximate solver is a coarse diffusion equation's solver. For this example, the input is the side length of 16 holes in a material.

NB: since the approximate solver is the same for the reaction-diffusion and the diffusion surrogates, training the surrogate models with the 25 holes structures for the reaction-diffusion equation necessitates only to change the dataset from `X/y_fisher16_small.csv` to `X/y_fourier16_small.csv`.
    
## Note on implementation

In this notebook, we run the training on a single processor and J=1.

However, the implementation of this code is meant to be parallel on multiple processors, where the element of a batch is computed in parallel over J groups of CPUs (J is the number models in the ensemble). For example, with a julia file `train_PEDS.jl`containing functions calls similart to the code in the cells below, the command to run large-scale training on a cluster would be:
    
`mpiexec -n 320 julia train_PEDS.jl` 

In [8]:
#initialize functions

include("../src/PEDS.jl")
include("../src/fourier_solver.jl")

##training functionalities
function train_distributed!(comm, commModel, commLeader, mloglik, m, loss, ps, loader, opt, validation_fes; logging=false)

    train_lossmean = 0.
    k = 0
    for d in loader
        X,y = d #change for Fourier
        train_loss, back = Zygote.pullback(() -> loss(commModel, mloglik, m, X, y), ps)
        train_lossmean *= k
        k+=1
        train_lossmean += train_loss
        train_lossmean /= k

        gs = back(Float32(1.))
        if debug && isleader
            if isnan(train_loss)
                @show (model_color, train_loss)
            end
        end
        for x in ps
            gs[x][:] .= (sum_reduce(commModel, (gs[x][:]) ))
            if debug && isleader
                if any(isnan.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isnan,x)), length(findall(isnan,gs[x][:])))
                end

                if any(isinf.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isinf,x)), length(findall(isinf,gs[x][:])))
                end
            end
        end

        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "before update")
                end
            end
        end
        Flux.Optimise.update!(opt, ps, gs)
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "after update")
                end
            end
        end
    end

    # logging && push!(validation_fes, dFE(comm, commModel, commLeader, m))
    #changing for fourier
    if logging
        if MPI.Comm_rank(comm) == 0
            @show train_lossmean
            push!(validation_fes, train_lossmean)
        end
        MPI.Barrier(comm)
        valid_lossMSE = lossvalid(comm, commModel, commLeader, mloglik, m)
        valid_lossFE = dFE_1d(comm, commModel, commLeader, m)
        if MPI.Comm_rank(comm) == 0
            push!(validation_fes, valid_lossMSE)
            push!(validation_fes, valid_lossFE)
        end
    end
end

function pore(p, resolution; lowval=0.1)

    if resolution % 2==0
        inner = p/2*resolution ÷ 1
        w =  (p*resolution)/2 % 1

        isqb = Int((resolution/2+inner) ÷ 1) # square index to the right
        isqa = Int((resolution/2-inner) ÷ 1) # square index to the left
    else
        inner = p/2*resolution ÷ 1
        w = ((p*resolution)/2 % 1)
        w*= resolution ==1 ? 2 : 1
        isqb = Int((resolution/2+inner) ÷ 1) # square index to the right
        isqa = Int((resolution/2-inner) ÷ 1)+1 # square index to the left
    end

    begin
        A1 = ones((resolution, resolution))
        A2 = ones((resolution, resolution))
        A1[isqa+1:isqb, isqa+1:isqb] .= lowval

        if isqa==0 && w==0
            return A1
        elseif isqa==0 && w!=0
            return A1 .* (1-w)
        else
            A2[isqa:isqb+1, isqa:isqb+1] .= lowval
            return A1 .* (1-w) .+ A2 .* w
        end
    end

end

function generatepores(ps, resolution)
    n = sqrt(length(ps))
    N = resolution÷n
    @assert n == Int(n)
    @assert N == Int(N)
    n = Int(n)
    N = Int(N)
    pores = zeros((resolution, resolution))
    for (k,p) in enumerate(ps)
        i = (k-1)÷n
        j = (k-1)%n
        pores[i*N+1:(i+1)*N, j*N+1:(j+1)*N] = pore(p, N)
    end
    return pores
end

ChainRules.@non_differentiable generatepores(ps, resolution)



In [9]:
startp, endp = 7, 7
PEDSname = "PEDS"
equation = "exampleFisher"

ninit = 64
Jval = 1

cwval = 0.05
learningrate = 5e-5
nnodes =128

optimizerfunc, optname = Flux.ADAM, "ADAM"

number_epoch = 200
batchsize = 64

datalimitvalid = 2^10

Ks = Int[2^i for i=startp:endp]

coarseresolution = 4
Lx = Ly = 1
nin=Int(coarseresolution^2)
coarse_sim = setsimulation(Lx, Ly, coarseresolution);
nout = coarseresolution^2
const avgOpcoarse = create_cOp(coarseresolution);

endname = "$(equation)_$(2^startp)to$(2^endp)_ne$(number_epoch)_lr$(learningrate)_bs$(batchsize)_opt$(optname)_hn$(nnodes)_coarse$(coarseresolution)_seed";

##loading data
X = readdlm("../data/X_fisher$(coarseresolution^2)_small.csv", ',')
y = readdlm("../data/y_fisher$(coarseresolution^2)_small.csv", ',')[:];

Xv = X[:, 1:1024] #valid set
Xtest = X[:, end-1023:end] #test set
Xt = X[:, 1025:end]

yv = y[1:1024] #valid set
ytest = y[end-1023:end] #test set
yt = y[1025:end];

In [10]:
##Definition of problem constants
const debug = false
const drv = DataRunner(Xv, yv, [1]);
const al = ALstruct(J=Jval, Nvalid=datalimitvalid, batchsize=batchsize);
const valid = initvalid(al, drv) #validation loader

const drtest = DataRunner(Xtest, ytest, [1]);
const test = initvalid(al, drtest) #validation loader

##setup MPI and random
const comm = MPI.COMM_WORLD
const model_color = MPI.Comm_rank(comm)%al.J
const commModel = MPI.Comm_split(comm, model_color, 0)
const isleader = MPI.Comm_rank(commModel) == 0
const commLeader = MPI.Comm_split(comm, isleader, 0)
debug && print("Comm rank=$(MPI.Comm_rank(comm)), commModel rank = $(MPI.Comm_rank(commModel)), commLeader rank = $(MPI.Comm_rank(commLeader))\n")

Random.seed!(2138*(model_color+1)) #alter seed for different groups

function lossvalid(comm, commModel, commLeader, mloglik, m)
    valid_lossmean = 0.
    k = 0
    testmode!(m, true)
    for d in valid
        X,y=d
        valid_loss = dMSE(commModel, mloglik, m, X, y)

        valid_lossmean *= k
        k+=1
        valid_lossmean += valid_loss
        valid_lossmean /= k
    end
    testmode!(m, false)
    # MPI.Comm_rank(commModel)== 0 && @show valid_lossmean
    valid_lossmean = sum_reduce(commLeader, valid_lossmean) / al.J
    MPI.Comm_rank(comm) == 0 && @show valid_lossmean
    return valid_lossmean
end

function dFE_1d(comm, commModel, commLeader, m; valid=valid) """dFE computes the FE using parallelization over the batch with MPI"""
    evalsr = zeros(al.Nvalid)
    FE = 0.
    j=0
    ys = []
    testmode!(m, true)
    for (x, y) in valid
                rp = m(x[:,1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):length(y)])[:]
        evalsr[j*length(y)+1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):j*length(y)+length(y)] .= rp
        push!(ys, y...)
        j+=1
    end
    testmode!(m, false)
    evalsrModel = sum_reduce(commModel, evalsr)
    evalsr = sum_reduce(commLeader, evalsrModel) / al.J
    if MPI.Comm_rank(comm) == 0
        FE = norm(evalsr - ys)/norm(ys)
        MPI.Comm_rank(comm) == 0 && @show FE
    end
    return FE
end

dFE_1d (generic function with 1 method)

# PEDS

In [11]:
namemodel = "local_MSE_parallel_PEDSparameterized_$endname"

kval =  Ks[1]
##define same AL parameters for all workers
al1 = ALstruct(Ninit=ninit+8*kval, T=0, η=1e-3, J=Jval, ne=number_epoch, batchsize=batchsize);
if MPI.Comm_rank(comm) == 0
    @show kval
    @show 256+8*kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
mgen =  Chain(Dense(nin, nnodes, relu),
Dropout(0.5),
Dense(nnodes, nnodes,relu),
Dropout(0.5),
Dense(nnodes, coarseresolution^2, hardtanh),
x-> @. x*0.9/2 + 0.45 + .1)

cw=[cwval]

m(p) = begin
    generatedgeom = mgen(p)
    w = @. max(0, min(1, cw)) # harder function.
    @views map(i->Float32(targetfunc(avgOpcoarse * (w .* generatedgeom[:,i] + (1 .-w) .* generatepores(p[:, i],coarseresolution)[:]), sim=coarse_sim)), 1:length(p)÷ nin)
end
mloglik(p) = 1
ps = Flux.params(mgen, cw)


loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)

if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
end


# ##save models and validation FEs
# if isleader
#     BSON.@save "$(namemodel)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(namemodel)_K$(kval)_cw$(model_color).bson" cw
# end

# if MPI.Comm_rank(comm) == 0
#     writedlm("$(namemodel)_K$(kval)_trainingMSE.csv", validation_fes, ',')
# end

# # evaluate and show the error on the test set
# valid_lossmean = dFE_1d(comm, commModel, commLeader, m)
# if MPI.Comm_rank(comm) == 0
#     @show valid_lossmean
#     writedlm("$(namemodel)_K$(kval)_endvalidFE.csv", [valid_lossmean], ',')
# end

kval = 128
256 + 8kval = 1280


[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(16 => 128, relu)  [90m# 2_176 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "16×64 Matrix{Float64}"
[33m[1m└ [22m[39m[90m@ Flux C:\Users\XMhua\.julia\packages\Flux\CUn7U\src\layers\stateless.jl:60[39m


train_lossmean = 7.021580160312698e-5
valid_lossmean = 0.006008534163016794
FE = 0.1590898464065475
train_lossmean = 6.343930880590507e-5
valid_lossmean = 0.005388712722646037
FE = 0.1506609532464366
train_lossmean = 6.073727535759983e-5
valid_lossmean = 0.005193160472614931
FE = 0.14790200674054596
train_lossmean = 5.984294943382934e-5
valid_lossmean = 0.00507337551979387
FE = 0.14618630857488982
train_lossmean = 5.881582051025826e-5
valid_lossmean = 0.004932900903069938
FE = 0.14414825542847862
train_lossmean = 5.766853779446097e-5
valid_lossmean = 0.004404720305487411
FE = 0.13621262528719955
train_lossmean = 5.4972617689740483e-5
valid_lossmean = 0.0036726103566856487
FE = 0.1243785880670095
train_lossmean = 5.062316556345056e-5
valid_lossmean = 0.002970007471283161
FE = 0.11185025178143426
train_lossmean = 4.53345106193638e-5
valid_lossmean = 0.0023007961384663325
FE = 0.09844581643782402
train_lossmean = 4.007355397117606e-5
valid_lossmean = 0.002021068690385755
FE = 0.0922674954

FE = 0.04697842306284541
train_lossmean = 2.031708815978466e-5
valid_lossmean = 0.0005126656835754538
FE = 0.046470283194049306
train_lossmean = 2.0558743717219942e-5
valid_lossmean = 0.0004735862200512652
FE = 0.04466401120109804
train_lossmean = 1.9842549250704347e-5
valid_lossmean = 0.00047750983662163404
FE = 0.04484864809719326
train_lossmean = 2.0816670520071465e-5
valid_lossmean = 0.00043920172025215795
FE = 0.04301205687336992
train_lossmean = 1.962913723309387e-5
valid_lossmean = 0.00043175516730146644
FE = 0.04264586886451717
train_lossmean = 2.004006269008572e-5
valid_lossmean = 0.0004904935754397748
FE = 0.04545428773497227
train_lossmean = 2.03281925021446e-5
valid_lossmean = 0.0005717560592315567
FE = 0.04907537004732151
train_lossmean = 2.0036339737253605e-5
valid_lossmean = 0.0005485757363012857
FE = 0.04807026236517016
train_lossmean = 2.04534740004245e-5
valid_lossmean = 0.0005282339804004833
FE = 0.04717059590871703
train_lossmean = 2.0232054677847735e-5
valid_lossme

FE = 0.039766216275705306
train_lossmean = 1.706053388331149e-5
valid_lossmean = 0.00033635487918320543
FE = 0.037640645377522715
train_lossmean = 1.6769591198441322e-5
valid_lossmean = 0.0003264789011504849
FE = 0.037083930216707806
train_lossmean = 1.6916289821128585e-5
valid_lossmean = 0.0003057387437866761
FE = 0.035886692513839226
train_lossmean = 1.6577596330912887e-5
valid_lossmean = 0.00030283045828460474
FE = 0.03571560177950021
train_lossmean = 1.6677859354442085e-5
valid_lossmean = 0.0003251333137242758
FE = 0.03700743033904881
train_lossmean = 1.6366124193373574e-5
valid_lossmean = 0.0003213265754372833
FE = 0.03679014655891946
train_lossmean = 1.6818932637922823e-5
valid_lossmean = 0.0003102249284112057
FE = 0.036149021130968455
train_lossmean = 1.5676108145620785e-5
valid_lossmean = 0.0003343751126704042
FE = 0.037529706496493904
train_lossmean = 1.6033735419035667e-5
valid_lossmean = 0.000385258880433555
FE = 0.04028417966849981
train_lossmean = 1.591927356730325e-5
vali

In [12]:
dFE_1d(comm, commModel, commLeader, m, valid=test)

FE = 0.042173400393197905


0.042173400393197905

## PEDS result

The fractional error on the test set is 0.0481.

# Baseline 

In [13]:
namemodel = "local_MSE_parallel_baselineparameterized_$endname"

##define same AL parameters for all workers
al1 = ALstruct(Ninit=ninit+8*kval, T=0, η=1e-3, J=Jval, ne=number_epoch, batchsize=batchsize);
if MPI.Comm_rank(comm) == 0
    @show kval
    @show 256+8*kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
# dr = DataRunner(generateX(Xt), yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
mgen =  Chain(Dense(nin, nnodes, relu),
Dropout(0.5),
#BatchNorm(256, relu),
Dense(nnodes, nnodes, relu),
Dropout(0.5),
#BatchNorm(256, relu),
Dense(nnodes, coarseresolution^2, relu),
#BatchNorm(coarseresolution^2, hardtanh),
Dense(coarseresolution^2, 1, hardtanh),
x-> @. x[:]*0.9/2 + 0.45 + .1)

cw=[cwval]

m(p) = begin
    mgen(p)
end
mloglik(p) = 1
ps = Flux.params(mgen)


loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)

if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dHuber, ps, loader, opt, validation_fes, logging=true)
    end
end


# ##save models and validation FEs
# if isleader
#     BSON.@save "$(namemodel)_K$(kval)_mgen$(model_color).bson" mgen
# end

# if MPI.Comm_rank(comm) == 0
#     writedlm("$(namemodel)_K$(kval)_trainingMSE.csv", validation_fes, ',')
# end

# # evaluate and show the error on the test set
# valid_lossmean = dFE_1d(comm, commModel, commLeader, m)
# if MPI.Comm_rank(comm) == 0
#     @show valid_lossmean
#     writedlm("$(namemodel)_K$(kval)_endvalidFE.csv", [valid_lossmean], ',')
# end



kval = 128
256 + 8kval = 1280
train_lossmean = 0.00010173447728370769
valid_lossmean = 0.010955924627230315
FE = 0.21482395756305167
train_lossmean = 8.159521361515487e-5
valid_lossmean = 0.010607992879124034
FE = 0.21138531084184478
train_lossmean = 7.828849764529866e-5
valid_lossmean = 0.009230786295051112
FE = 0.19718666570905968
train_lossmean = 7.333193996959793e-5
valid_lossmean = 0.008062804025426005
FE = 0.18428977810593444
train_lossmean = 7.290222037833497e-5
valid_lossmean = 0.008826443321329798
FE = 0.19281954935437323
train_lossmean = 7.070507131710219e-5
valid_lossmean = 0.006554441031684627
FE = 0.16615982473976418
train_lossmean = 6.913548821723271e-5
valid_lossmean = 0.007567916033414755
FE = 0.17854444702561276
train_lossmean = 6.736590455398741e-5
valid_lossmean = 0.005790396373718222
FE = 0.156175297545612
train_lossmean = 6.502502727095147e-5
valid_lossmean = 0.007349706590266956
FE = 0.1759515901377205
train_lossmean = 6.609411773297293e-5
valid_lossmean = 0.00613

train_lossmean = 2.889738477663557e-5
valid_lossmean = 0.0037969774829100414
FE = 0.12646699693059962
train_lossmean = 3.01059562912825e-5
valid_lossmean = 0.0034303427284033245
FE = 0.12020622305932997
train_lossmean = 3.0068013157002086e-5
valid_lossmean = 0.0039051054482113147
FE = 0.12825508058377783
train_lossmean = 2.939760083790726e-5
valid_lossmean = 0.003984150792021591
FE = 0.12954661775221682
train_lossmean = 2.8232493659532042e-5
valid_lossmean = 0.0044123885285763985
FE = 0.1363311407117972
train_lossmean = 2.9666223234239003e-5
valid_lossmean = 0.0041667564328170025
FE = 0.1324821149017107
train_lossmean = 2.9460783033015645e-5
valid_lossmean = 0.003575263081318708
FE = 0.12271911009788689
train_lossmean = 2.920291663941632e-5
valid_lossmean = 0.0036102465090528746
FE = 0.12331804284282019
train_lossmean = 2.797753756212576e-5
valid_lossmean = 0.0037085739138738305
FE = 0.1249860850199087
train_lossmean = 2.7827573342119755e-5
valid_lossmean = 0.0034371638067364914
FE = 0

valid_lossmean = 0.0036754921448356264
FE = 0.12442737658812826
train_lossmean = 2.2577467907007598e-5
valid_lossmean = 0.0034124820680793183
FE = 0.11989287762722511
train_lossmean = 2.389906187551235e-5
valid_lossmean = 0.0030568233810049457
FE = 0.11347321725850378
train_lossmean = 2.2660685378129457e-5
valid_lossmean = 0.0037204990953716727
FE = 0.12518687453069782
train_lossmean = 2.3139274029315105e-5
valid_lossmean = 0.003662595735547559
FE = 0.12420889204471669
train_lossmean = 2.2558524683629467e-5
valid_lossmean = 0.003177218956201368
FE = 0.11568625642468193
train_lossmean = 2.200490697034127e-5
valid_lossmean = 0.003145303253296956
FE = 0.11510374585169433
train_lossmean = 2.1913316966714128e-5
valid_lossmean = 0.0034604782826052156
FE = 0.12073307421088868
train_lossmean = 2.18780156521658e-5
valid_lossmean = 0.0038529069808496564
FE = 0.12739502173008543
train_lossmean = 2.180222177719944e-5
valid_lossmean = 0.003305356900731556
FE = 0.1179960244262761
train_lossmean = 2.

In [14]:
dFE_1d(comm, commModel, commLeader, m, valid=test)

FE = 0.1263928144138003


0.1263928144138003

## Baseline result 

The fractional error on the test set is 0.13.

# Overall result

With about 1000 data points, a single model PEDS model achieves an error < 5% and leads to a 2.7x improvement compared to the NN-only baseline. 