# Physics-Enhanced Deep Surrogate (PEDS) Robustness study

We further studied the general robustness of PEDS further studied robustness in the most difficult case of Maxwell’s equations 
We consider models without ensembling and without active learning to single out the effect of PEDS in comparison to Neural Networks (NN-only) and predicting the mean.
We study the robustness on random split and stratified splits of the test set. 
We report that PEDS' error is 5x more robust to random splits in the test set, and PEDS improvement compared to the baseline is robust to test set splits.

In [2]:
using Distributions

In [3]:
##load module
include("../src/PEDS.jl")

##loading data
X = readdlm("../data/X_maxwell10_small.csv", ',')
y = parse.(Complex{Float64}, readdlm("../data/y_maxwell10_small.csv", ',')[:])

Xv = X[:, 1:1024] #valid set
Xtest = X[:, end-1023:end] #test set
Xt = X[:, 1025:end]

yv = y[1:1024] #valid set
ytest = y[end-1023:end] #test set
yt = y[1025:end];

[33m[1m│ [22m[39m- Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
[33m[1m│ [22m[39m- If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
[33m[1m└ [22m[39m[90m@ FluxCUDAExt C:\Users\XMhua\.julia\packages\Flux\CUn7U\ext\FluxCUDAExt\FluxCUDAExt.jl:57[39m


In [5]:
datalimitvalid = 2^10
Jval=1
batchsize=64

##Definition of problem constants
const debug = false
const drv = DataRunner(Xv, yv, [1]);
const al = ALstruct(J=Jval, Nvalid=datalimitvalid, batchsize=batchsize);
const valid = initvalid(al, drv) #validation loader

const drtest = DataRunner(Xtest, ytest, [1]);
const test = initvalid(al, drtest) #validation loader

const nn = NNstruct(outGen=[256, 256, 10*110],
postGen = [x-> @. x*1.5 + 2.5; x-> reshape(x, (110,10,:))],
inVar = [110*10, 256, 256, 256]);
const cs = CSstruct(resolution=10, 
nn_x=10, 
ny_nn=110, 
refsim=0.3364246930443735 + 0.1920021246559511im);
const sd = SimulationDomain(cs)
##setup MPI and random
const comm = MPI.COMM_WORLD
const model_color = MPI.Comm_rank(comm)%al.J
const commModel = MPI.Comm_split(comm, model_color, 0)
const isleader = MPI.Comm_rank(commModel) == 0
const commLeader = MPI.Comm_split(comm, isleader, 0)
debug && print("Comm rank=$(MPI.Comm_rank(comm)), commModel rank = $(MPI.Comm_rank(commModel)), commLeader rank = $(MPI.Comm_rank(commLeader))\n")
Random.seed!(2139*(model_color+1)) #alter seed for different groups



TaskLocalRNG()

In [6]:
##training functionalities
function train_distributed!(comm, commModel, commLeader, mloglik, m, loss, ps, loader, opt, validation_fes; logging=false)
    for d in loader
        train_loss, back = Zygote.pullback(() -> loss(commModel, mloglik, m, d...), ps)
        gs = back(1.)
        if debug && isleader
            if isnan(train_loss)
                @show (model_color, train_loss)
            end
        end
        for x in ps
            gs[x][:] .= sum_reduce(commModel, Float64.(gs[x][:]))
            if debug && isleader
                if any(isnan.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isnan,x)), length(findall(isnan,gs[x][:])))
                end

                if any(isinf.(gs[x][:]))
                    @show (model_color, train_loss)
                    @show (length(x), length(findall(isinf,x)), length(findall(isinf,gs[x][:])))
                end
            end
        end
        
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "before update")
                end
            end
        end
        Flux.Optimise.update!(opt, ps, gs)
        if debug && isleader
            for p_ in ps
                if any(isnan.(p_))
                    @show (model_color, "after update")
                end
            end
        end
    end

    logging && push!(validation_fes, dFE(comm, commModel, commLeader, m))
end

function dFE(comm, commModel, commLeader, m; valid=valid) """dFE computes the FE using parallelization over the batch with MPI""" 
    evalsr = zeros(al.Nvalid)
    evalsi = zeros(al.Nvalid)
    FE = 0.
    j=0
    ys = Complex{Float64}[]
    for (x, y) in valid
        for i=1+MPI.Comm_rank(commModel):MPI.Comm_size(commModel):length(y)
            rp, ip = m(x[:,i])
            evalsr[j*length(y)+i] = rp
            evalsi[j*length(y)+i] = ip
        end
        j+=1
        push!(ys, y...)
    end
    evalsrModel = sum_reduce(commModel, evalsr)
    evalsiModel = sum_reduce(commModel, evalsi)
    evalsr = sum_reduce(commLeader, evalsrModel) / al.J
    evalsi = sum_reduce(commLeader, evalsiModel) / al.J
    if MPI.Comm_rank(comm) == 0
        ŷ = @. evalsr + 1im * evalsi
        FE = norm(ŷ - ys)/norm(ys)
        @show FE
    end
    return FE
end

dFE (generic function with 1 method)

## PEDS

In [7]:
kval=2^7

##define same AL parameters for all workers
MPI.Barrier(comm)
al1 = ALstruct(Ninit=256+8*kval, T=0);
if MPI.Comm_rank(comm) == 0
    @show kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
(mgen, cw, mvar) = initmodel(nn)
coarseinput(p) = begin 
    (coarsified, sd_freq) = Zygote.ignore() do
        coarse_geom_func(p)
    end
    generated =dropdims(mgen(p), dims=3)
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("inputnan", p) 
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("mgennan", mgen(p)) 
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("mgennanparam", ps)
    debug && isleader && any(isnan.(ϵcombine)) && writedlm("errorcolor", model_color)
    w = NNlib.sigmoid.(cw*nn.multfact)
    # w = max(0, min(1, cw))
    ϵcombine = @. w * generated + (1-w) * coarsified
    ϵsymmetric = ϵcombine#(ϵcombine .+ reverse(ϵcombine, dims=2))./2
    return ϵsymmetric, sd_freq
end
m(p) = begin
    ϵsymmetric, sd_freq = coarseinput(p)
    return [realtransmissionSolver(ϵsymmetric, sd_freq = sd_freq); imagtransmissionSolver(ϵsymmetric, sd_freq = sd_freq)]
end 

uq(p) = mvar(coarseinput(p)[1])

mloglik(p) =  vcat(m(p), uq(p))
ps = Flux.params(mgen, cw, mvar)
# ps = Flux.params(mgen, mvar)
loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)
##train baseline
if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
end
##active learning loop
for t=1:al1.T
    MPI.Comm_rank(comm) == 0 && @show t
    loader= getloader(al1, dr, ds, X->varfilter(mloglik, X))
    if MPI.Comm_rank(comm)==0 
        @time for iepoch = 1:al1.ne 
            train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
        end
    else
        for iepoch = 1:al1.ne 
            train_distributed!(comm, commModel, commLeader, mloglik, m, dNLL, ps, loader, opt, validation_fes, logging=true)
        end
    end
end

# ##save models and validation FEs
# name = "PEDS10_example"
# if isleader
#     BSON.@save "$(name)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(name)_K$(kval)_cw$(model_color).bson" cw
#     BSON.@save "$(name)_K$(kval)_mvar$(model_color).bson" mvar
# end
# if MPI.Comm_rank(comm) == 0
#     writedlm("$(name)_K$(kval)_validation_fes.csv", validation_fes, ',')
# end

kval = 128


[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(13 => 256, relu)  [90m# 3_584 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "13-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux C:\Users\XMhua\.julia\packages\Flux\CUn7U\src\layers\stateless.jl:60[39m


FE = 0.8577009787830996
FE = 0.815192597441678
FE = 0.6847121154762893
FE = 0.4293349735411232
FE = 0.37359245102053856
FE = 0.30727811959272866
FE = 0.28244204088185293
FE = 0.26800707317899647
FE = 0.32691861764241686
FE = 0.28521748808573144
330.071395 seconds (76.69 M allocations: 418.372 GiB, 5.55% gc time, 6.75% compilation time)


In [8]:
println("The error of PEDS on the full test set is")
errval=dFE(comm, commModel, commLeader, m, valid=test)

The error of PEDS on the full test set is
FE = 0.2980496076462081


0.2980496076462081

### Random split

We study the robustness of the error of PEDS when randomly splitting the test set.

In [9]:
rs = rand(length(ytest))
mskr1 = rs.>0.5
mskr2 = rs.<=0.5;

In [10]:
println("The FEs on the randomly split test sets are:")
errtestsPEDSrandom = []
for msk in [mskr1, mskr2] #2.4%, 2.4%
    predsf = map(i->dot(m(Xtest[:, msk][:, i]), [1.0; 1.0im]), 1:sum(msk));
    errtest = norm(predsf-ytest[msk])/norm(ytest[msk])
    @show errtest
    push!(errtestsPEDSrandom, errtest)
end

The FEs on the randomly split test sets are:
errtest = 0.30101673297394554
errtest = 0.29459854026631704


In [11]:
println("The relative difference compared to the error on the full test set are")
[abs(errtest-errval)/errval for errtest in errtestsPEDSrandom]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.009955139183607015
 0.011578835507099663

### Stratified split

We study the robustness of the error of PEDS, when splitting the test set in two half in function of the output value. We split the test set into label with high absolute values (> 1.15) and low absolute values, which is roughly a 50-50 split.

In [12]:
mskH = abs.(ytest).>1.15
mskL = abs.(ytest).<=1.15;
println("The mean absolute value for the whole test set, and the stratified test set with high and low absolute values are:")
println("$(mean(abs, ytest)), $(mean(abs, ytest[mskH])), and $(mean(abs, ytest[mskL])).")
println("The data count for the whole test set, and the stratified test set with high and low absolute values are:")
println("$(length(ytest)), $(sum(mskH)), and $(sum(mskL)).")

The mean absolute value for the whole test set, and the stratified test set with high and low absolute values are:
1.0834087051430121, 1.1937880360548516, and 0.9821243003737256.
The data count for the whole test set, and the stratified test set with high and low absolute values are:
1024, 490, and 534.


In [13]:
println("The FEs on the stratified split test sets are:")
errtestsPEDSstratified = []
for msk in [mskH, mskL]
    predsf = map(i->dot(m(Xtest[:, msk][:, i]), [1.0; 1.0im]), 1:sum(msk));
    errtest = norm(predsf-ytest[msk])/norm(ytest[msk])
    @show errtest
    append!(errtestsPEDSstratified, errtest)
end

The FEs on the stratified split test sets are:
errtest = 0.2334463153920033
errtest = 0.3660051336847997


In [14]:
println("The relative difference compared to the error on the full test set are")
[abs(errtest-errval)/errval for errtest in errtestsPEDSstratified]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.2167534886705518
 0.22800072301808372

With about 1000 data points, PEDS perform less well for lower absolute values of the complex transmission coefficient.

## Baseline

In [15]:
##define same AL parameters for all workers
MPI.Barrier(comm)
al1 = ALstruct(Ninit=256+8*kval, T=0);

if MPI.Comm_rank(comm) == 0
    @show kval
end
##ititialize DataRunner and DataSet
dr = DataRunner(Xt, yt, [1]);
ds = DataSet()
validation_fes = []
##initialize baseline
(mgen2, pred2, mvar2) = initbase(nn)
mloglik2(p) =  vcat(pred2(mgen2(p)), mvar2(mgen2(p)))
m2(p) = pred2(mgen2(p))
ps = Flux.params(mgen2, pred2, mvar2)
loader = initloader(al1, dr, ds);
opt = ADAM(al1.η)
##train baseline
if MPI.Comm_rank(comm)==0
    @time for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik2, m2, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
else
    for iepoch = 1:al1.ne 
        train_distributed!(comm, commModel, commLeader, mloglik2, m2, dNLL, ps, loader, opt, validation_fes, logging=true)
    end
end
##active learning loop
for t=1:al1.T
    MPI.Comm_rank(comm) == 0 && @show t
    loader= getloader(al1, dr, ds, X->varfilter(mloglik, X)) 
    if MPI.Comm_rank(comm)==0 
        @time for iepoch = 1:al1.ne 
            train_distributed!(comm, commModel, commLeader, mloglik2, m2, dNLL, ps, loader, opt, validation_fes, logging=true)
        end
    else
        for iepoch = 1:al1.ne 
            train_distributed!(comm, commModel, commLeader, mloglik2, m2, dNLL, ps, loader, opt, validation_fes, logging=true)
        end
    end
end

# ##save models and validation FEs
# name = "baseline10_noal_example"
# if isleader
#     BSON.@save "$(name)_K$(kval)_mgen$(model_color).bson" mgen
#     BSON.@save "$(name)_K$(kval)_pred$(model_color).bson" pred
#     BSON.@save "$(name)_K$(kval)_mvar$(model_color).bson" mvar
# end
# if MPI.Comm_rank(comm) == 0
#     writedlm("$(name)_K$(kval)_validation_fes.csv", validation_fes, ',')
# end

kval = 128
FE = 0.8123875780446773
FE = 0.7418925106005231
FE = 0.6464766980258803
FE = 0.6068669970782117
FE = 0.5444202805915394
FE = 0.5233060102624337
FE = 0.5174969811595236
FE = 0.5211313587725638
FE = 0.49780352778297954
FE = 0.49290690676340554
 22.067578 seconds (10.41 M allocations: 70.419 GiB, 22.77% gc time, 8.12% compilation time)


In [16]:
println("The error of NN-only on the full test set is")
errvalb = dFE(comm, commModel, commLeader, m2, valid=test)

The error of NN-only on the full test set is
FE = 0.5113234353269354


0.5113234353269354

In [17]:
println("PEDS improves the FE by")
@show errvalb/errval

PEDS improves the FE by
errvalb / errval = 1.7155648664160912


1.7155648664160912

### Random split

In [18]:
println("The FEs on the randomly split test sets are:")
errtestsBrandom = []
for msk in [mskr1, mskr2]
    predsf = map(i->dot(m2(Xtest[:, msk][:, i]), [1.0; 1.0im]), 1:sum(msk));
    errtest = norm(predsf-ytest[msk])/norm(ytest[msk])
    @show errtest
    push!(errtestsBrandom, errtest)
end

The FEs on the randomly split test sets are:
errtest = 0.5326975900500167
errtest = 0.48556672187559197


In [19]:
println("The relative difference compared to the error on the full test set are")
[abs(errtestb-errvalb)/errvalb for errtestb in errtestsBrandom]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.04180163326450101
 0.05037264414621402

In [20]:
println("NN-only is worse than PEDS by a factor of:")
[errb/errPEDS for (errb, errPEDS) in zip(errtestsBrandom, errtestsPEDSrandom)]

NN-only is worse than PEDS by a factor of:


2-element Vector{Float64}:
 1.7696610576665992
 1.6482319343355867

### Stratified split

In [21]:
println("The FEs on the stratified split test sets are:")
errtestsBstratified = []
for msk in [mskH, mskL]
    predsf = map(i->dot(m2(Xtest[:, msk][:, i]), [1.0; 1.0im]), 1:sum(msk));
    errtest = norm(predsf-ytest[msk])/norm(ytest[msk])
    @show errtest
    append!(errtestsBstratified, errtest)
end

The FEs on the stratified split test sets are:
errtest = 0.420756212539178
errtest = 0.610241105278447


In [22]:
println("The relative difference compared to the error on the full test set are")
[abs(errtest-errvalb)/errval for errtest in errtestsBstratified]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.3038662707963125
 0.33188324162778043

In [23]:
println("The improvement from adding PEDS is:")
[errb/errPEDS for (errb, errPEDS) in zip(errtestsBstratified, errtestsPEDSstratified)]

The improvement from adding PEDS is:


2-element Vector{Float64}:
 1.8023681883033524
 1.6673020379107062

## Mean baseline 

In [24]:
println("The FEs the full data set is:")
predsmean = mean(ytest)
errvalm = norm(predsmean .-ytest)/norm(ytest)
@show errvalm

The FEs the full data set is:
errvalm = 0.9657957695419176


0.9657957695419176

In [25]:
println("PEDS improves the FE by")
@show errvalm/errval

PEDS improves the FE by
errvalm / errval = 3.240385978593


3.240385978593

### Random split

In [26]:
println("The FEs on the randomly split test sets are:")
errtestsMrandom = []
for msk in [mskr1, mskr2]
    errtest = norm(predsmean .-ytest[msk])/norm(ytest[msk])
    @show errtest
    push!(errtestsMrandom, errtest)
end

The FEs on the randomly split test sets are:
errtest = 0.9633755149995928
errtest = 0.9685731290054408


In [27]:
println("The relative difference compared to the error on the full test set are")
[abs(errtestb-errvalm)/errvalm for errtestb in errtestsMrandom]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.002505969293562646
 0.0028757212975166753

In [28]:
println("Mean prediction is worse than PEDS by a factor of:")
[errb/errPEDS for (errb, errPEDS) in zip(errtestsMrandom, errtestsPEDSrandom)]

Mean prediction is worse than PEDS by a factor of:


2-element Vector{Float64}:
 3.200405191704003
 3.2877730084129095

### Stratified split

In [29]:
println("The FEs on the stratified split test sets are:")
errtestsMstratified = []
for msk in [mskH, mskL]
    errtest = norm(predsmean .-ytest[msk])/norm(ytest[msk])
    @show errtest
    append!(errtestsMstratified, errtest)
end

The FEs on the stratified split test sets are:
errtest = 0.923471278925129
errtest = 1.0187496890362957


In [30]:
println("The relative difference compared to the error on the full test set are")
[abs(errtestb-errvalm)/errvalm for errtestb in errtestsMstratified]

The relative difference compared to the error on the full test set are


2-element Vector{Float64}:
 0.0438234375750717
 0.05482931398580722

In [31]:
println("Mean prediction is worse than PEDS by a factor of:")
[errb/errPEDS for (errb, errPEDS) in zip(errtestsMstratified, errtestsPEDSstratified)]

Mean prediction is worse than PEDS by a factor of:


2-element Vector{Float64}:
 3.955818610263499
 2.783430054053924

## Summary


|                                             | PEDS          | NN-only       | Mean prediction |
|---------------------------------------------|---------------|---------------|-----------------|
| Fractional Error (FE) on test set           | 28.33%        | 53.65%        | 95.58%          |
| PEDS improvement                            | N/A           | 1.89x         | 3.40x           |
| FE with random splits (set 1/set 2)         | 28.28%/28.38% | 54.67%/52.49% | 97.38%/95.68%   |
| PEDS improvement                            | N/A           | 1.93x/1.85x   | 3.44x/3.37x     |
| FE with stratified split (set high/set low) | 20.19%/36.35% | 46.02%/62.28% | 92.35%/101.87%  |
| PEDS improvement                            | N/A           | 2.28x/1.71x   | 4.57x/2.80x     |


##  Discussion

- On random splits, the NN-only and mean prediction baseline have error that vary by about 1% relatively to the error on the full test set. PEDS' error varies 5 times less is more robust to the choice of test set.
- On stratified test sets, with datapoints with high absolute transmission in one set and lower absolute transmission in the other, we see that it is harder to predict lower absolute transmission accurately across the models, but PEDS is always performing much better than the two baseline by a factor of about two at least.